generation_queue.h 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. #ifndef GENERATION_QUEUE_H
  2. #define GENERATION_QUEUE_H
  3. #include <chrono>
  4. #include <functional>
  5. #include <future>
  6. #include <map>
  7. #include <memory>
  8. #include <string>
  9. #include <vector>
  10. /**
  11. * @brief Job type enumeration
  12. */
  13. enum class JobType {
  14. GENERATION, ///< Image generation job
  15. HASHING, ///< Model hashing job
  16. CONVERSION ///< Model conversion/quantization job
  17. };
  18. /**
  19. * @brief Generation status enumeration
  20. */
  21. enum class GenerationStatus {
  22. QUEUED, ///< Request is queued and waiting to be processed
  23. MODEL_LOADING, ///< Model is being loaded
  24. PROCESSING, ///< Request is currently being processed
  25. COMPLETED, ///< Request completed successfully
  26. FAILED ///< Request failed during processing
  27. };
  28. /**
  29. * @brief Sampling method enumeration (matching stable-diffusion.cpp)
  30. */
  31. enum class SamplingMethod {
  32. EULER,
  33. EULER_A,
  34. HEUN,
  35. DPM2,
  36. DPMPP2S_A,
  37. DPMPP2M,
  38. DPMPP2MV2,
  39. IPNDM,
  40. IPNDM_V,
  41. LCM,
  42. DDIM_TRAILING,
  43. TCD,
  44. DEFAULT ///< Use model default
  45. };
  46. /**
  47. * @brief Scheduler enumeration (matching stable-diffusion.cpp)
  48. */
  49. enum class Scheduler {
  50. DISCRETE,
  51. KARRAS,
  52. EXPONENTIAL,
  53. AYS,
  54. GITS,
  55. SMOOTHSTEP,
  56. SGM_UNIFORM,
  57. SIMPLE,
  58. DEFAULT ///< Use model default
  59. };
  60. /**
  61. * @brief Generation request structure with all stable-diffusion.cpp parameters
  62. */
  63. struct GenerationRequest {
  64. // Basic parameters
  65. std::string id; ///< Unique request ID
  66. std::string modelName; ///< Name of the model to use
  67. std::string prompt; ///< Text prompt for generation
  68. std::string negativePrompt; ///< Negative prompt (optional)
  69. // Image parameters
  70. int width = 512; ///< Image width
  71. int height = 512; ///< Image height
  72. int batchCount = 1; ///< Number of images to generate
  73. // Sampling parameters
  74. int steps = 20; ///< Number of diffusion steps
  75. float cfgScale = 7.5f; ///< CFG scale
  76. SamplingMethod samplingMethod = SamplingMethod::DEFAULT; ///< Sampling method
  77. Scheduler scheduler = Scheduler::DEFAULT; ///< Scheduler
  78. // Seed control
  79. std::string seed = "42"; ///< Seed for generation ("random" for random)
  80. // Model paths (for advanced usage)
  81. std::string clipLPath; ///< Path to CLIP-L model
  82. std::string clipGPath; ///< Path to CLIP-G model
  83. std::string clipVisionPath; ///< Path to CLIP-Vision model
  84. std::string t5xxlPath; ///< Path to T5-XXL model
  85. std::string qwen2vlPath; ///< Path to Qwen2VL model
  86. std::string qwen2vlVisionPath; ///< Path to Qwen2VL Vision model
  87. std::string diffusionModelPath; ///< Path to standalone diffusion model
  88. std::string vaePath; ///< Path to VAE model
  89. std::string taesdPath; ///< Path to TAESD model
  90. std::string controlNetPath; ///< Path to ControlNet model
  91. std::string embeddingDir; ///< Path to embeddings directory
  92. std::string loraModelDir; ///< Path to LoRA model directory
  93. // Advanced parameters
  94. int clipSkip = -1; ///< CLIP skip layers
  95. std::vector<int> skipLayers = {7, 8, 9}; ///< Layers to skip for SLG
  96. float strength = 0.75f; ///< Strength for img2img
  97. float controlStrength = 0.9f; ///< ControlNet strength
  98. // Performance parameters
  99. int nThreads = -1; ///< Number of threads (-1 for auto)
  100. bool offloadParamsToCpu = false; ///< Offload parameters to CPU
  101. bool clipOnCpu = false; ///< Keep CLIP on CPU
  102. bool vaeOnCpu = false; ///< Keep VAE on CPU
  103. bool diffusionFlashAttn = false; ///< Use flash attention
  104. bool diffusionConvDirect = false; ///< Use direct convolution
  105. bool vaeConvDirect = false; ///< Use direct VAE convolution
  106. // Output parameters
  107. std::string outputPath; ///< Output path for generated images
  108. // Image-to-image parameters
  109. std::string initImagePath; ///< Path to init image for img2img (can be file path or base64)
  110. std::vector<uint8_t> initImageData; ///< Init image data (decoded)
  111. int initImageWidth = 0; ///< Init image width
  112. int initImageHeight = 0; ///< Init image height
  113. int initImageChannels = 3; ///< Init image channels
  114. // ControlNet parameters
  115. std::string controlImagePath; ///< Path to control image for ControlNet
  116. std::vector<uint8_t> controlImageData; ///< Control image data (decoded)
  117. int controlImageWidth = 0; ///< Control image width
  118. int controlImageHeight = 0; ///< Control image height
  119. int controlImageChannels = 3; ///< Control image channels
  120. // Upscaler parameters
  121. std::string esrganPath; ///< Path to ESRGAN model for upscaling
  122. uint32_t upscaleFactor = 4; ///< Upscale factor (2 or 4)
  123. // Inpainting parameters
  124. std::string maskImagePath; ///< Path to mask image for inpainting
  125. std::vector<uint8_t> maskImageData; ///< Mask image data (decoded)
  126. int maskImageWidth = 0; ///< Mask image width
  127. int maskImageHeight = 0; ///< Mask image height
  128. int maskImageChannels = 1; ///< Mask image channels (grayscale)
  129. // Request type
  130. enum class RequestType {
  131. TEXT2IMG,
  132. IMG2IMG,
  133. CONTROLNET,
  134. UPSCALER,
  135. INPAINTING
  136. } requestType = RequestType::TEXT2IMG;
  137. // Callback for completion
  138. std::function<void(const std::string&, const std::string&)> callback; ///< Callback for completion
  139. };
  140. /**
  141. * @brief Generation result structure
  142. */
  143. struct GenerationResult {
  144. std::string requestId; ///< ID of the original request
  145. GenerationStatus status; ///< Final status of the generation
  146. bool success; ///< Whether generation was successful
  147. std::vector<std::string> imagePaths; ///< Paths to generated images (multiple for batch)
  148. std::string errorMessage; ///< Error message if generation failed
  149. uint64_t generationTime; ///< Time taken for generation in milliseconds
  150. int64_t actualSeed; ///< Actual seed used for generation
  151. };
  152. /**
  153. * @brief Hash request structure for model hashing jobs
  154. */
  155. struct HashRequest {
  156. std::string id; ///< Unique request ID
  157. std::vector<std::string> modelNames; ///< Model names to hash (empty = hash all unhashed)
  158. bool forceRehash = false; ///< Force rehash even if hash exists
  159. };
  160. /**
  161. * @brief Hash result structure
  162. */
  163. struct HashResult {
  164. std::string requestId; ///< ID of the original request
  165. GenerationStatus status; ///< Final status
  166. bool success; ///< Whether hashing was successful
  167. std::map<std::string, std::string> modelHashes; ///< Map of model names to their hashes
  168. std::string errorMessage; ///< Error message if hashing failed
  169. uint64_t hashingTime; ///< Time taken for hashing in milliseconds
  170. int modelsHashed; ///< Number of models successfully hashed
  171. };
  172. /**
  173. * @brief Conversion request structure for model quantization/conversion jobs
  174. */
  175. struct ConversionRequest {
  176. std::string id; ///< Unique request ID
  177. std::string modelName; ///< Model name to convert
  178. std::string modelPath; ///< Full path to model file
  179. std::string outputPath; ///< Output path for converted model
  180. std::string quantizationType; ///< Quantization type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)
  181. };
  182. /**
  183. * @brief Conversion result structure
  184. */
  185. struct ConversionResult {
  186. std::string requestId; ///< ID of the original request
  187. GenerationStatus status; ///< Final status
  188. bool success; ///< Whether conversion was successful
  189. std::string outputPath; ///< Path to converted model file
  190. std::string errorMessage; ///< Error message if conversion failed
  191. uint64_t conversionTime; ///< Time taken for conversion in milliseconds
  192. std::string originalSize; ///< Original model file size
  193. std::string convertedSize; ///< Converted model file size
  194. };
  195. /**
  196. * @brief Job information for queue status
  197. */
  198. struct JobInfo {
  199. std::string id; ///< Job ID
  200. JobType type; ///< Job type (generation or hashing)
  201. GenerationStatus status; ///< Current status
  202. std::string prompt; ///< Job prompt (full text for generation, or model name for hashing)
  203. std::chrono::system_clock::time_point queuedTime; ///< When job was queued
  204. std::chrono::system_clock::time_point startTime; ///< When job started processing
  205. std::chrono::system_clock::time_point endTime; ///< When job completed/failed
  206. int position; ///< Position in queue (for queued jobs)
  207. std::vector<std::string> outputFiles; ///< Paths to generated output files
  208. std::string errorMessage; ///< Error message if job failed
  209. float progress = 0.0f; ///< Overall progress (0.0 to 1.0) - kept for backward compatibility
  210. float modelLoadProgress = 0.0f; ///< Model loading progress (0.0 to 1.0)
  211. int currentStep = 0; ///< Current step in generation
  212. int totalSteps = 0; ///< Total steps in generation
  213. int64_t timeElapsed = 0; ///< Time elapsed in milliseconds
  214. int64_t timeRemaining = 0; ///< Estimated time remaining in milliseconds
  215. float speed = 0.0f; ///< Generation speed in steps per second
  216. bool firstGenerationCallback = true; ///< Flag to track if this is the first generation callback
  217. // Enhanced fields for repeatable generation
  218. std::string modelName; ///< Name of the model used
  219. std::string modelHash; ///< SHA256 hash of the model
  220. std::string modelPath; ///< Full path to the model file
  221. std::string negativePrompt; ///< Negative prompt used
  222. int width = 512; ///< Image width
  223. int height = 512; ///< Image height
  224. int batchCount = 1; ///< Number of images generated
  225. int steps = 20; ///< Number of diffusion steps
  226. float cfgScale = 7.5f; ///< CFG scale
  227. SamplingMethod samplingMethod = SamplingMethod::DEFAULT; ///< Sampling method used
  228. Scheduler scheduler = Scheduler::DEFAULT; ///< Scheduler used
  229. std::string seed; ///< Seed used for generation
  230. int64_t actualSeed = 0; ///< Actual seed that was used (for random seeds)
  231. std::string requestType; ///< Request type: text2img, img2img, controlnet, upscaler, inpainting
  232. float strength = 0.75f; ///< Strength for img2img
  233. float controlStrength = 0.9f; ///< ControlNet strength
  234. int clipSkip = -1; ///< CLIP skip layers
  235. int nThreads = -1; ///< Number of threads used
  236. bool offloadParamsToCpu = false; ///< Offload parameters to CPU setting
  237. bool clipOnCpu = false; ///< Keep CLIP on CPU setting
  238. bool vaeOnCpu = false; ///< Keep VAE on CPU setting
  239. bool diffusionFlashAttn = false; ///< Use flash attention setting
  240. bool diffusionConvDirect = false; ///< Use direct convolution setting
  241. bool vaeConvDirect = false; ///< Use direct VAE convolution setting
  242. uint64_t generationTime = 0; ///< Total generation time in milliseconds
  243. // Image data for complex operations (base64 encoded)
  244. std::string initImageData; ///< Init image data for img2img (base64)
  245. int initImageWidth = 0; ///< Init image width
  246. int initImageHeight = 0; ///< Init image height
  247. int initImageChannels = 3; ///< Init image channels
  248. std::string controlImageData; ///< Control image data for ControlNet (base64)
  249. int controlImageWidth = 0; ///< Control image width
  250. int controlImageHeight = 0; ///< Control image height
  251. int controlImageChannels = 3; ///< Control image channels
  252. std::string maskImageData; ///< Mask image data for inpainting (base64)
  253. int maskImageWidth = 0; ///< Mask image width
  254. int maskImageHeight = 0; ///< Mask image height
  255. int maskImageChannels = 1; ///< Mask image channels (grayscale)
  256. // Model paths for advanced usage
  257. std::string clipLPath; ///< Path to CLIP-L model
  258. std::string clipGPath; ///< Path to CLIP-G model
  259. std::string vaePath; ///< Path to VAE model
  260. std::string taesdPath; ///< Path to TAESD model
  261. std::string controlNetPath; ///< Path to ControlNet model
  262. std::string embeddingDir; ///< Path to embeddings directory
  263. std::string loraModelDir; ///< Path to LoRA model directory
  264. std::string esrganPath; ///< Path to ESRGAN model for upscaling
  265. uint32_t upscaleFactor = 4; ///< Upscale factor used
  266. };
  267. /**
  268. * @brief Generation queue class for managing image generation requests
  269. *
  270. * This class manages a queue of image generation requests, processes them
  271. * asynchronously, and provides thread-safe access to the queue and results.
  272. * Only one generation job is processed at a time as specified in requirements.
  273. */
  274. class GenerationQueue {
  275. public:
  276. /**
  277. * @brief Construct a new Generation Queue object
  278. *
  279. * @param modelManager Pointer to the model manager
  280. * @param maxConcurrentGenerations Maximum number of concurrent generations (should be 1)
  281. * @param queueDir Directory to store job persistence files
  282. * @param outputDir Directory to store generated output files
  283. */
  284. explicit GenerationQueue(class ModelManager* modelManager, int maxConcurrentGenerations = 1, const std::string& queueDir = "./queue", const std::string& outputDir = "./output");
  285. /**
  286. * @brief Destroy the Generation Queue object
  287. */
  288. virtual ~GenerationQueue();
  289. /**
  290. * @brief Add a generation request to the queue
  291. *
  292. * @param request The generation request
  293. * @return std::future<GenerationResult> Future for the generation result
  294. */
  295. std::future<GenerationResult> enqueueRequest(const GenerationRequest& request);
  296. /**
  297. * @brief Add a hash request to the queue
  298. *
  299. * @param request The hash request
  300. * @return std::future<HashResult> Future for the hash result
  301. */
  302. std::future<HashResult> enqueueHashRequest(const HashRequest& request);
  303. /**
  304. * @brief Add a conversion request to the queue
  305. *
  306. * @param request The conversion request
  307. * @return std::future<ConversionResult> Future for the conversion result
  308. */
  309. std::future<ConversionResult> enqueueConversionRequest(const ConversionRequest& request);
  310. /**
  311. * @brief Get the current queue size
  312. *
  313. * @return size_t Number of requests in the queue
  314. */
  315. size_t getQueueSize() const;
  316. /**
  317. * @brief Get the number of active generations
  318. *
  319. * @return size_t Number of currently processing requests
  320. */
  321. size_t getActiveGenerations() const;
  322. /**
  323. * @brief Get detailed queue status
  324. *
  325. * @return std::vector<JobInfo> List of all jobs with their status
  326. */
  327. std::vector<JobInfo> getQueueStatus() const;
  328. /**
  329. * @brief Get job information by ID
  330. *
  331. * @param jobId The job ID to look up
  332. * @return JobInfo Job information, or empty if not found
  333. */
  334. JobInfo getJobInfo(const std::string& jobId) const;
  335. /**
  336. * @brief Cancel a pending job
  337. *
  338. * @param jobId The job ID to cancel
  339. * @return true if job was cancelled, false if not found or already processing
  340. */
  341. bool cancelJob(const std::string& jobId);
  342. /**
  343. * @brief Clear all pending requests
  344. */
  345. void clearQueue();
  346. /**
  347. * @brief Start the queue processing thread
  348. */
  349. void start();
  350. /**
  351. * @brief Stop the queue processing thread
  352. */
  353. void stop();
  354. /**
  355. * @brief Check if the queue is running
  356. *
  357. * @return true if the queue is running, false otherwise
  358. */
  359. bool isRunning() const;
  360. /**
  361. * @brief Set the maximum number of concurrent generations
  362. *
  363. * @param maxConcurrent Maximum number of concurrent generations
  364. */
  365. void setMaxConcurrentGenerations(int maxConcurrent);
  366. private:
  367. class Impl;
  368. std::unique_ptr<Impl> pImpl; // Pimpl idiom
  369. };
  370. #endif // GENERATION_QUEUE_H