generation_queue.cpp 46 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196
  1. #include "generation_queue.h"
  2. #include "model_manager.h"
  3. #include "stable_diffusion_wrapper.h"
  4. #include "utils.h"
  5. #include <iostream>
  6. #include <random>
  7. #include <sstream>
  8. #include <iomanip>
  9. #include <algorithm>
  10. #include <fstream>
  11. #include <filesystem>
  12. #include <nlohmann/json.hpp>
  13. #define STB_IMAGE_WRITE_IMPLEMENTATION
  14. #include "../stable-diffusion.cpp-src/thirdparty/stb_image_write.h"
  15. #define STB_IMAGE_IMPLEMENTATION
  16. #include "../stable-diffusion.cpp-src/thirdparty/stb_image.h"
  17. class GenerationQueue::Impl {
  18. public:
  19. // Model manager reference
  20. ModelManager* modelManager = nullptr;
  21. // Thread management
  22. std::thread workerThread;
  23. std::atomic<bool> running{false};
  24. std::atomic<bool> stopRequested{false};
  25. // Queue management
  26. mutable std::mutex queueMutex;
  27. std::condition_variable queueCondition;
  28. std::queue<GenerationRequest> requestQueue;
  29. // Job tracking
  30. mutable std::mutex jobsMutex;
  31. std::unordered_map<std::string, JobInfo> activeJobs;
  32. std::unordered_map<std::string, std::promise<GenerationResult>> jobPromises;
  33. // Hash job tracking
  34. std::map<std::string, std::shared_ptr<std::promise<HashResult>>> hashPromises;
  35. std::map<std::string, HashRequest> hashRequests;
  36. // Conversion job tracking
  37. std::map<std::string, std::shared_ptr<std::promise<ConversionResult>>> conversionPromises;
  38. std::map<std::string, ConversionRequest> conversionRequests;
  39. // Configuration
  40. int maxConcurrentGenerations = 1;
  41. std::string queueDir = "./queue";
  42. std::string outputDir = "./output";
  43. // Statistics
  44. std::atomic<size_t> queueSize{0};
  45. std::atomic<size_t> activeGenerations{0};
  46. std::atomic<uint64_t> totalJobsProcessed{0};
  47. // Worker thread function
  48. void workerThreadFunction() {
  49. std::cout << "GenerationQueue worker thread started" << std::endl;
  50. while (running.load() && !stopRequested.load()) {
  51. std::unique_lock<std::mutex> lock(queueMutex);
  52. // Wait for a request or stop signal
  53. queueCondition.wait(lock, [this] {
  54. return !requestQueue.empty() || stopRequested.load();
  55. });
  56. if (stopRequested.load()) {
  57. break;
  58. }
  59. if (requestQueue.empty()) {
  60. continue;
  61. }
  62. // Get the next request
  63. GenerationRequest request = requestQueue.front();
  64. requestQueue.pop();
  65. queueSize.store(requestQueue.size());
  66. lock.unlock();
  67. // Process the request
  68. processRequest(request);
  69. }
  70. std::cout << "GenerationQueue worker thread stopped" << std::endl;
  71. }
  72. void processRequest(const GenerationRequest& request) {
  73. // Check if this is a hash job
  74. if (request.prompt == "HASH_JOB") {
  75. auto hashIt = hashRequests.find(request.id);
  76. if (hashIt != hashRequests.end()) {
  77. HashResult result = performHashJob(hashIt->second);
  78. auto promiseIt = hashPromises.find(request.id);
  79. if (promiseIt != hashPromises.end()) {
  80. promiseIt->second->set_value(result);
  81. hashPromises.erase(promiseIt);
  82. }
  83. hashRequests.erase(hashIt);
  84. }
  85. return;
  86. }
  87. // Check if this is a conversion job
  88. if (request.prompt == "CONVERSION_JOB") {
  89. auto convIt = conversionRequests.find(request.id);
  90. if (convIt != conversionRequests.end()) {
  91. ConversionResult result = performConversionJob(convIt->second);
  92. auto promiseIt = conversionPromises.find(request.id);
  93. if (promiseIt != conversionPromises.end()) {
  94. promiseIt->second->set_value(result);
  95. conversionPromises.erase(promiseIt);
  96. }
  97. conversionRequests.erase(convIt);
  98. }
  99. return;
  100. }
  101. auto startTime = std::chrono::steady_clock::now();
  102. // Update job status to PROCESSING
  103. {
  104. std::lock_guard<std::mutex> lock(jobsMutex);
  105. if (activeJobs.find(request.id) != activeJobs.end()) {
  106. activeJobs[request.id].status = GenerationStatus::PROCESSING;
  107. activeJobs[request.id].startTime = startTime;
  108. activeJobs[request.id].progress = 0.0f;
  109. activeJobs[request.id].currentStep = 0;
  110. activeJobs[request.id].totalSteps = 0;
  111. activeJobs[request.id].timeElapsed = 0;
  112. activeJobs[request.id].timeRemaining = 0;
  113. activeJobs[request.id].speed = 0.0f;
  114. saveJobToFile(activeJobs[request.id]);
  115. }
  116. }
  117. activeGenerations.store(1); // Only one generation at a time
  118. std::cout << "Processing generation request: " << request.id
  119. << " (prompt: " << request.prompt.substr(0, 50)
  120. << (request.prompt.length() > 50 ? "..." : "") << ")" << std::endl;
  121. // Real generation logic using stable-diffusion.cpp with progress tracking
  122. GenerationResult result = performActualGeneration(request, request.id);
  123. auto endTime = std::chrono::steady_clock::now();
  124. auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
  125. result.generationTime = duration.count();
  126. // Update job status to COMPLETED/FAILED
  127. {
  128. std::lock_guard<std::mutex> lock(jobsMutex);
  129. if (activeJobs.find(request.id) != activeJobs.end()) {
  130. activeJobs[request.id].status = result.success ? GenerationStatus::COMPLETED : GenerationStatus::FAILED;
  131. activeJobs[request.id].endTime = endTime;
  132. // Set final progress to 100% if successful
  133. if (result.success) {
  134. activeJobs[request.id].progress = 1.0f;
  135. if (activeJobs[request.id].totalSteps > 0) {
  136. activeJobs[request.id].currentStep = activeJobs[request.id].totalSteps;
  137. }
  138. }
  139. // Store output files and error message
  140. activeJobs[request.id].outputFiles = result.imagePaths;
  141. activeJobs[request.id].errorMessage = result.errorMessage;
  142. // Persist to disk
  143. saveJobToFile(activeJobs[request.id]);
  144. }
  145. // Set the promise value
  146. auto it = jobPromises.find(request.id);
  147. if (it != jobPromises.end()) {
  148. it->second.set_value(result);
  149. jobPromises.erase(it);
  150. }
  151. }
  152. activeGenerations.store(0);
  153. totalJobsProcessed.fetch_add(1);
  154. std::cout << "Completed generation request: " << request.id
  155. << " (success: " << (result.success ? "true" : "false")
  156. << ", time: " << result.generationTime << "ms)";
  157. if (!result.success && !result.errorMessage.empty()) {
  158. std::cout << " - Error: " << result.errorMessage;
  159. }
  160. std::cout << std::endl;
  161. }
  162. // Progress callback that updates the job info
  163. void updateJobProgress(const std::string& jobId, int step, int totalSteps, float progress, uint64_t timeElapsed) {
  164. std::lock_guard<std::mutex> lock(jobsMutex);
  165. auto it = activeJobs.find(jobId);
  166. if (it != activeJobs.end()) {
  167. it->second.progress = progress;
  168. it->second.currentStep = step;
  169. it->second.totalSteps = totalSteps;
  170. it->second.timeElapsed = static_cast<int64_t>(timeElapsed);
  171. // Calculate time remaining and speed
  172. if (step > 0 && timeElapsed > 0) {
  173. double avgStepTime = static_cast<double>(timeElapsed) / step;
  174. int remainingSteps = totalSteps - step;
  175. it->second.timeRemaining = static_cast<int64_t>(avgStepTime * remainingSteps);
  176. it->second.speed = 1000.0 / avgStepTime; // steps per second
  177. }
  178. // Save progress to file periodically (every 10 steps or on significant progress changes)
  179. if (step % 10 == 0 || progress >= 0.99f) {
  180. saveJobToFile(it->second);
  181. }
  182. }
  183. }
  184. GenerationResult performActualGeneration(const GenerationRequest& request, const std::string& jobId) {
  185. GenerationResult result;
  186. result.requestId = request.id;
  187. result.success = false;
  188. // Check if model manager is available
  189. if (!modelManager) {
  190. result.errorMessage = "Model manager not available";
  191. return result;
  192. }
  193. // Check if the model is loaded (DO NOT auto-load)
  194. if (!modelManager->isModelLoaded(request.modelName)) {
  195. result.errorMessage = "Model not loaded: " + request.modelName + ". Please load the model first using POST /api/models/{hash}/load";
  196. return result;
  197. }
  198. // Get the model wrapper from the shared model manager
  199. auto* modelWrapper = modelManager->getModel(request.modelName);
  200. if (!modelWrapper) {
  201. result.errorMessage = "Model not found or not loaded: " + request.modelName;
  202. return result;
  203. }
  204. // Prepare generation parameters
  205. StableDiffusionWrapper::GenerationParams params;
  206. params.prompt = request.prompt;
  207. params.negativePrompt = request.negativePrompt;
  208. params.width = request.width;
  209. params.height = request.height;
  210. params.batchCount = request.batchCount;
  211. params.steps = request.steps;
  212. params.cfgScale = request.cfgScale;
  213. params.samplingMethod = samplingMethodToString(request.samplingMethod);
  214. params.scheduler = schedulerToString(request.scheduler);
  215. params.clipSkip = request.clipSkip;
  216. params.strength = request.strength;
  217. params.controlStrength = request.controlStrength;
  218. params.nThreads = request.nThreads;
  219. params.offloadParamsToCpu = request.offloadParamsToCpu;
  220. params.clipOnCpu = request.clipOnCpu;
  221. params.vaeOnCpu = request.vaeOnCpu;
  222. params.diffusionFlashAttn = request.diffusionFlashAttn;
  223. params.diffusionConvDirect = request.diffusionConvDirect;
  224. params.vaeConvDirect = request.vaeConvDirect;
  225. // Set model paths if provided
  226. params.modelPath = modelManager->getModelInfo(request.modelName).path;
  227. params.clipLPath = request.clipLPath;
  228. params.clipGPath = request.clipGPath;
  229. params.vaePath = request.vaePath;
  230. params.taesdPath = request.taesdPath;
  231. params.controlNetPath = request.controlNetPath;
  232. params.embeddingDir = request.embeddingDir;
  233. params.loraModelDir = request.loraModelDir;
  234. // Parse seed
  235. if (request.seed == "random") {
  236. std::random_device rd;
  237. std::mt19937 gen(rd());
  238. std::uniform_int_distribution<int64_t> dis;
  239. params.seed = dis(gen);
  240. } else {
  241. try {
  242. params.seed = std::stoll(request.seed);
  243. } catch (...) {
  244. params.seed = 42; // Default seed
  245. }
  246. }
  247. result.actualSeed = params.seed;
  248. // Generate images based on request type with progress tracking
  249. try {
  250. std::vector<StableDiffusionWrapper::GeneratedImage> generatedImages;
  251. // Create progress callback that updates job info
  252. auto progressCallback = [this, jobId](int step, int totalSteps, float progress, void* userData) {
  253. // Calculate time elapsed from start time (stored in userData)
  254. auto startTime = userData ? *static_cast<std::chrono::steady_clock::time_point*>(userData) : std::chrono::steady_clock::now();
  255. auto currentTime = std::chrono::steady_clock::now();
  256. uint64_t timeElapsed = std::chrono::duration_cast<std::chrono::milliseconds>(currentTime - startTime).count();
  257. updateJobProgress(jobId, step, totalSteps, progress, timeElapsed);
  258. };
  259. // Store start time to pass as user data
  260. auto generationStartTime = std::chrono::steady_clock::now();
  261. switch (request.requestType) {
  262. case GenerationRequest::RequestType::TEXT2IMG:
  263. generatedImages = modelWrapper->generateImage(params, progressCallback, &generationStartTime);
  264. break;
  265. case GenerationRequest::RequestType::IMG2IMG:
  266. if (request.initImageData.empty()) {
  267. result.errorMessage = "No init image data provided for img2img";
  268. return result;
  269. }
  270. generatedImages = modelWrapper->generateImageImg2Img(
  271. params,
  272. request.initImageData,
  273. request.initImageWidth,
  274. request.initImageHeight,
  275. progressCallback,
  276. &generationStartTime
  277. );
  278. break;
  279. case GenerationRequest::RequestType::CONTROLNET:
  280. if (request.controlImageData.empty()) {
  281. result.errorMessage = "No control image data provided for ControlNet";
  282. return result;
  283. }
  284. generatedImages = modelWrapper->generateImageControlNet(
  285. params,
  286. request.controlImageData,
  287. request.controlImageWidth,
  288. request.controlImageHeight,
  289. progressCallback,
  290. &generationStartTime
  291. );
  292. break;
  293. case GenerationRequest::RequestType::UPSCALER:
  294. if (request.initImageData.empty()) {
  295. result.errorMessage = "No input image data provided for upscaling";
  296. return result;
  297. }
  298. if (request.esrganPath.empty()) {
  299. result.errorMessage = "No ESRGAN model path provided for upscaling";
  300. return result;
  301. }
  302. {
  303. auto upscaledImage = modelWrapper->upscaleImage(
  304. request.esrganPath,
  305. request.initImageData,
  306. request.initImageWidth,
  307. request.initImageHeight,
  308. request.initImageChannels,
  309. request.upscaleFactor,
  310. request.nThreads,
  311. request.offloadParamsToCpu,
  312. request.diffusionConvDirect
  313. );
  314. generatedImages.push_back(upscaledImage);
  315. }
  316. break;
  317. case GenerationRequest::RequestType::INPAINTING:
  318. if (request.initImageData.empty()) {
  319. result.errorMessage = "No source image data provided for inpainting";
  320. return result;
  321. }
  322. if (request.maskImageData.empty()) {
  323. result.errorMessage = "No mask image data provided for inpainting";
  324. return result;
  325. }
  326. generatedImages = modelWrapper->generateImageInpainting(
  327. params,
  328. request.initImageData,
  329. request.initImageWidth,
  330. request.initImageHeight,
  331. request.maskImageData,
  332. request.maskImageWidth,
  333. request.maskImageHeight,
  334. progressCallback,
  335. &generationStartTime
  336. );
  337. break;
  338. default:
  339. result.errorMessage = "Unknown request type";
  340. return result;
  341. }
  342. if (generatedImages.empty()) {
  343. result.errorMessage = "Failed to generate images: " + modelWrapper->getLastError();
  344. return result;
  345. }
  346. // Save generated images to files
  347. for (size_t i = 0; i < generatedImages.size(); i++) {
  348. const auto& image = generatedImages[i];
  349. std::string imagePath = saveImageToFile(image, request.id, i);
  350. if (!imagePath.empty()) {
  351. result.imagePaths.push_back(imagePath);
  352. } else {
  353. result.errorMessage = "Failed to save generated image " + std::to_string(i);
  354. return result;
  355. }
  356. }
  357. result.success = true;
  358. result.generationTime = generatedImages.empty() ? 0 : generatedImages[0].generationTime;
  359. result.errorMessage = "";
  360. } catch (const std::exception& e) {
  361. result.errorMessage = "Exception during generation: " + std::string(e.what());
  362. }
  363. return result;
  364. }
  365. std::string saveImageToFile(const StableDiffusionWrapper::GeneratedImage& image, const std::string& requestId, size_t index) {
  366. // Create job-specific output directory
  367. std::string jobOutputDir = outputDir + "/" + requestId;
  368. std::error_code ec;
  369. std::filesystem::create_directories(jobOutputDir, ec);
  370. if (ec) {
  371. std::cerr << "Failed to create output directory " << jobOutputDir
  372. << ": " << ec.message() << std::endl;
  373. return "";
  374. }
  375. // Generate filename
  376. std::stringstream ss;
  377. ss << jobOutputDir << "/" << requestId << "_" << index << ".png";
  378. std::string filename = ss.str();
  379. std::cout << "Attempting to save image to: " << filename << std::endl;
  380. // Check if image data is valid
  381. if (image.data.empty() || image.width <= 0 || image.height <= 0) {
  382. std::cerr << "Invalid image data for " << requestId << "_" << index
  383. << ": width=" << image.width
  384. << ", height=" << image.height
  385. << ", channels=" << image.channels
  386. << ", data_size=" << image.data.size() << std::endl;
  387. return "";
  388. }
  389. // Validate image data integrity
  390. const size_t expectedDataSize = static_cast<size_t>(image.width) * image.height * image.channels;
  391. if (image.data.size() != expectedDataSize) {
  392. std::cerr << "Image data size mismatch for " << requestId << "_" << index
  393. << ": expected=" << expectedDataSize
  394. << ", actual=" << image.data.size() << std::endl;
  395. // Continue anyway, but log the warning
  396. }
  397. // Check if we can write to the directory
  398. std::ofstream testFile(filename + ".test");
  399. if (!testFile.is_open()) {
  400. std::cerr << "Cannot write to directory " << jobOutputDir
  401. << ": permission denied or disk full" << std::endl;
  402. return "";
  403. }
  404. testFile.close();
  405. std::filesystem::remove(filename + ".test");
  406. // Write PNG file using stb_image_write with detailed error logging
  407. std::cout << "Writing PNG file: " << filename
  408. << " (size: " << image.width << "x" << image.height
  409. << "x" << image.channels << ")" << std::endl;
  410. int result = stbi_write_png(
  411. filename.c_str(),
  412. image.width,
  413. image.height,
  414. image.channels,
  415. image.data.data(),
  416. image.width * image.channels // stride in bytes
  417. );
  418. if (result == 0) {
  419. std::cerr << "stbi_write_png failed for " << filename << std::endl;
  420. // Try to get more detailed error information
  421. std::cerr << "Image details:" << std::endl;
  422. std::cerr << " Dimensions: " << image.width << "x" << image.height << std::endl;
  423. std::cerr << " Channels: " << image.channels << std::endl;
  424. std::cerr << " Data size: " << image.data.size() << " bytes" << std::endl;
  425. std::cerr << " Expected size: " << expectedDataSize << " bytes" << std::endl;
  426. std::cerr << " Stride: " << (image.width * image.channels) << " bytes" << std::endl;
  427. // Check if file was created but is empty
  428. if (std::filesystem::exists(filename)) {
  429. auto fileSize = std::filesystem::file_size(filename);
  430. std::cerr << " File exists but size is: " << fileSize << " bytes" << std::endl;
  431. if (fileSize == 0) {
  432. std::cerr << " ERROR: Zero-byte file created - stbi_write_png returned false but file exists" << std::endl;
  433. }
  434. } else {
  435. std::cerr << " File was not created" << std::endl;
  436. }
  437. // Check disk space
  438. try {
  439. auto space = std::filesystem::space(jobOutputDir);
  440. std::cerr << " Available disk space: " << (space.available / (1024 * 1024)) << " MB" << std::endl;
  441. } catch (const std::exception& e) {
  442. std::cerr << " Could not check disk space: " << e.what() << std::endl;
  443. }
  444. return "";
  445. }
  446. // Verify the file was created successfully and has content
  447. if (!std::filesystem::exists(filename)) {
  448. std::cerr << "ERROR: stbi_write_png returned success but file does not exist: " << filename << std::endl;
  449. return "";
  450. }
  451. auto fileSize = std::filesystem::file_size(filename);
  452. if (fileSize == 0) {
  453. std::cerr << "ERROR: stbi_write_png returned success but created zero-byte file: " << filename << std::endl;
  454. return "";
  455. }
  456. std::cout << "Successfully saved generated image to: " << filename
  457. << " (" << image.width << "x" << image.height
  458. << ", " << image.channels << " channels, "
  459. << image.data.size() << " data bytes, "
  460. << fileSize << " file bytes)" << std::endl;
  461. return filename;
  462. }
  463. std::string samplingMethodToString(SamplingMethod method) {
  464. switch (method) {
  465. case SamplingMethod::EULER: return "euler";
  466. case SamplingMethod::EULER_A: return "euler_a";
  467. case SamplingMethod::HEUN: return "heun";
  468. case SamplingMethod::DPM2: return "dpm2";
  469. case SamplingMethod::DPMPP2S_A: return "dpmpp2s_a";
  470. case SamplingMethod::DPMPP2M: return "dpmpp2m";
  471. case SamplingMethod::DPMPP2MV2: return "dpmpp2mv2";
  472. case SamplingMethod::IPNDM: return "ipndm";
  473. case SamplingMethod::IPNDM_V: return "ipndm_v";
  474. case SamplingMethod::LCM: return "lcm";
  475. case SamplingMethod::DDIM_TRAILING: return "ddim_trailing";
  476. case SamplingMethod::TCD: return "tcd";
  477. default: return "euler";
  478. }
  479. }
  480. std::string schedulerToString(Scheduler scheduler) {
  481. switch (scheduler) {
  482. case Scheduler::DISCRETE: return "discrete";
  483. case Scheduler::KARRAS: return "karras";
  484. case Scheduler::EXPONENTIAL: return "exponential";
  485. case Scheduler::AYS: return "ays";
  486. case Scheduler::GITS: return "gits";
  487. case Scheduler::SMOOTHSTEP: return "smoothstep";
  488. case Scheduler::SGM_UNIFORM: return "sgm_uniform";
  489. case Scheduler::SIMPLE: return "simple";
  490. default: return "default";
  491. }
  492. }
  493. std::string jobStatusToString(GenerationStatus status) {
  494. switch (status) {
  495. case GenerationStatus::QUEUED: return "queued";
  496. case GenerationStatus::PROCESSING: return "processing";
  497. case GenerationStatus::COMPLETED: return "completed";
  498. case GenerationStatus::FAILED: return "failed";
  499. default: return "unknown";
  500. }
  501. }
  502. GenerationStatus stringToJobStatus(const std::string& status) {
  503. if (status == "queued") return GenerationStatus::QUEUED;
  504. if (status == "processing") return GenerationStatus::PROCESSING;
  505. if (status == "completed") return GenerationStatus::COMPLETED;
  506. if (status == "failed") return GenerationStatus::FAILED;
  507. return GenerationStatus::QUEUED;
  508. }
  509. std::string jobTypeToString(JobType type) {
  510. switch (type) {
  511. case JobType::GENERATION: return "generation";
  512. case JobType::HASHING: return "hashing";
  513. default: return "unknown";
  514. }
  515. }
  516. JobType stringToJobType(const std::string& type) {
  517. if (type == "generation") return JobType::GENERATION;
  518. if (type == "hashing") return JobType::HASHING;
  519. return JobType::GENERATION;
  520. }
  521. void saveJobToFile(const JobInfo& job) {
  522. try {
  523. // Create queue directory if it doesn't exist
  524. std::filesystem::create_directories(queueDir);
  525. // Create JSON object
  526. nlohmann::json jobJson;
  527. jobJson["id"] = job.id;
  528. jobJson["type"] = jobTypeToString(job.type);
  529. jobJson["status"] = jobStatusToString(job.status);
  530. jobJson["prompt"] = job.prompt;
  531. jobJson["position"] = job.position;
  532. // Convert time points to milliseconds since epoch
  533. auto queuedMs = std::chrono::duration_cast<std::chrono::milliseconds>(
  534. job.queuedTime.time_since_epoch()).count();
  535. jobJson["queued_time"] = queuedMs;
  536. if (job.status != GenerationStatus::QUEUED) {
  537. auto startMs = std::chrono::duration_cast<std::chrono::milliseconds>(
  538. job.startTime.time_since_epoch()).count();
  539. jobJson["start_time"] = startMs;
  540. }
  541. if (job.status == GenerationStatus::COMPLETED || job.status == GenerationStatus::FAILED) {
  542. auto endMs = std::chrono::duration_cast<std::chrono::milliseconds>(
  543. job.endTime.time_since_epoch()).count();
  544. jobJson["end_time"] = endMs;
  545. }
  546. jobJson["output_files"] = job.outputFiles;
  547. jobJson["error_message"] = job.errorMessage;
  548. // Write to file
  549. std::string filename = queueDir + "/" + job.id + ".json";
  550. std::ofstream file(filename);
  551. if (file.is_open()) {
  552. file << jobJson.dump(2);
  553. file.close();
  554. }
  555. } catch (const std::exception& e) {
  556. std::cerr << "Error saving job to file: " << e.what() << std::endl;
  557. }
  558. }
  559. void loadJobsFromDisk() {
  560. try {
  561. if (!std::filesystem::exists(queueDir)) {
  562. return;
  563. }
  564. std::cout << "Loading persisted jobs from: " << queueDir << std::endl;
  565. int loadedCount = 0;
  566. for (const auto& entry : std::filesystem::directory_iterator(queueDir)) {
  567. if (entry.path().extension() != ".json") {
  568. continue;
  569. }
  570. try {
  571. std::ifstream file(entry.path());
  572. if (!file.is_open()) {
  573. continue;
  574. }
  575. nlohmann::json jobJson = nlohmann::json::parse(file);
  576. file.close();
  577. // Reconstruct JobInfo
  578. JobInfo job;
  579. job.id = jobJson["id"];
  580. job.type = stringToJobType(jobJson["type"]);
  581. job.status = stringToJobStatus(jobJson["status"]);
  582. job.prompt = jobJson["prompt"];
  583. job.position = jobJson["position"];
  584. // Reconstruct time points
  585. auto queuedMs = jobJson["queued_time"].get<int64_t>();
  586. job.queuedTime = std::chrono::steady_clock::time_point(
  587. std::chrono::milliseconds(queuedMs));
  588. if (jobJson.contains("start_time")) {
  589. auto startMs = jobJson["start_time"].get<int64_t>();
  590. job.startTime = std::chrono::steady_clock::time_point(
  591. std::chrono::milliseconds(startMs));
  592. }
  593. if (jobJson.contains("end_time")) {
  594. auto endMs = jobJson["end_time"].get<int64_t>();
  595. job.endTime = std::chrono::steady_clock::time_point(
  596. std::chrono::milliseconds(endMs));
  597. }
  598. if (jobJson.contains("output_files")) {
  599. job.outputFiles = jobJson["output_files"].get<std::vector<std::string>>();
  600. }
  601. if (jobJson.contains("error_message")) {
  602. job.errorMessage = jobJson["error_message"];
  603. }
  604. // Clean up stale processing jobs from server restart
  605. if (job.status == GenerationStatus::PROCESSING) {
  606. job.status = GenerationStatus::FAILED;
  607. job.errorMessage = "Server restarted while job was processing";
  608. job.endTime = std::chrono::steady_clock::now();
  609. std::cout << "Marked stale job as failed: " << job.id << std::endl;
  610. // Persist updated status to disk
  611. saveJobToFile(job);
  612. }
  613. // Add to active jobs
  614. std::lock_guard<std::mutex> lock(jobsMutex);
  615. activeJobs[job.id] = job;
  616. loadedCount++;
  617. } catch (const std::exception& e) {
  618. std::cerr << "Error loading job from " << entry.path() << ": " << e.what() << std::endl;
  619. }
  620. }
  621. if (loadedCount > 0) {
  622. std::cout << "Loaded " << loadedCount << " persisted job(s)" << std::endl;
  623. }
  624. } catch (const std::exception& e) {
  625. std::cerr << "Error loading jobs from disk: " << e.what() << std::endl;
  626. }
  627. }
  628. HashResult performHashJob(const HashRequest& request) {
  629. HashResult result;
  630. result.requestId = request.id;
  631. result.success = false;
  632. result.modelsHashed = 0;
  633. auto startTime = std::chrono::steady_clock::now();
  634. if (!modelManager) {
  635. result.errorMessage = "Model manager not available";
  636. result.status = GenerationStatus::FAILED;
  637. return result;
  638. }
  639. // Get list of models to hash
  640. std::vector<std::string> modelsToHash;
  641. if (request.modelNames.empty()) {
  642. // Hash all models without hashes
  643. auto allModels = modelManager->getAllModels();
  644. for (const auto& [name, info] : allModels) {
  645. if (info.sha256.empty() || request.forceRehash) {
  646. modelsToHash.push_back(name);
  647. }
  648. }
  649. } else {
  650. modelsToHash = request.modelNames;
  651. }
  652. std::cout << "Hashing " << modelsToHash.size() << " model(s)..." << std::endl;
  653. // Hash each model
  654. for (const auto& modelName : modelsToHash) {
  655. std::string hash = modelManager->ensureModelHash(modelName, request.forceRehash);
  656. if (!hash.empty()) {
  657. result.modelHashes[modelName] = hash;
  658. result.modelsHashed++;
  659. } else {
  660. std::cerr << "Failed to hash model: " << modelName << std::endl;
  661. }
  662. }
  663. auto endTime = std::chrono::steady_clock::now();
  664. result.hashingTime = std::chrono::duration_cast<std::chrono::milliseconds>(
  665. endTime - startTime).count();
  666. result.success = result.modelsHashed > 0;
  667. result.status = result.success ? GenerationStatus::COMPLETED : GenerationStatus::FAILED;
  668. if (!result.success) {
  669. result.errorMessage = "Failed to hash any models";
  670. }
  671. return result;
  672. }
  673. ConversionResult performConversionJob(const ConversionRequest& request) {
  674. ConversionResult result;
  675. result.requestId = request.id;
  676. result.success = false;
  677. auto startTime = std::chrono::steady_clock::now();
  678. std::cout << "Starting model conversion: " << request.modelName << std::endl;
  679. std::cout << " Input: " << request.modelPath << std::endl;
  680. std::cout << " Output: " << request.outputPath << std::endl;
  681. std::cout << " Quantization: " << request.quantizationType << std::endl;
  682. // Check if input file exists
  683. namespace fs = std::filesystem;
  684. if (!fs::exists(request.modelPath)) {
  685. result.errorMessage = "Input model file not found: " + request.modelPath;
  686. result.status = GenerationStatus::FAILED;
  687. return result;
  688. }
  689. // Get original file size
  690. try {
  691. auto originalSize = fs::file_size(request.modelPath);
  692. result.originalSize = formatFileSize(originalSize);
  693. } catch (const std::exception& e) {
  694. result.originalSize = "Unknown";
  695. }
  696. // Build conversion command
  697. // Get the sd binary path from the CMake installation directory
  698. std::string sdBinaryPath = "../build/stable-diffusion.cpp-install/bin/sd";
  699. std::stringstream cmd;
  700. cmd << sdBinaryPath << " --mode convert";
  701. cmd << " -m \"" << request.modelPath << "\"";
  702. cmd << " -o \"" << request.outputPath << "\"";
  703. cmd << " --type " << request.quantizationType;
  704. cmd << " 2>&1"; // Capture stderr
  705. std::cout << "Executing: " << cmd.str() << std::endl;
  706. // Execute conversion
  707. FILE* pipe = popen(cmd.str().c_str(), "r");
  708. if (!pipe) {
  709. result.errorMessage = "Failed to execute conversion command";
  710. result.status = GenerationStatus::FAILED;
  711. return result;
  712. }
  713. // Read command output
  714. char buffer[256];
  715. std::string output;
  716. while (fgets(buffer, sizeof(buffer), pipe) != nullptr) {
  717. output += buffer;
  718. std::cout << buffer; // Print progress
  719. }
  720. int exitCode = pclose(pipe);
  721. auto endTime = std::chrono::steady_clock::now();
  722. result.conversionTime = std::chrono::duration_cast<std::chrono::milliseconds>(
  723. endTime - startTime).count();
  724. if (exitCode != 0) {
  725. result.errorMessage = "Conversion failed with exit code " + std::to_string(exitCode);
  726. if (!output.empty()) {
  727. result.errorMessage += "\nOutput: " + output;
  728. }
  729. result.status = GenerationStatus::FAILED;
  730. return result;
  731. }
  732. // Check if output file was created
  733. if (!fs::exists(request.outputPath)) {
  734. result.errorMessage = "Output file was not created: " + request.outputPath;
  735. result.status = GenerationStatus::FAILED;
  736. return result;
  737. }
  738. // Get converted file size
  739. try {
  740. auto convertedSize = fs::file_size(request.outputPath);
  741. result.convertedSize = formatFileSize(convertedSize);
  742. } catch (const std::exception& e) {
  743. result.convertedSize = "Unknown";
  744. }
  745. result.success = true;
  746. result.status = GenerationStatus::COMPLETED;
  747. result.outputPath = request.outputPath;
  748. std::cout << "Conversion completed successfully!" << std::endl;
  749. std::cout << " Original size: " << result.originalSize << std::endl;
  750. std::cout << " Converted size: " << result.convertedSize << std::endl;
  751. std::cout << " Time: " << result.conversionTime << "ms" << std::endl;
  752. // Trigger model rescan after successful conversion
  753. if (modelManager) {
  754. std::cout << "Triggering model rescan..." << std::endl;
  755. modelManager->scanModelsDirectory();
  756. }
  757. return result;
  758. }
  759. std::string formatFileSize(size_t bytes) {
  760. const char* units[] = {"B", "KB", "MB", "GB", "TB"};
  761. int unitIndex = 0;
  762. double size = static_cast<double>(bytes);
  763. while (size >= 1024.0 && unitIndex < 4) {
  764. size /= 1024.0;
  765. unitIndex++;
  766. }
  767. std::stringstream ss;
  768. ss << std::fixed << std::setprecision(2) << size << " " << units[unitIndex];
  769. return ss.str();
  770. }
  771. };
  772. GenerationQueue::GenerationQueue(ModelManager* modelManager, int maxConcurrentGenerations,
  773. const std::string& queueDir, const std::string& outputDir)
  774. : pImpl(std::make_unique<Impl>()) {
  775. pImpl->modelManager = modelManager;
  776. pImpl->maxConcurrentGenerations = maxConcurrentGenerations;
  777. pImpl->queueDir = queueDir;
  778. pImpl->outputDir = outputDir;
  779. std::cout << "GenerationQueue initialized" << std::endl;
  780. std::cout << " Max concurrent generations: " << maxConcurrentGenerations << std::endl;
  781. std::cout << " Queue directory: " << queueDir << std::endl;
  782. std::cout << " Output directory: " << outputDir << std::endl;
  783. // Load any existing jobs from disk
  784. pImpl->loadJobsFromDisk();
  785. }
  786. GenerationQueue::~GenerationQueue() {
  787. stop();
  788. }
  789. std::future<GenerationResult> GenerationQueue::enqueueRequest(const GenerationRequest& request) {
  790. std::cout << "Enqueuing generation request: " << request.id << std::endl;
  791. std::cout << " Prompt: " << request.prompt.substr(0, 100)
  792. << (request.prompt.length() > 100 ? "..." : "") << std::endl;
  793. std::cout << " Model: " << request.modelName << std::endl;
  794. std::cout << " Size: " << request.width << "x" << request.height << std::endl;
  795. std::cout << " Steps: " << request.steps << ", CFG: " << request.cfgScale << std::endl;
  796. // Create promise and future
  797. auto promise = std::make_shared<std::promise<GenerationResult>>();
  798. auto future = promise->get_future();
  799. // Store the promise
  800. {
  801. std::lock_guard<std::mutex> lock(pImpl->jobsMutex);
  802. pImpl->jobPromises[request.id] = std::move(*promise);
  803. }
  804. // Add to queue
  805. {
  806. std::lock_guard<std::mutex> lock(pImpl->queueMutex);
  807. // Create job info
  808. JobInfo jobInfo;
  809. jobInfo.id = request.id;
  810. jobInfo.type = JobType::GENERATION;
  811. jobInfo.status = GenerationStatus::QUEUED;
  812. jobInfo.prompt = request.prompt; // Store full prompt
  813. jobInfo.queuedTime = std::chrono::steady_clock::now();
  814. jobInfo.position = pImpl->requestQueue.size() + 1;
  815. // Store job info
  816. {
  817. std::lock_guard<std::mutex> jobsLock(pImpl->jobsMutex);
  818. pImpl->activeJobs[request.id] = jobInfo;
  819. }
  820. // Persist to disk
  821. pImpl->saveJobToFile(jobInfo);
  822. pImpl->requestQueue.push(request);
  823. pImpl->queueSize.store(pImpl->requestQueue.size());
  824. }
  825. // Notify worker thread
  826. pImpl->queueCondition.notify_one();
  827. return future;
  828. }
  829. std::future<HashResult> GenerationQueue::enqueueHashRequest(const HashRequest& request) {
  830. auto promise = std::make_shared<std::promise<HashResult>>();
  831. auto future = promise->get_future();
  832. std::unique_lock<std::mutex> lock(pImpl->queueMutex);
  833. // Create a generation request that acts as a placeholder for hash job
  834. GenerationRequest hashJobPlaceholder;
  835. hashJobPlaceholder.id = request.id;
  836. hashJobPlaceholder.prompt = "HASH_JOB"; // Special marker
  837. hashJobPlaceholder.modelName = request.modelNames.empty() ? "ALL_MODELS" : request.modelNames[0];
  838. // Store promise for retrieval later
  839. pImpl->hashPromises[request.id] = promise;
  840. pImpl->hashRequests[request.id] = request;
  841. pImpl->requestQueue.push(hashJobPlaceholder);
  842. pImpl->queueCondition.notify_one();
  843. std::cout << "Enqueued hash request: " << request.id << std::endl;
  844. return future;
  845. }
  846. std::future<ConversionResult> GenerationQueue::enqueueConversionRequest(const ConversionRequest& request) {
  847. auto promise = std::make_shared<std::promise<ConversionResult>>();
  848. auto future = promise->get_future();
  849. std::unique_lock<std::mutex> lock(pImpl->queueMutex);
  850. // Create a generation request that acts as a placeholder for conversion job
  851. GenerationRequest conversionJobPlaceholder;
  852. conversionJobPlaceholder.id = request.id;
  853. conversionJobPlaceholder.prompt = "CONVERSION_JOB"; // Special marker
  854. conversionJobPlaceholder.modelName = request.modelName;
  855. // Store promise for retrieval later
  856. pImpl->conversionPromises[request.id] = promise;
  857. pImpl->conversionRequests[request.id] = request;
  858. pImpl->requestQueue.push(conversionJobPlaceholder);
  859. pImpl->queueCondition.notify_one();
  860. std::cout << "Enqueued conversion request: " << request.id << " (model: " << request.modelName << ", type: " << request.quantizationType << ")" << std::endl;
  861. return future;
  862. }
  863. size_t GenerationQueue::getQueueSize() const {
  864. return pImpl->queueSize.load();
  865. }
  866. size_t GenerationQueue::getActiveGenerations() const {
  867. return pImpl->activeGenerations.load();
  868. }
  869. std::vector<JobInfo> GenerationQueue::getQueueStatus() const {
  870. std::vector<JobInfo> jobs;
  871. std::lock_guard<std::mutex> lock(pImpl->jobsMutex);
  872. jobs.reserve(pImpl->activeJobs.size());
  873. for (const auto& pair : pImpl->activeJobs) {
  874. jobs.push_back(pair.second);
  875. }
  876. // Sort by queued time, then by status
  877. std::sort(jobs.begin(), jobs.end(), [](const JobInfo& a, const JobInfo& b) {
  878. if (a.status != b.status) {
  879. return static_cast<int>(a.status) < static_cast<int>(b.status);
  880. }
  881. return a.queuedTime < b.queuedTime;
  882. });
  883. return jobs;
  884. }
  885. JobInfo GenerationQueue::getJobInfo(const std::string& jobId) const {
  886. std::lock_guard<std::mutex> lock(pImpl->jobsMutex);
  887. auto it = pImpl->activeJobs.find(jobId);
  888. if (it != pImpl->activeJobs.end()) {
  889. return it->second;
  890. }
  891. return JobInfo{}; // Return empty job info if not found
  892. }
  893. bool GenerationQueue::cancelJob(const std::string& jobId) {
  894. std::lock_guard<std::mutex> queueLock(pImpl->queueMutex);
  895. std::lock_guard<std::mutex> jobsLock(pImpl->jobsMutex);
  896. // Check if job is still queued
  897. std::queue<GenerationRequest> newQueue;
  898. bool found = false;
  899. while (!pImpl->requestQueue.empty()) {
  900. GenerationRequest request = pImpl->requestQueue.front();
  901. pImpl->requestQueue.pop();
  902. if (request.id == jobId) {
  903. found = true;
  904. // Update job status
  905. auto it = pImpl->activeJobs.find(jobId);
  906. if (it != pImpl->activeJobs.end()) {
  907. it->second.status = GenerationStatus::FAILED;
  908. it->second.endTime = std::chrono::steady_clock::now();
  909. }
  910. // Set promise with cancellation error
  911. auto promiseIt = pImpl->jobPromises.find(jobId);
  912. if (promiseIt != pImpl->jobPromises.end()) {
  913. GenerationResult result;
  914. result.requestId = jobId;
  915. result.success = false;
  916. result.errorMessage = "Job cancelled by user";
  917. result.generationTime = 0;
  918. promiseIt->second.set_value(result);
  919. pImpl->jobPromises.erase(promiseIt);
  920. }
  921. } else {
  922. newQueue.push(request);
  923. }
  924. }
  925. pImpl->requestQueue = newQueue;
  926. pImpl->queueSize.store(pImpl->requestQueue.size());
  927. return found;
  928. }
  929. void GenerationQueue::clearQueue() {
  930. std::cout << "Clearing generation queue" << std::endl;
  931. std::lock_guard<std::mutex> queueLock(pImpl->queueMutex);
  932. std::lock_guard<std::mutex> jobsLock(pImpl->jobsMutex);
  933. // Cancel all queued jobs
  934. while (!pImpl->requestQueue.empty()) {
  935. GenerationRequest request = pImpl->requestQueue.front();
  936. pImpl->requestQueue.pop();
  937. // Update job status
  938. auto it = pImpl->activeJobs.find(request.id);
  939. if (it != pImpl->activeJobs.end()) {
  940. it->second.status = GenerationStatus::FAILED;
  941. it->second.endTime = std::chrono::steady_clock::now();
  942. }
  943. // Set promise with cancellation error
  944. auto promiseIt = pImpl->jobPromises.find(request.id);
  945. if (promiseIt != pImpl->jobPromises.end()) {
  946. GenerationResult result;
  947. result.requestId = request.id;
  948. result.success = false;
  949. result.errorMessage = "Queue cleared";
  950. result.generationTime = 0;
  951. promiseIt->second.set_value(result);
  952. pImpl->jobPromises.erase(promiseIt);
  953. }
  954. }
  955. pImpl->queueSize.store(0);
  956. }
  957. void GenerationQueue::start() {
  958. if (pImpl->running.load()) {
  959. std::cout << "GenerationQueue is already running" << std::endl;
  960. return;
  961. }
  962. pImpl->running.store(true);
  963. pImpl->stopRequested.store(false);
  964. pImpl->workerThread = std::thread(&Impl::workerThreadFunction, pImpl.get());
  965. std::cout << "GenerationQueue started" << std::endl;
  966. }
  967. void GenerationQueue::stop() {
  968. if (!pImpl->running.load()) {
  969. return;
  970. }
  971. std::cout << "Stopping GenerationQueue..." << std::endl;
  972. pImpl->stopRequested.store(true);
  973. pImpl->queueCondition.notify_all();
  974. if (pImpl->workerThread.joinable()) {
  975. pImpl->workerThread.join();
  976. }
  977. pImpl->running.store(false);
  978. // Clear any remaining promises
  979. std::lock_guard<std::mutex> lock(pImpl->jobsMutex);
  980. for (auto& pair : pImpl->jobPromises) {
  981. GenerationResult result;
  982. result.requestId = pair.first;
  983. result.success = false;
  984. result.errorMessage = "Queue stopped";
  985. result.generationTime = 0;
  986. pair.second.set_value(result);
  987. }
  988. pImpl->jobPromises.clear();
  989. std::cout << "GenerationQueue stopped" << std::endl;
  990. }
  991. bool GenerationQueue::isRunning() const {
  992. return pImpl->running.load();
  993. }
  994. void GenerationQueue::setMaxConcurrentGenerations(int maxConcurrent) {
  995. pImpl->maxConcurrentGenerations = maxConcurrent;
  996. std::cout << "GenerationQueue max concurrent generations set to: " << maxConcurrent << std::endl;
  997. }