generation_queue.h 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. #ifndef GENERATION_QUEUE_H
  2. #define GENERATION_QUEUE_H
  3. #include <string>
  4. #include <memory>
  5. #include <queue>
  6. #include <mutex>
  7. #include <condition_variable>
  8. #include <functional>
  9. #include <future>
  10. #include <thread>
  11. #include <atomic>
  12. #include <unordered_map>
  13. #include <chrono>
  14. #include <vector>
  15. #include <map>
  16. /**
  17. * @brief Job type enumeration
  18. */
  19. enum class JobType {
  20. GENERATION, ///< Image generation job
  21. HASHING, ///< Model hashing job
  22. CONVERSION ///< Model conversion/quantization job
  23. };
  24. /**
  25. * @brief Generation status enumeration
  26. */
  27. enum class GenerationStatus {
  28. QUEUED, ///< Request is queued and waiting to be processed
  29. MODEL_LOADING, ///< Model is being loaded
  30. PROCESSING, ///< Request is currently being processed
  31. COMPLETED, ///< Request completed successfully
  32. FAILED ///< Request failed during processing
  33. };
  34. /**
  35. * @brief Sampling method enumeration (matching stable-diffusion.cpp)
  36. */
  37. enum class SamplingMethod {
  38. EULER,
  39. EULER_A,
  40. HEUN,
  41. DPM2,
  42. DPMPP2S_A,
  43. DPMPP2M,
  44. DPMPP2MV2,
  45. IPNDM,
  46. IPNDM_V,
  47. LCM,
  48. DDIM_TRAILING,
  49. TCD,
  50. DEFAULT ///< Use model default
  51. };
  52. /**
  53. * @brief Scheduler enumeration (matching stable-diffusion.cpp)
  54. */
  55. enum class Scheduler {
  56. DISCRETE,
  57. KARRAS,
  58. EXPONENTIAL,
  59. AYS,
  60. GITS,
  61. SMOOTHSTEP,
  62. SGM_UNIFORM,
  63. SIMPLE,
  64. DEFAULT ///< Use model default
  65. };
  66. /**
  67. * @brief Generation request structure with all stable-diffusion.cpp parameters
  68. */
  69. struct GenerationRequest {
  70. // Basic parameters
  71. std::string id; ///< Unique request ID
  72. std::string modelName; ///< Name of the model to use
  73. std::string prompt; ///< Text prompt for generation
  74. std::string negativePrompt; ///< Negative prompt (optional)
  75. // Image parameters
  76. int width = 512; ///< Image width
  77. int height = 512; ///< Image height
  78. int batchCount = 1; ///< Number of images to generate
  79. // Sampling parameters
  80. int steps = 20; ///< Number of diffusion steps
  81. float cfgScale = 7.5f; ///< CFG scale
  82. SamplingMethod samplingMethod = SamplingMethod::DEFAULT; ///< Sampling method
  83. Scheduler scheduler = Scheduler::DEFAULT; ///< Scheduler
  84. // Seed control
  85. std::string seed = "42"; ///< Seed for generation ("random" for random)
  86. // Model paths (for advanced usage)
  87. std::string clipLPath; ///< Path to CLIP-L model
  88. std::string clipGPath; ///< Path to CLIP-G model
  89. std::string clipVisionPath; ///< Path to CLIP-Vision model
  90. std::string t5xxlPath; ///< Path to T5-XXL model
  91. std::string qwen2vlPath; ///< Path to Qwen2VL model
  92. std::string qwen2vlVisionPath; ///< Path to Qwen2VL Vision model
  93. std::string diffusionModelPath; ///< Path to standalone diffusion model
  94. std::string vaePath; ///< Path to VAE model
  95. std::string taesdPath; ///< Path to TAESD model
  96. std::string controlNetPath; ///< Path to ControlNet model
  97. std::string embeddingDir; ///< Path to embeddings directory
  98. std::string loraModelDir; ///< Path to LoRA model directory
  99. // Advanced parameters
  100. int clipSkip = -1; ///< CLIP skip layers
  101. std::vector<int> skipLayers = {7, 8, 9}; ///< Layers to skip for SLG
  102. float strength = 0.75f; ///< Strength for img2img
  103. float controlStrength = 0.9f; ///< ControlNet strength
  104. // Performance parameters
  105. int nThreads = -1; ///< Number of threads (-1 for auto)
  106. bool offloadParamsToCpu = false; ///< Offload parameters to CPU
  107. bool clipOnCpu = false; ///< Keep CLIP on CPU
  108. bool vaeOnCpu = false; ///< Keep VAE on CPU
  109. bool diffusionFlashAttn = false; ///< Use flash attention
  110. bool diffusionConvDirect = false; ///< Use direct convolution
  111. bool vaeConvDirect = false; ///< Use direct VAE convolution
  112. // Output parameters
  113. std::string outputPath; ///< Output path for generated images
  114. // Image-to-image parameters
  115. std::string initImagePath; ///< Path to init image for img2img (can be file path or base64)
  116. std::vector<uint8_t> initImageData; ///< Init image data (decoded)
  117. int initImageWidth = 0; ///< Init image width
  118. int initImageHeight = 0; ///< Init image height
  119. int initImageChannels = 3; ///< Init image channels
  120. // ControlNet parameters
  121. std::string controlImagePath; ///< Path to control image for ControlNet
  122. std::vector<uint8_t> controlImageData; ///< Control image data (decoded)
  123. int controlImageWidth = 0; ///< Control image width
  124. int controlImageHeight = 0; ///< Control image height
  125. int controlImageChannels = 3; ///< Control image channels
  126. // Upscaler parameters
  127. std::string esrganPath; ///< Path to ESRGAN model for upscaling
  128. uint32_t upscaleFactor = 4; ///< Upscale factor (2 or 4)
  129. // Inpainting parameters
  130. std::string maskImagePath; ///< Path to mask image for inpainting
  131. std::vector<uint8_t> maskImageData; ///< Mask image data (decoded)
  132. int maskImageWidth = 0; ///< Mask image width
  133. int maskImageHeight = 0; ///< Mask image height
  134. int maskImageChannels = 1; ///< Mask image channels (grayscale)
  135. // Request type
  136. enum class RequestType {
  137. TEXT2IMG,
  138. IMG2IMG,
  139. CONTROLNET,
  140. UPSCALER,
  141. INPAINTING
  142. } requestType = RequestType::TEXT2IMG;
  143. // Callback for completion
  144. std::function<void(const std::string&, const std::string&)> callback; ///< Callback for completion
  145. };
  146. /**
  147. * @brief Generation result structure
  148. */
  149. struct GenerationResult {
  150. std::string requestId; ///< ID of the original request
  151. GenerationStatus status; ///< Final status of the generation
  152. bool success; ///< Whether generation was successful
  153. std::vector<std::string> imagePaths; ///< Paths to generated images (multiple for batch)
  154. std::string errorMessage; ///< Error message if generation failed
  155. uint64_t generationTime; ///< Time taken for generation in milliseconds
  156. int64_t actualSeed; ///< Actual seed used for generation
  157. };
  158. /**
  159. * @brief Hash request structure for model hashing jobs
  160. */
  161. struct HashRequest {
  162. std::string id; ///< Unique request ID
  163. std::vector<std::string> modelNames; ///< Model names to hash (empty = hash all unhashed)
  164. bool forceRehash = false; ///< Force rehash even if hash exists
  165. };
  166. /**
  167. * @brief Hash result structure
  168. */
  169. struct HashResult {
  170. std::string requestId; ///< ID of the original request
  171. GenerationStatus status; ///< Final status
  172. bool success; ///< Whether hashing was successful
  173. std::map<std::string, std::string> modelHashes; ///< Map of model names to their hashes
  174. std::string errorMessage; ///< Error message if hashing failed
  175. uint64_t hashingTime; ///< Time taken for hashing in milliseconds
  176. int modelsHashed; ///< Number of models successfully hashed
  177. };
  178. /**
  179. * @brief Conversion request structure for model quantization/conversion jobs
  180. */
  181. struct ConversionRequest {
  182. std::string id; ///< Unique request ID
  183. std::string modelName; ///< Model name to convert
  184. std::string modelPath; ///< Full path to model file
  185. std::string outputPath; ///< Output path for converted model
  186. std::string quantizationType; ///< Quantization type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)
  187. };
  188. /**
  189. * @brief Conversion result structure
  190. */
  191. struct ConversionResult {
  192. std::string requestId; ///< ID of the original request
  193. GenerationStatus status; ///< Final status
  194. bool success; ///< Whether conversion was successful
  195. std::string outputPath; ///< Path to converted model file
  196. std::string errorMessage; ///< Error message if conversion failed
  197. uint64_t conversionTime; ///< Time taken for conversion in milliseconds
  198. std::string originalSize; ///< Original model file size
  199. std::string convertedSize; ///< Converted model file size
  200. };
  201. /**
  202. * @brief Job information for queue status
  203. */
  204. struct JobInfo {
  205. std::string id; ///< Job ID
  206. JobType type; ///< Job type (generation or hashing)
  207. GenerationStatus status; ///< Current status
  208. std::string prompt; ///< Job prompt (full text for generation, or model name for hashing)
  209. std::chrono::system_clock::time_point queuedTime; ///< When job was queued
  210. std::chrono::system_clock::time_point startTime; ///< When job started processing
  211. std::chrono::system_clock::time_point endTime; ///< When job completed/failed
  212. int position; ///< Position in queue (for queued jobs)
  213. std::vector<std::string> outputFiles; ///< Paths to generated output files
  214. std::string errorMessage; ///< Error message if job failed
  215. float progress = 0.0f; ///< Overall progress (0.0 to 1.0)
  216. float modelLoadProgress = 0.0f; ///< Model loading progress (0.0 to 1.0)
  217. float generationProgress = 0.0f; ///< Generation progress (0.0 to 1.0)
  218. int currentStep = 0; ///< Current step in generation
  219. int totalSteps = 0; ///< Total steps in generation
  220. int64_t timeElapsed = 0; ///< Time elapsed in milliseconds
  221. int64_t timeRemaining = 0; ///< Estimated time remaining in milliseconds
  222. float speed = 0.0f; ///< Generation speed in steps per second
  223. bool firstGenerationCallback = true; ///< Flag to track if this is the first generation callback
  224. // Enhanced fields for repeatable generation
  225. std::string modelName; ///< Name of the model used
  226. std::string modelHash; ///< SHA256 hash of the model
  227. std::string modelPath; ///< Full path to the model file
  228. std::string negativePrompt; ///< Negative prompt used
  229. int width = 512; ///< Image width
  230. int height = 512; ///< Image height
  231. int batchCount = 1; ///< Number of images generated
  232. int steps = 20; ///< Number of diffusion steps
  233. float cfgScale = 7.5f; ///< CFG scale
  234. SamplingMethod samplingMethod = SamplingMethod::DEFAULT; ///< Sampling method used
  235. Scheduler scheduler = Scheduler::DEFAULT; ///< Scheduler used
  236. std::string seed; ///< Seed used for generation
  237. int64_t actualSeed = 0; ///< Actual seed that was used (for random seeds)
  238. std::string requestType; ///< Request type: text2img, img2img, controlnet, upscaler, inpainting
  239. float strength = 0.75f; ///< Strength for img2img
  240. float controlStrength = 0.9f; ///< ControlNet strength
  241. int clipSkip = -1; ///< CLIP skip layers
  242. int nThreads = -1; ///< Number of threads used
  243. bool offloadParamsToCpu = false; ///< Offload parameters to CPU setting
  244. bool clipOnCpu = false; ///< Keep CLIP on CPU setting
  245. bool vaeOnCpu = false; ///< Keep VAE on CPU setting
  246. bool diffusionFlashAttn = false; ///< Use flash attention setting
  247. bool diffusionConvDirect = false; ///< Use direct convolution setting
  248. bool vaeConvDirect = false; ///< Use direct VAE convolution setting
  249. uint64_t generationTime = 0; ///< Total generation time in milliseconds
  250. // Image data for complex operations (base64 encoded)
  251. std::string initImageData; ///< Init image data for img2img (base64)
  252. std::string controlImageData; ///< Control image data for ControlNet (base64)
  253. std::string maskImageData; ///< Mask image data for inpainting (base64)
  254. // Model paths for advanced usage
  255. std::string clipLPath; ///< Path to CLIP-L model
  256. std::string clipGPath; ///< Path to CLIP-G model
  257. std::string vaePath; ///< Path to VAE model
  258. std::string taesdPath; ///< Path to TAESD model
  259. std::string controlNetPath; ///< Path to ControlNet model
  260. std::string embeddingDir; ///< Path to embeddings directory
  261. std::string loraModelDir; ///< Path to LoRA model directory
  262. std::string esrganPath; ///< Path to ESRGAN model for upscaling
  263. uint32_t upscaleFactor = 4; ///< Upscale factor used
  264. };
  265. /**
  266. * @brief Generation queue class for managing image generation requests
  267. *
  268. * This class manages a queue of image generation requests, processes them
  269. * asynchronously, and provides thread-safe access to the queue and results.
  270. * Only one generation job is processed at a time as specified in requirements.
  271. */
  272. class GenerationQueue {
  273. public:
  274. /**
  275. * @brief Construct a new Generation Queue object
  276. *
  277. * @param modelManager Pointer to the model manager
  278. * @param maxConcurrentGenerations Maximum number of concurrent generations (should be 1)
  279. * @param queueDir Directory to store job persistence files
  280. * @param outputDir Directory to store generated output files
  281. */
  282. explicit GenerationQueue(class ModelManager* modelManager, int maxConcurrentGenerations = 1,
  283. const std::string& queueDir = "./queue", const std::string& outputDir = "./output");
  284. /**
  285. * @brief Destroy the Generation Queue object
  286. */
  287. virtual ~GenerationQueue();
  288. /**
  289. * @brief Add a generation request to the queue
  290. *
  291. * @param request The generation request
  292. * @return std::future<GenerationResult> Future for the generation result
  293. */
  294. std::future<GenerationResult> enqueueRequest(const GenerationRequest& request);
  295. /**
  296. * @brief Add a hash request to the queue
  297. *
  298. * @param request The hash request
  299. * @return std::future<HashResult> Future for the hash result
  300. */
  301. std::future<HashResult> enqueueHashRequest(const HashRequest& request);
  302. /**
  303. * @brief Add a conversion request to the queue
  304. *
  305. * @param request The conversion request
  306. * @return std::future<ConversionResult> Future for the conversion result
  307. */
  308. std::future<ConversionResult> enqueueConversionRequest(const ConversionRequest& request);
  309. /**
  310. * @brief Get the current queue size
  311. *
  312. * @return size_t Number of requests in the queue
  313. */
  314. size_t getQueueSize() const;
  315. /**
  316. * @brief Get the number of active generations
  317. *
  318. * @return size_t Number of currently processing requests
  319. */
  320. size_t getActiveGenerations() const;
  321. /**
  322. * @brief Get detailed queue status
  323. *
  324. * @return std::vector<JobInfo> List of all jobs with their status
  325. */
  326. std::vector<JobInfo> getQueueStatus() const;
  327. /**
  328. * @brief Get job information by ID
  329. *
  330. * @param jobId The job ID to look up
  331. * @return JobInfo Job information, or empty if not found
  332. */
  333. JobInfo getJobInfo(const std::string& jobId) const;
  334. /**
  335. * @brief Cancel a pending job
  336. *
  337. * @param jobId The job ID to cancel
  338. * @return true if job was cancelled, false if not found or already processing
  339. */
  340. bool cancelJob(const std::string& jobId);
  341. /**
  342. * @brief Clear all pending requests
  343. */
  344. void clearQueue();
  345. /**
  346. * @brief Start the queue processing thread
  347. */
  348. void start();
  349. /**
  350. * @brief Stop the queue processing thread
  351. */
  352. void stop();
  353. /**
  354. * @brief Check if the queue is running
  355. *
  356. * @return true if the queue is running, false otherwise
  357. */
  358. bool isRunning() const;
  359. /**
  360. * @brief Set the maximum number of concurrent generations
  361. *
  362. * @param maxConcurrent Maximum number of concurrent generations
  363. */
  364. void setMaxConcurrentGenerations(int maxConcurrent);
  365. private:
  366. class Impl;
  367. std::unique_ptr<Impl> pImpl; // Pimpl idiom
  368. };
  369. #endif // GENERATION_QUEUE_H