generation_queue.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. #ifndef GENERATION_QUEUE_H
  2. #define GENERATION_QUEUE_H
  3. #include <string>
  4. #include <memory>
  5. #include <queue>
  6. #include <mutex>
  7. #include <condition_variable>
  8. #include <functional>
  9. #include <future>
  10. #include <thread>
  11. #include <atomic>
  12. #include <unordered_map>
  13. #include <chrono>
  14. #include <vector>
  15. #include <map>
  16. /**
  17. * @brief Job type enumeration
  18. */
  19. enum class JobType {
  20. GENERATION, ///< Image generation job
  21. HASHING, ///< Model hashing job
  22. CONVERSION ///< Model conversion/quantization job
  23. };
  24. /**
  25. * @brief Generation status enumeration
  26. */
  27. enum class GenerationStatus {
  28. QUEUED, ///< Request is queued and waiting to be processed
  29. PROCESSING, ///< Request is currently being processed
  30. COMPLETED, ///< Request completed successfully
  31. FAILED ///< Request failed during processing
  32. };
  33. /**
  34. * @brief Sampling method enumeration (matching stable-diffusion.cpp)
  35. */
  36. enum class SamplingMethod {
  37. EULER,
  38. EULER_A,
  39. HEUN,
  40. DPM2,
  41. DPMPP2S_A,
  42. DPMPP2M,
  43. DPMPP2MV2,
  44. IPNDM,
  45. IPNDM_V,
  46. LCM,
  47. DDIM_TRAILING,
  48. TCD,
  49. DEFAULT ///< Use model default
  50. };
  51. /**
  52. * @brief Scheduler enumeration (matching stable-diffusion.cpp)
  53. */
  54. enum class Scheduler {
  55. DISCRETE,
  56. KARRAS,
  57. EXPONENTIAL,
  58. AYS,
  59. GITS,
  60. SMOOTHSTEP,
  61. SGM_UNIFORM,
  62. SIMPLE,
  63. DEFAULT ///< Use model default
  64. };
  65. /**
  66. * @brief Generation request structure with all stable-diffusion.cpp parameters
  67. */
  68. struct GenerationRequest {
  69. // Basic parameters
  70. std::string id; ///< Unique request ID
  71. std::string modelName; ///< Name of the model to use
  72. std::string prompt; ///< Text prompt for generation
  73. std::string negativePrompt; ///< Negative prompt (optional)
  74. // Image parameters
  75. int width = 512; ///< Image width
  76. int height = 512; ///< Image height
  77. int batchCount = 1; ///< Number of images to generate
  78. // Sampling parameters
  79. int steps = 20; ///< Number of diffusion steps
  80. float cfgScale = 7.5f; ///< CFG scale
  81. SamplingMethod samplingMethod = SamplingMethod::DEFAULT; ///< Sampling method
  82. Scheduler scheduler = Scheduler::DEFAULT; ///< Scheduler
  83. // Seed control
  84. std::string seed = "42"; ///< Seed for generation ("random" for random)
  85. // Model paths (for advanced usage)
  86. std::string clipLPath; ///< Path to CLIP-L model
  87. std::string clipGPath; ///< Path to CLIP-G model
  88. std::string clipVisionPath; ///< Path to CLIP-Vision model
  89. std::string t5xxlPath; ///< Path to T5-XXL model
  90. std::string qwen2vlPath; ///< Path to Qwen2VL model
  91. std::string qwen2vlVisionPath; ///< Path to Qwen2VL Vision model
  92. std::string diffusionModelPath; ///< Path to standalone diffusion model
  93. std::string vaePath; ///< Path to VAE model
  94. std::string taesdPath; ///< Path to TAESD model
  95. std::string controlNetPath; ///< Path to ControlNet model
  96. std::string embeddingDir; ///< Path to embeddings directory
  97. std::string loraModelDir; ///< Path to LoRA model directory
  98. // Advanced parameters
  99. int clipSkip = -1; ///< CLIP skip layers
  100. std::vector<int> skipLayers = {7, 8, 9}; ///< Layers to skip for SLG
  101. float strength = 0.75f; ///< Strength for img2img
  102. float controlStrength = 0.9f; ///< ControlNet strength
  103. // Performance parameters
  104. int nThreads = -1; ///< Number of threads (-1 for auto)
  105. bool offloadParamsToCpu = false; ///< Offload parameters to CPU
  106. bool clipOnCpu = false; ///< Keep CLIP on CPU
  107. bool vaeOnCpu = false; ///< Keep VAE on CPU
  108. bool diffusionFlashAttn = false; ///< Use flash attention
  109. bool diffusionConvDirect = false; ///< Use direct convolution
  110. bool vaeConvDirect = false; ///< Use direct VAE convolution
  111. // Output parameters
  112. std::string outputPath; ///< Output path for generated images
  113. // Image-to-image parameters
  114. std::string initImagePath; ///< Path to init image for img2img (can be file path or base64)
  115. std::vector<uint8_t> initImageData; ///< Init image data (decoded)
  116. int initImageWidth = 0; ///< Init image width
  117. int initImageHeight = 0; ///< Init image height
  118. int initImageChannels = 3; ///< Init image channels
  119. // ControlNet parameters
  120. std::string controlImagePath; ///< Path to control image for ControlNet
  121. std::vector<uint8_t> controlImageData; ///< Control image data (decoded)
  122. int controlImageWidth = 0; ///< Control image width
  123. int controlImageHeight = 0; ///< Control image height
  124. int controlImageChannels = 3; ///< Control image channels
  125. // Upscaler parameters
  126. std::string esrganPath; ///< Path to ESRGAN model for upscaling
  127. uint32_t upscaleFactor = 4; ///< Upscale factor (2 or 4)
  128. // Inpainting parameters
  129. std::string maskImagePath; ///< Path to mask image for inpainting
  130. std::vector<uint8_t> maskImageData; ///< Mask image data (decoded)
  131. int maskImageWidth = 0; ///< Mask image width
  132. int maskImageHeight = 0; ///< Mask image height
  133. int maskImageChannels = 1; ///< Mask image channels (grayscale)
  134. // Request type
  135. enum class RequestType {
  136. TEXT2IMG,
  137. IMG2IMG,
  138. CONTROLNET,
  139. UPSCALER,
  140. INPAINTING
  141. } requestType = RequestType::TEXT2IMG;
  142. // Callback for completion
  143. std::function<void(const std::string&, const std::string&)> callback; ///< Callback for completion
  144. };
  145. /**
  146. * @brief Generation result structure
  147. */
  148. struct GenerationResult {
  149. std::string requestId; ///< ID of the original request
  150. GenerationStatus status; ///< Final status of the generation
  151. bool success; ///< Whether generation was successful
  152. std::vector<std::string> imagePaths; ///< Paths to generated images (multiple for batch)
  153. std::string errorMessage; ///< Error message if generation failed
  154. uint64_t generationTime; ///< Time taken for generation in milliseconds
  155. int64_t actualSeed; ///< Actual seed used for generation
  156. };
  157. /**
  158. * @brief Hash request structure for model hashing jobs
  159. */
  160. struct HashRequest {
  161. std::string id; ///< Unique request ID
  162. std::vector<std::string> modelNames; ///< Model names to hash (empty = hash all unhashed)
  163. bool forceRehash = false; ///< Force rehash even if hash exists
  164. };
  165. /**
  166. * @brief Hash result structure
  167. */
  168. struct HashResult {
  169. std::string requestId; ///< ID of the original request
  170. GenerationStatus status; ///< Final status
  171. bool success; ///< Whether hashing was successful
  172. std::map<std::string, std::string> modelHashes; ///< Map of model names to their hashes
  173. std::string errorMessage; ///< Error message if hashing failed
  174. uint64_t hashingTime; ///< Time taken for hashing in milliseconds
  175. int modelsHashed; ///< Number of models successfully hashed
  176. };
  177. /**
  178. * @brief Conversion request structure for model quantization/conversion jobs
  179. */
  180. struct ConversionRequest {
  181. std::string id; ///< Unique request ID
  182. std::string modelName; ///< Model name to convert
  183. std::string modelPath; ///< Full path to model file
  184. std::string outputPath; ///< Output path for converted model
  185. std::string quantizationType; ///< Quantization type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)
  186. };
  187. /**
  188. * @brief Conversion result structure
  189. */
  190. struct ConversionResult {
  191. std::string requestId; ///< ID of the original request
  192. GenerationStatus status; ///< Final status
  193. bool success; ///< Whether conversion was successful
  194. std::string outputPath; ///< Path to converted model file
  195. std::string errorMessage; ///< Error message if conversion failed
  196. uint64_t conversionTime; ///< Time taken for conversion in milliseconds
  197. std::string originalSize; ///< Original model file size
  198. std::string convertedSize; ///< Converted model file size
  199. };
  200. /**
  201. * @brief Job information for queue status
  202. */
  203. struct JobInfo {
  204. std::string id; ///< Job ID
  205. JobType type; ///< Job type (generation or hashing)
  206. GenerationStatus status; ///< Current status
  207. std::string prompt; ///< Job prompt (full text for generation, or model name for hashing)
  208. std::chrono::steady_clock::time_point queuedTime; ///< When job was queued
  209. std::chrono::steady_clock::time_point startTime; ///< When job started processing
  210. std::chrono::steady_clock::time_point endTime; ///< When job completed/failed
  211. int position; ///< Position in queue (for queued jobs)
  212. std::vector<std::string> outputFiles; ///< Paths to generated output files
  213. std::string errorMessage; ///< Error message if job failed
  214. float progress = 0.0f; ///< Generation progress (0.0 to 1.0)
  215. int currentStep = 0; ///< Current step in generation
  216. int totalSteps = 0; ///< Total steps in generation
  217. int64_t timeElapsed = 0; ///< Time elapsed in milliseconds
  218. int64_t timeRemaining = 0; ///< Estimated time remaining in milliseconds
  219. float speed = 0.0f; ///< Generation speed in steps per second
  220. };
  221. /**
  222. * @brief Generation queue class for managing image generation requests
  223. *
  224. * This class manages a queue of image generation requests, processes them
  225. * asynchronously, and provides thread-safe access to the queue and results.
  226. * Only one generation job is processed at a time as specified in requirements.
  227. */
  228. class GenerationQueue {
  229. public:
  230. /**
  231. * @brief Construct a new Generation Queue object
  232. *
  233. * @param modelManager Pointer to the model manager
  234. * @param maxConcurrentGenerations Maximum number of concurrent generations (should be 1)
  235. * @param queueDir Directory to store job persistence files
  236. * @param outputDir Directory to store generated output files
  237. */
  238. explicit GenerationQueue(class ModelManager* modelManager, int maxConcurrentGenerations = 1,
  239. const std::string& queueDir = "./queue", const std::string& outputDir = "./output");
  240. /**
  241. * @brief Destroy the Generation Queue object
  242. */
  243. virtual ~GenerationQueue();
  244. /**
  245. * @brief Add a generation request to the queue
  246. *
  247. * @param request The generation request
  248. * @return std::future<GenerationResult> Future for the generation result
  249. */
  250. std::future<GenerationResult> enqueueRequest(const GenerationRequest& request);
  251. /**
  252. * @brief Add a hash request to the queue
  253. *
  254. * @param request The hash request
  255. * @return std::future<HashResult> Future for the hash result
  256. */
  257. std::future<HashResult> enqueueHashRequest(const HashRequest& request);
  258. /**
  259. * @brief Add a conversion request to the queue
  260. *
  261. * @param request The conversion request
  262. * @return std::future<ConversionResult> Future for the conversion result
  263. */
  264. std::future<ConversionResult> enqueueConversionRequest(const ConversionRequest& request);
  265. /**
  266. * @brief Get the current queue size
  267. *
  268. * @return size_t Number of requests in the queue
  269. */
  270. size_t getQueueSize() const;
  271. /**
  272. * @brief Get the number of active generations
  273. *
  274. * @return size_t Number of currently processing requests
  275. */
  276. size_t getActiveGenerations() const;
  277. /**
  278. * @brief Get detailed queue status
  279. *
  280. * @return std::vector<JobInfo> List of all jobs with their status
  281. */
  282. std::vector<JobInfo> getQueueStatus() const;
  283. /**
  284. * @brief Get job information by ID
  285. *
  286. * @param jobId The job ID to look up
  287. * @return JobInfo Job information, or empty if not found
  288. */
  289. JobInfo getJobInfo(const std::string& jobId) const;
  290. /**
  291. * @brief Cancel a pending job
  292. *
  293. * @param jobId The job ID to cancel
  294. * @return true if job was cancelled, false if not found or already processing
  295. */
  296. bool cancelJob(const std::string& jobId);
  297. /**
  298. * @brief Clear all pending requests
  299. */
  300. void clearQueue();
  301. /**
  302. * @brief Start the queue processing thread
  303. */
  304. void start();
  305. /**
  306. * @brief Stop the queue processing thread
  307. */
  308. void stop();
  309. /**
  310. * @brief Check if the queue is running
  311. *
  312. * @return true if the queue is running, false otherwise
  313. */
  314. bool isRunning() const;
  315. /**
  316. * @brief Set the maximum number of concurrent generations
  317. *
  318. * @param maxConcurrent Maximum number of concurrent generations
  319. */
  320. void setMaxConcurrentGenerations(int maxConcurrent);
  321. private:
  322. class Impl;
  323. std::unique_ptr<Impl> pImpl; // Pimpl idiom
  324. };
  325. #endif // GENERATION_QUEUE_H