generation_queue.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. #ifndef GENERATION_QUEUE_H
  2. #define GENERATION_QUEUE_H
  3. #include <string>
  4. #include <memory>
  5. #include <queue>
  6. #include <mutex>
  7. #include <condition_variable>
  8. #include <functional>
  9. #include <future>
  10. #include <thread>
  11. #include <atomic>
  12. #include <unordered_map>
  13. #include <chrono>
  14. #include <vector>
  15. #include <map>
  16. /**
  17. * @brief Job type enumeration
  18. */
  19. enum class JobType {
  20. GENERATION, ///< Image generation job
  21. HASHING, ///< Model hashing job
  22. CONVERSION ///< Model conversion/quantization job
  23. };
  24. /**
  25. * @brief Generation status enumeration
  26. */
  27. enum class GenerationStatus {
  28. QUEUED, ///< Request is queued and waiting to be processed
  29. PROCESSING, ///< Request is currently being processed
  30. COMPLETED, ///< Request completed successfully
  31. FAILED ///< Request failed during processing
  32. };
  33. /**
  34. * @brief Sampling method enumeration (matching stable-diffusion.cpp)
  35. */
  36. enum class SamplingMethod {
  37. EULER,
  38. EULER_A,
  39. HEUN,
  40. DPM2,
  41. DPMPP2S_A,
  42. DPMPP2M,
  43. DPMPP2MV2,
  44. IPNDM,
  45. IPNDM_V,
  46. LCM,
  47. DDIM_TRAILING,
  48. TCD,
  49. DEFAULT ///< Use model default
  50. };
  51. /**
  52. * @brief Scheduler enumeration (matching stable-diffusion.cpp)
  53. */
  54. enum class Scheduler {
  55. DISCRETE,
  56. KARRAS,
  57. EXPONENTIAL,
  58. AYS,
  59. GITS,
  60. SMOOTHSTEP,
  61. SGM_UNIFORM,
  62. SIMPLE,
  63. DEFAULT ///< Use model default
  64. };
  65. /**
  66. * @brief Generation request structure with all stable-diffusion.cpp parameters
  67. */
  68. struct GenerationRequest {
  69. // Basic parameters
  70. std::string id; ///< Unique request ID
  71. std::string modelName; ///< Name of the model to use
  72. std::string prompt; ///< Text prompt for generation
  73. std::string negativePrompt; ///< Negative prompt (optional)
  74. // Image parameters
  75. int width = 512; ///< Image width
  76. int height = 512; ///< Image height
  77. int batchCount = 1; ///< Number of images to generate
  78. // Sampling parameters
  79. int steps = 20; ///< Number of diffusion steps
  80. float cfgScale = 7.5f; ///< CFG scale
  81. SamplingMethod samplingMethod = SamplingMethod::DEFAULT; ///< Sampling method
  82. Scheduler scheduler = Scheduler::DEFAULT; ///< Scheduler
  83. // Seed control
  84. std::string seed = "42"; ///< Seed for generation ("random" for random)
  85. // Model paths (for advanced usage)
  86. std::string clipLPath; ///< Path to CLIP-L model
  87. std::string clipGPath; ///< Path to CLIP-G model
  88. std::string clipVisionPath; ///< Path to CLIP-Vision model
  89. std::string t5xxlPath; ///< Path to T5-XXL model
  90. std::string qwen2vlPath; ///< Path to Qwen2VL model
  91. std::string qwen2vlVisionPath; ///< Path to Qwen2VL Vision model
  92. std::string diffusionModelPath; ///< Path to standalone diffusion model
  93. std::string vaePath; ///< Path to VAE model
  94. std::string taesdPath; ///< Path to TAESD model
  95. std::string controlNetPath; ///< Path to ControlNet model
  96. std::string embeddingDir; ///< Path to embeddings directory
  97. std::string loraModelDir; ///< Path to LoRA model directory
  98. // Advanced parameters
  99. int clipSkip = -1; ///< CLIP skip layers
  100. std::vector<int> skipLayers = {7, 8, 9}; ///< Layers to skip for SLG
  101. float strength = 0.75f; ///< Strength for img2img
  102. float controlStrength = 0.9f; ///< ControlNet strength
  103. // Performance parameters
  104. int nThreads = -1; ///< Number of threads (-1 for auto)
  105. bool offloadParamsToCpu = false; ///< Offload parameters to CPU
  106. bool clipOnCpu = false; ///< Keep CLIP on CPU
  107. bool vaeOnCpu = false; ///< Keep VAE on CPU
  108. bool diffusionFlashAttn = false; ///< Use flash attention
  109. bool diffusionConvDirect = false; ///< Use direct convolution
  110. bool vaeConvDirect = false; ///< Use direct VAE convolution
  111. // Output parameters
  112. std::string outputPath; ///< Output path for generated images
  113. // Image-to-image parameters
  114. std::string initImagePath; ///< Path to init image for img2img (can be file path or base64)
  115. std::vector<uint8_t> initImageData; ///< Init image data (decoded)
  116. int initImageWidth = 0; ///< Init image width
  117. int initImageHeight = 0; ///< Init image height
  118. int initImageChannels = 3; ///< Init image channels
  119. // ControlNet parameters
  120. std::string controlImagePath; ///< Path to control image for ControlNet
  121. std::vector<uint8_t> controlImageData; ///< Control image data (decoded)
  122. int controlImageWidth = 0; ///< Control image width
  123. int controlImageHeight = 0; ///< Control image height
  124. int controlImageChannels = 3; ///< Control image channels
  125. // Upscaler parameters
  126. std::string esrganPath; ///< Path to ESRGAN model for upscaling
  127. uint32_t upscaleFactor = 4; ///< Upscale factor (2 or 4)
  128. // Request type
  129. enum class RequestType {
  130. TEXT2IMG,
  131. IMG2IMG,
  132. CONTROLNET,
  133. UPSCALER
  134. } requestType = RequestType::TEXT2IMG;
  135. // Callback for completion
  136. std::function<void(const std::string&, const std::string&)> callback; ///< Callback for completion
  137. };
  138. /**
  139. * @brief Generation result structure
  140. */
  141. struct GenerationResult {
  142. std::string requestId; ///< ID of the original request
  143. GenerationStatus status; ///< Final status of the generation
  144. bool success; ///< Whether generation was successful
  145. std::vector<std::string> imagePaths; ///< Paths to generated images (multiple for batch)
  146. std::string errorMessage; ///< Error message if generation failed
  147. uint64_t generationTime; ///< Time taken for generation in milliseconds
  148. int64_t actualSeed; ///< Actual seed used for generation
  149. };
  150. /**
  151. * @brief Hash request structure for model hashing jobs
  152. */
  153. struct HashRequest {
  154. std::string id; ///< Unique request ID
  155. std::vector<std::string> modelNames; ///< Model names to hash (empty = hash all unhashed)
  156. bool forceRehash = false; ///< Force rehash even if hash exists
  157. };
  158. /**
  159. * @brief Hash result structure
  160. */
  161. struct HashResult {
  162. std::string requestId; ///< ID of the original request
  163. GenerationStatus status; ///< Final status
  164. bool success; ///< Whether hashing was successful
  165. std::map<std::string, std::string> modelHashes; ///< Map of model names to their hashes
  166. std::string errorMessage; ///< Error message if hashing failed
  167. uint64_t hashingTime; ///< Time taken for hashing in milliseconds
  168. int modelsHashed; ///< Number of models successfully hashed
  169. };
  170. /**
  171. * @brief Conversion request structure for model quantization/conversion jobs
  172. */
  173. struct ConversionRequest {
  174. std::string id; ///< Unique request ID
  175. std::string modelName; ///< Model name to convert
  176. std::string modelPath; ///< Full path to model file
  177. std::string outputPath; ///< Output path for converted model
  178. std::string quantizationType; ///< Quantization type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)
  179. };
  180. /**
  181. * @brief Conversion result structure
  182. */
  183. struct ConversionResult {
  184. std::string requestId; ///< ID of the original request
  185. GenerationStatus status; ///< Final status
  186. bool success; ///< Whether conversion was successful
  187. std::string outputPath; ///< Path to converted model file
  188. std::string errorMessage; ///< Error message if conversion failed
  189. uint64_t conversionTime; ///< Time taken for conversion in milliseconds
  190. std::string originalSize; ///< Original model file size
  191. std::string convertedSize; ///< Converted model file size
  192. };
  193. /**
  194. * @brief Job information for queue status
  195. */
  196. struct JobInfo {
  197. std::string id; ///< Job ID
  198. JobType type; ///< Job type (generation or hashing)
  199. GenerationStatus status; ///< Current status
  200. std::string prompt; ///< Job prompt (full text for generation, or model name for hashing)
  201. std::chrono::steady_clock::time_point queuedTime; ///< When job was queued
  202. std::chrono::steady_clock::time_point startTime; ///< When job started processing
  203. std::chrono::steady_clock::time_point endTime; ///< When job completed/failed
  204. int position; ///< Position in queue (for queued jobs)
  205. std::vector<std::string> outputFiles; ///< Paths to generated output files
  206. std::string errorMessage; ///< Error message if job failed
  207. float progress = 0.0f; ///< Generation progress (0.0 to 1.0)
  208. };
  209. /**
  210. * @brief Generation queue class for managing image generation requests
  211. *
  212. * This class manages a queue of image generation requests, processes them
  213. * asynchronously, and provides thread-safe access to the queue and results.
  214. * Only one generation job is processed at a time as specified in requirements.
  215. */
  216. class GenerationQueue {
  217. public:
  218. /**
  219. * @brief Construct a new Generation Queue object
  220. *
  221. * @param modelManager Pointer to the model manager
  222. * @param maxConcurrentGenerations Maximum number of concurrent generations (should be 1)
  223. * @param queueDir Directory to store job persistence files
  224. * @param outputDir Directory to store generated output files
  225. */
  226. explicit GenerationQueue(class ModelManager* modelManager, int maxConcurrentGenerations = 1,
  227. const std::string& queueDir = "./queue", const std::string& outputDir = "./output");
  228. /**
  229. * @brief Destroy the Generation Queue object
  230. */
  231. virtual ~GenerationQueue();
  232. /**
  233. * @brief Add a generation request to the queue
  234. *
  235. * @param request The generation request
  236. * @return std::future<GenerationResult> Future for the generation result
  237. */
  238. std::future<GenerationResult> enqueueRequest(const GenerationRequest& request);
  239. /**
  240. * @brief Add a hash request to the queue
  241. *
  242. * @param request The hash request
  243. * @return std::future<HashResult> Future for the hash result
  244. */
  245. std::future<HashResult> enqueueHashRequest(const HashRequest& request);
  246. /**
  247. * @brief Add a conversion request to the queue
  248. *
  249. * @param request The conversion request
  250. * @return std::future<ConversionResult> Future for the conversion result
  251. */
  252. std::future<ConversionResult> enqueueConversionRequest(const ConversionRequest& request);
  253. /**
  254. * @brief Get the current queue size
  255. *
  256. * @return size_t Number of requests in the queue
  257. */
  258. size_t getQueueSize() const;
  259. /**
  260. * @brief Get the number of active generations
  261. *
  262. * @return size_t Number of currently processing requests
  263. */
  264. size_t getActiveGenerations() const;
  265. /**
  266. * @brief Get detailed queue status
  267. *
  268. * @return std::vector<JobInfo> List of all jobs with their status
  269. */
  270. std::vector<JobInfo> getQueueStatus() const;
  271. /**
  272. * @brief Get job information by ID
  273. *
  274. * @param jobId The job ID to look up
  275. * @return JobInfo Job information, or empty if not found
  276. */
  277. JobInfo getJobInfo(const std::string& jobId) const;
  278. /**
  279. * @brief Cancel a pending job
  280. *
  281. * @param jobId The job ID to cancel
  282. * @return true if job was cancelled, false if not found or already processing
  283. */
  284. bool cancelJob(const std::string& jobId);
  285. /**
  286. * @brief Clear all pending requests
  287. */
  288. void clearQueue();
  289. /**
  290. * @brief Start the queue processing thread
  291. */
  292. void start();
  293. /**
  294. * @brief Stop the queue processing thread
  295. */
  296. void stop();
  297. /**
  298. * @brief Check if the queue is running
  299. *
  300. * @return true if the queue is running, false otherwise
  301. */
  302. bool isRunning() const;
  303. /**
  304. * @brief Set the maximum number of concurrent generations
  305. *
  306. * @param maxConcurrent Maximum number of concurrent generations
  307. */
  308. void setMaxConcurrentGenerations(int maxConcurrent);
  309. private:
  310. class Impl;
  311. std::unique_ptr<Impl> pImpl; // Pimpl idiom
  312. };
  313. #endif // GENERATION_QUEUE_H