generation_queue.cpp 38 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036
  1. #include "generation_queue.h"
  2. #include "model_manager.h"
  3. #include "stable_diffusion_wrapper.h"
  4. #include "utils.h"
  5. #include <iostream>
  6. #include <random>
  7. #include <sstream>
  8. #include <iomanip>
  9. #include <algorithm>
  10. #include <fstream>
  11. #include <filesystem>
  12. #include <nlohmann/json.hpp>
  13. #define STB_IMAGE_WRITE_IMPLEMENTATION
  14. #include "../stable-diffusion.cpp-src/thirdparty/stb_image_write.h"
  15. #define STB_IMAGE_IMPLEMENTATION
  16. #include "../stable-diffusion.cpp-src/thirdparty/stb_image.h"
  17. class GenerationQueue::Impl {
  18. public:
  19. // Model manager reference
  20. ModelManager* modelManager = nullptr;
  21. // Thread management
  22. std::thread workerThread;
  23. std::atomic<bool> running{false};
  24. std::atomic<bool> stopRequested{false};
  25. // Queue management
  26. mutable std::mutex queueMutex;
  27. std::condition_variable queueCondition;
  28. std::queue<GenerationRequest> requestQueue;
  29. // Job tracking
  30. mutable std::mutex jobsMutex;
  31. std::unordered_map<std::string, JobInfo> activeJobs;
  32. std::unordered_map<std::string, std::promise<GenerationResult>> jobPromises;
  33. // Hash job tracking
  34. std::map<std::string, std::shared_ptr<std::promise<HashResult>>> hashPromises;
  35. std::map<std::string, HashRequest> hashRequests;
  36. // Conversion job tracking
  37. std::map<std::string, std::shared_ptr<std::promise<ConversionResult>>> conversionPromises;
  38. std::map<std::string, ConversionRequest> conversionRequests;
  39. // Configuration
  40. int maxConcurrentGenerations = 1;
  41. std::string queueDir = "./queue";
  42. std::string outputDir = "./output";
  43. // Statistics
  44. std::atomic<size_t> queueSize{0};
  45. std::atomic<size_t> activeGenerations{0};
  46. std::atomic<uint64_t> totalJobsProcessed{0};
  47. // Worker thread function
  48. void workerThreadFunction() {
  49. std::cout << "GenerationQueue worker thread started" << std::endl;
  50. while (running.load() && !stopRequested.load()) {
  51. std::unique_lock<std::mutex> lock(queueMutex);
  52. // Wait for a request or stop signal
  53. queueCondition.wait(lock, [this] {
  54. return !requestQueue.empty() || stopRequested.load();
  55. });
  56. if (stopRequested.load()) {
  57. break;
  58. }
  59. if (requestQueue.empty()) {
  60. continue;
  61. }
  62. // Get the next request
  63. GenerationRequest request = requestQueue.front();
  64. requestQueue.pop();
  65. queueSize.store(requestQueue.size());
  66. lock.unlock();
  67. // Process the request
  68. processRequest(request);
  69. }
  70. std::cout << "GenerationQueue worker thread stopped" << std::endl;
  71. }
  72. void processRequest(const GenerationRequest& request) {
  73. // Check if this is a hash job
  74. if (request.prompt == "HASH_JOB") {
  75. auto hashIt = hashRequests.find(request.id);
  76. if (hashIt != hashRequests.end()) {
  77. HashResult result = performHashJob(hashIt->second);
  78. auto promiseIt = hashPromises.find(request.id);
  79. if (promiseIt != hashPromises.end()) {
  80. promiseIt->second->set_value(result);
  81. hashPromises.erase(promiseIt);
  82. }
  83. hashRequests.erase(hashIt);
  84. }
  85. return;
  86. }
  87. // Check if this is a conversion job
  88. if (request.prompt == "CONVERSION_JOB") {
  89. auto convIt = conversionRequests.find(request.id);
  90. if (convIt != conversionRequests.end()) {
  91. ConversionResult result = performConversionJob(convIt->second);
  92. auto promiseIt = conversionPromises.find(request.id);
  93. if (promiseIt != conversionPromises.end()) {
  94. promiseIt->second->set_value(result);
  95. conversionPromises.erase(promiseIt);
  96. }
  97. conversionRequests.erase(convIt);
  98. }
  99. return;
  100. }
  101. auto startTime = std::chrono::steady_clock::now();
  102. // Update job status to PROCESSING
  103. {
  104. std::lock_guard<std::mutex> lock(jobsMutex);
  105. if (activeJobs.find(request.id) != activeJobs.end()) {
  106. activeJobs[request.id].status = GenerationStatus::PROCESSING;
  107. activeJobs[request.id].startTime = startTime;
  108. saveJobToFile(activeJobs[request.id]);
  109. }
  110. }
  111. activeGenerations.store(1); // Only one generation at a time
  112. std::cout << "Processing generation request: " << request.id
  113. << " (prompt: " << request.prompt.substr(0, 50)
  114. << (request.prompt.length() > 50 ? "..." : "") << ")" << std::endl;
  115. // Real generation logic using stable-diffusion.cpp
  116. GenerationResult result = performActualGeneration(request);
  117. auto endTime = std::chrono::steady_clock::now();
  118. auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
  119. result.generationTime = duration.count();
  120. // Update job status to COMPLETED/FAILED
  121. {
  122. std::lock_guard<std::mutex> lock(jobsMutex);
  123. if (activeJobs.find(request.id) != activeJobs.end()) {
  124. activeJobs[request.id].status = result.success ? GenerationStatus::COMPLETED : GenerationStatus::FAILED;
  125. activeJobs[request.id].endTime = endTime;
  126. // Store output files and error message
  127. activeJobs[request.id].outputFiles = result.imagePaths;
  128. activeJobs[request.id].errorMessage = result.errorMessage;
  129. // Persist to disk
  130. saveJobToFile(activeJobs[request.id]);
  131. }
  132. // Set the promise value
  133. auto it = jobPromises.find(request.id);
  134. if (it != jobPromises.end()) {
  135. it->second.set_value(result);
  136. jobPromises.erase(it);
  137. }
  138. }
  139. activeGenerations.store(0);
  140. totalJobsProcessed.fetch_add(1);
  141. std::cout << "Completed generation request: " << request.id
  142. << " (success: " << (result.success ? "true" : "false")
  143. << ", time: " << result.generationTime << "ms)";
  144. if (!result.success && !result.errorMessage.empty()) {
  145. std::cout << " - Error: " << result.errorMessage;
  146. }
  147. std::cout << std::endl;
  148. }
  149. GenerationResult performActualGeneration(const GenerationRequest& request) {
  150. GenerationResult result;
  151. result.requestId = request.id;
  152. result.success = false;
  153. // Check if model manager is available
  154. if (!modelManager) {
  155. result.errorMessage = "Model manager not available";
  156. return result;
  157. }
  158. // Check if the model is loaded (DO NOT auto-load)
  159. if (!modelManager->isModelLoaded(request.modelName)) {
  160. result.errorMessage = "Model not loaded: " + request.modelName + ". Please load the model first using POST /api/models/{hash}/load";
  161. return result;
  162. }
  163. // Get the model wrapper from the shared model manager
  164. auto* modelWrapper = modelManager->getModel(request.modelName);
  165. if (!modelWrapper) {
  166. result.errorMessage = "Model not found or not loaded: " + request.modelName;
  167. return result;
  168. }
  169. // Prepare generation parameters
  170. StableDiffusionWrapper::GenerationParams params;
  171. params.prompt = request.prompt;
  172. params.negativePrompt = request.negativePrompt;
  173. params.width = request.width;
  174. params.height = request.height;
  175. params.batchCount = request.batchCount;
  176. params.steps = request.steps;
  177. params.cfgScale = request.cfgScale;
  178. params.samplingMethod = samplingMethodToString(request.samplingMethod);
  179. params.scheduler = schedulerToString(request.scheduler);
  180. params.clipSkip = request.clipSkip;
  181. params.strength = request.strength;
  182. params.controlStrength = request.controlStrength;
  183. params.nThreads = request.nThreads;
  184. params.offloadParamsToCpu = request.offloadParamsToCpu;
  185. params.clipOnCpu = request.clipOnCpu;
  186. params.vaeOnCpu = request.vaeOnCpu;
  187. params.diffusionFlashAttn = request.diffusionFlashAttn;
  188. params.diffusionConvDirect = request.diffusionConvDirect;
  189. params.vaeConvDirect = request.vaeConvDirect;
  190. // Set model paths if provided
  191. params.modelPath = modelManager->getModelInfo(request.modelName).path;
  192. params.clipLPath = request.clipLPath;
  193. params.clipGPath = request.clipGPath;
  194. params.vaePath = request.vaePath;
  195. params.taesdPath = request.taesdPath;
  196. params.controlNetPath = request.controlNetPath;
  197. params.embeddingDir = request.embeddingDir;
  198. params.loraModelDir = request.loraModelDir;
  199. // Parse seed
  200. if (request.seed == "random") {
  201. std::random_device rd;
  202. std::mt19937 gen(rd());
  203. std::uniform_int_distribution<int64_t> dis;
  204. params.seed = dis(gen);
  205. } else {
  206. try {
  207. params.seed = std::stoll(request.seed);
  208. } catch (...) {
  209. params.seed = 42; // Default seed
  210. }
  211. }
  212. result.actualSeed = params.seed;
  213. // Generate images based on request type
  214. try {
  215. std::vector<StableDiffusionWrapper::GeneratedImage> generatedImages;
  216. switch (request.requestType) {
  217. case GenerationRequest::RequestType::TEXT2IMG:
  218. generatedImages = modelWrapper->generateImage(params);
  219. break;
  220. case GenerationRequest::RequestType::IMG2IMG:
  221. if (request.initImageData.empty()) {
  222. result.errorMessage = "No init image data provided for img2img";
  223. return result;
  224. }
  225. generatedImages = modelWrapper->generateImageImg2Img(
  226. params,
  227. request.initImageData,
  228. request.initImageWidth,
  229. request.initImageHeight
  230. );
  231. break;
  232. case GenerationRequest::RequestType::CONTROLNET:
  233. if (request.controlImageData.empty()) {
  234. result.errorMessage = "No control image data provided for ControlNet";
  235. return result;
  236. }
  237. generatedImages = modelWrapper->generateImageControlNet(
  238. params,
  239. request.controlImageData,
  240. request.controlImageWidth,
  241. request.controlImageHeight
  242. );
  243. break;
  244. case GenerationRequest::RequestType::UPSCALER:
  245. if (request.initImageData.empty()) {
  246. result.errorMessage = "No input image data provided for upscaling";
  247. return result;
  248. }
  249. if (request.esrganPath.empty()) {
  250. result.errorMessage = "No ESRGAN model path provided for upscaling";
  251. return result;
  252. }
  253. {
  254. auto upscaledImage = modelWrapper->upscaleImage(
  255. request.esrganPath,
  256. request.initImageData,
  257. request.initImageWidth,
  258. request.initImageHeight,
  259. request.initImageChannels,
  260. request.upscaleFactor,
  261. request.nThreads,
  262. request.offloadParamsToCpu,
  263. request.diffusionConvDirect
  264. );
  265. generatedImages.push_back(upscaledImage);
  266. }
  267. break;
  268. default:
  269. result.errorMessage = "Unknown request type";
  270. return result;
  271. }
  272. if (generatedImages.empty()) {
  273. result.errorMessage = "Failed to generate images: " + modelWrapper->getLastError();
  274. return result;
  275. }
  276. // Save generated images to files
  277. for (size_t i = 0; i < generatedImages.size(); i++) {
  278. const auto& image = generatedImages[i];
  279. std::string imagePath = saveImageToFile(image, request.id, i);
  280. if (!imagePath.empty()) {
  281. result.imagePaths.push_back(imagePath);
  282. } else {
  283. result.errorMessage = "Failed to save generated image " + std::to_string(i);
  284. return result;
  285. }
  286. }
  287. result.success = true;
  288. result.generationTime = generatedImages.empty() ? 0 : generatedImages[0].generationTime;
  289. result.errorMessage = "";
  290. } catch (const std::exception& e) {
  291. result.errorMessage = "Exception during generation: " + std::string(e.what());
  292. }
  293. return result;
  294. }
  295. std::string saveImageToFile(const StableDiffusionWrapper::GeneratedImage& image, const std::string& requestId, size_t index) {
  296. // Create job-specific output directory
  297. std::string jobOutputDir = outputDir + "/" + requestId;
  298. std::filesystem::create_directories(jobOutputDir);
  299. // Generate filename
  300. std::stringstream ss;
  301. ss << jobOutputDir << "/" << requestId << "_" << index << ".png";
  302. std::string filename = ss.str();
  303. // Check if image data is valid
  304. if (image.data.empty() || image.width <= 0 || image.height <= 0) {
  305. std::cerr << "Invalid image data: width=" << image.width
  306. << ", height=" << image.height
  307. << ", data_size=" << image.data.size() << std::endl;
  308. return "";
  309. }
  310. // Write PNG file using stb_image_write
  311. int result = stbi_write_png(
  312. filename.c_str(),
  313. image.width,
  314. image.height,
  315. image.channels,
  316. image.data.data(),
  317. image.width * image.channels // stride in bytes
  318. );
  319. if (result == 0) {
  320. std::cerr << "Failed to write PNG file: " << filename << std::endl;
  321. return "";
  322. }
  323. std::cout << "Saved generated image to: " << filename
  324. << " (" << image.width << "x" << image.height
  325. << ", " << image.channels << " channels, "
  326. << image.data.size() << " bytes)" << std::endl;
  327. return filename;
  328. }
  329. std::string samplingMethodToString(SamplingMethod method) {
  330. switch (method) {
  331. case SamplingMethod::EULER: return "euler";
  332. case SamplingMethod::EULER_A: return "euler_a";
  333. case SamplingMethod::HEUN: return "heun";
  334. case SamplingMethod::DPM2: return "dpm2";
  335. case SamplingMethod::DPMPP2S_A: return "dpmpp2s_a";
  336. case SamplingMethod::DPMPP2M: return "dpmpp2m";
  337. case SamplingMethod::DPMPP2MV2: return "dpmpp2mv2";
  338. case SamplingMethod::IPNDM: return "ipndm";
  339. case SamplingMethod::IPNDM_V: return "ipndm_v";
  340. case SamplingMethod::LCM: return "lcm";
  341. case SamplingMethod::DDIM_TRAILING: return "ddim_trailing";
  342. case SamplingMethod::TCD: return "tcd";
  343. default: return "euler";
  344. }
  345. }
  346. std::string schedulerToString(Scheduler scheduler) {
  347. switch (scheduler) {
  348. case Scheduler::DISCRETE: return "discrete";
  349. case Scheduler::KARRAS: return "karras";
  350. case Scheduler::EXPONENTIAL: return "exponential";
  351. case Scheduler::AYS: return "ays";
  352. case Scheduler::GITS: return "gits";
  353. case Scheduler::SMOOTHSTEP: return "smoothstep";
  354. case Scheduler::SGM_UNIFORM: return "sgm_uniform";
  355. case Scheduler::SIMPLE: return "simple";
  356. default: return "default";
  357. }
  358. }
  359. std::string jobStatusToString(GenerationStatus status) {
  360. switch (status) {
  361. case GenerationStatus::QUEUED: return "queued";
  362. case GenerationStatus::PROCESSING: return "processing";
  363. case GenerationStatus::COMPLETED: return "completed";
  364. case GenerationStatus::FAILED: return "failed";
  365. default: return "unknown";
  366. }
  367. }
  368. GenerationStatus stringToJobStatus(const std::string& status) {
  369. if (status == "queued") return GenerationStatus::QUEUED;
  370. if (status == "processing") return GenerationStatus::PROCESSING;
  371. if (status == "completed") return GenerationStatus::COMPLETED;
  372. if (status == "failed") return GenerationStatus::FAILED;
  373. return GenerationStatus::QUEUED;
  374. }
  375. std::string jobTypeToString(JobType type) {
  376. switch (type) {
  377. case JobType::GENERATION: return "generation";
  378. case JobType::HASHING: return "hashing";
  379. default: return "unknown";
  380. }
  381. }
  382. JobType stringToJobType(const std::string& type) {
  383. if (type == "generation") return JobType::GENERATION;
  384. if (type == "hashing") return JobType::HASHING;
  385. return JobType::GENERATION;
  386. }
  387. void saveJobToFile(const JobInfo& job) {
  388. try {
  389. // Create queue directory if it doesn't exist
  390. std::filesystem::create_directories(queueDir);
  391. // Create JSON object
  392. nlohmann::json jobJson;
  393. jobJson["id"] = job.id;
  394. jobJson["type"] = jobTypeToString(job.type);
  395. jobJson["status"] = jobStatusToString(job.status);
  396. jobJson["prompt"] = job.prompt;
  397. jobJson["position"] = job.position;
  398. // Convert time points to milliseconds since epoch
  399. auto queuedMs = std::chrono::duration_cast<std::chrono::milliseconds>(
  400. job.queuedTime.time_since_epoch()).count();
  401. jobJson["queued_time"] = queuedMs;
  402. if (job.status != GenerationStatus::QUEUED) {
  403. auto startMs = std::chrono::duration_cast<std::chrono::milliseconds>(
  404. job.startTime.time_since_epoch()).count();
  405. jobJson["start_time"] = startMs;
  406. }
  407. if (job.status == GenerationStatus::COMPLETED || job.status == GenerationStatus::FAILED) {
  408. auto endMs = std::chrono::duration_cast<std::chrono::milliseconds>(
  409. job.endTime.time_since_epoch()).count();
  410. jobJson["end_time"] = endMs;
  411. }
  412. jobJson["output_files"] = job.outputFiles;
  413. jobJson["error_message"] = job.errorMessage;
  414. // Write to file
  415. std::string filename = queueDir + "/" + job.id + ".json";
  416. std::ofstream file(filename);
  417. if (file.is_open()) {
  418. file << jobJson.dump(2);
  419. file.close();
  420. }
  421. } catch (const std::exception& e) {
  422. std::cerr << "Error saving job to file: " << e.what() << std::endl;
  423. }
  424. }
  425. void loadJobsFromDisk() {
  426. try {
  427. if (!std::filesystem::exists(queueDir)) {
  428. return;
  429. }
  430. std::cout << "Loading persisted jobs from: " << queueDir << std::endl;
  431. int loadedCount = 0;
  432. for (const auto& entry : std::filesystem::directory_iterator(queueDir)) {
  433. if (entry.path().extension() != ".json") {
  434. continue;
  435. }
  436. try {
  437. std::ifstream file(entry.path());
  438. if (!file.is_open()) {
  439. continue;
  440. }
  441. nlohmann::json jobJson = nlohmann::json::parse(file);
  442. file.close();
  443. // Reconstruct JobInfo
  444. JobInfo job;
  445. job.id = jobJson["id"];
  446. job.type = stringToJobType(jobJson["type"]);
  447. job.status = stringToJobStatus(jobJson["status"]);
  448. job.prompt = jobJson["prompt"];
  449. job.position = jobJson["position"];
  450. // Reconstruct time points
  451. auto queuedMs = jobJson["queued_time"].get<int64_t>();
  452. job.queuedTime = std::chrono::steady_clock::time_point(
  453. std::chrono::milliseconds(queuedMs));
  454. if (jobJson.contains("start_time")) {
  455. auto startMs = jobJson["start_time"].get<int64_t>();
  456. job.startTime = std::chrono::steady_clock::time_point(
  457. std::chrono::milliseconds(startMs));
  458. }
  459. if (jobJson.contains("end_time")) {
  460. auto endMs = jobJson["end_time"].get<int64_t>();
  461. job.endTime = std::chrono::steady_clock::time_point(
  462. std::chrono::milliseconds(endMs));
  463. }
  464. if (jobJson.contains("output_files")) {
  465. job.outputFiles = jobJson["output_files"].get<std::vector<std::string>>();
  466. }
  467. if (jobJson.contains("error_message")) {
  468. job.errorMessage = jobJson["error_message"];
  469. }
  470. // Add to active jobs
  471. std::lock_guard<std::mutex> lock(jobsMutex);
  472. activeJobs[job.id] = job;
  473. loadedCount++;
  474. } catch (const std::exception& e) {
  475. std::cerr << "Error loading job from " << entry.path() << ": " << e.what() << std::endl;
  476. }
  477. }
  478. if (loadedCount > 0) {
  479. std::cout << "Loaded " << loadedCount << " persisted job(s)" << std::endl;
  480. }
  481. } catch (const std::exception& e) {
  482. std::cerr << "Error loading jobs from disk: " << e.what() << std::endl;
  483. }
  484. }
  485. HashResult performHashJob(const HashRequest& request) {
  486. HashResult result;
  487. result.requestId = request.id;
  488. result.success = false;
  489. result.modelsHashed = 0;
  490. auto startTime = std::chrono::steady_clock::now();
  491. if (!modelManager) {
  492. result.errorMessage = "Model manager not available";
  493. result.status = GenerationStatus::FAILED;
  494. return result;
  495. }
  496. // Get list of models to hash
  497. std::vector<std::string> modelsToHash;
  498. if (request.modelNames.empty()) {
  499. // Hash all models without hashes
  500. auto allModels = modelManager->getAllModels();
  501. for (const auto& [name, info] : allModels) {
  502. if (info.sha256.empty() || request.forceRehash) {
  503. modelsToHash.push_back(name);
  504. }
  505. }
  506. } else {
  507. modelsToHash = request.modelNames;
  508. }
  509. std::cout << "Hashing " << modelsToHash.size() << " model(s)..." << std::endl;
  510. // Hash each model
  511. for (const auto& modelName : modelsToHash) {
  512. std::string hash = modelManager->ensureModelHash(modelName, request.forceRehash);
  513. if (!hash.empty()) {
  514. result.modelHashes[modelName] = hash;
  515. result.modelsHashed++;
  516. } else {
  517. std::cerr << "Failed to hash model: " << modelName << std::endl;
  518. }
  519. }
  520. auto endTime = std::chrono::steady_clock::now();
  521. result.hashingTime = std::chrono::duration_cast<std::chrono::milliseconds>(
  522. endTime - startTime).count();
  523. result.success = result.modelsHashed > 0;
  524. result.status = result.success ? GenerationStatus::COMPLETED : GenerationStatus::FAILED;
  525. if (!result.success) {
  526. result.errorMessage = "Failed to hash any models";
  527. }
  528. return result;
  529. }
  530. ConversionResult performConversionJob(const ConversionRequest& request) {
  531. ConversionResult result;
  532. result.requestId = request.id;
  533. result.success = false;
  534. auto startTime = std::chrono::steady_clock::now();
  535. std::cout << "Starting model conversion: " << request.modelName << std::endl;
  536. std::cout << " Input: " << request.modelPath << std::endl;
  537. std::cout << " Output: " << request.outputPath << std::endl;
  538. std::cout << " Quantization: " << request.quantizationType << std::endl;
  539. // Check if input file exists
  540. namespace fs = std::filesystem;
  541. if (!fs::exists(request.modelPath)) {
  542. result.errorMessage = "Input model file not found: " + request.modelPath;
  543. result.status = GenerationStatus::FAILED;
  544. return result;
  545. }
  546. // Get original file size
  547. try {
  548. auto originalSize = fs::file_size(request.modelPath);
  549. result.originalSize = formatFileSize(originalSize);
  550. } catch (const std::exception& e) {
  551. result.originalSize = "Unknown";
  552. }
  553. // Build conversion command
  554. // Get the sd binary path from the CMake installation directory
  555. std::string sdBinaryPath = "../build/stable-diffusion.cpp-install/bin/sd";
  556. std::stringstream cmd;
  557. cmd << sdBinaryPath << " --mode convert";
  558. cmd << " -m \"" << request.modelPath << "\"";
  559. cmd << " -o \"" << request.outputPath << "\"";
  560. cmd << " --type " << request.quantizationType;
  561. cmd << " 2>&1"; // Capture stderr
  562. std::cout << "Executing: " << cmd.str() << std::endl;
  563. // Execute conversion
  564. FILE* pipe = popen(cmd.str().c_str(), "r");
  565. if (!pipe) {
  566. result.errorMessage = "Failed to execute conversion command";
  567. result.status = GenerationStatus::FAILED;
  568. return result;
  569. }
  570. // Read command output
  571. char buffer[256];
  572. std::string output;
  573. while (fgets(buffer, sizeof(buffer), pipe) != nullptr) {
  574. output += buffer;
  575. std::cout << buffer; // Print progress
  576. }
  577. int exitCode = pclose(pipe);
  578. auto endTime = std::chrono::steady_clock::now();
  579. result.conversionTime = std::chrono::duration_cast<std::chrono::milliseconds>(
  580. endTime - startTime).count();
  581. if (exitCode != 0) {
  582. result.errorMessage = "Conversion failed with exit code " + std::to_string(exitCode);
  583. if (!output.empty()) {
  584. result.errorMessage += "\nOutput: " + output;
  585. }
  586. result.status = GenerationStatus::FAILED;
  587. return result;
  588. }
  589. // Check if output file was created
  590. if (!fs::exists(request.outputPath)) {
  591. result.errorMessage = "Output file was not created: " + request.outputPath;
  592. result.status = GenerationStatus::FAILED;
  593. return result;
  594. }
  595. // Get converted file size
  596. try {
  597. auto convertedSize = fs::file_size(request.outputPath);
  598. result.convertedSize = formatFileSize(convertedSize);
  599. } catch (const std::exception& e) {
  600. result.convertedSize = "Unknown";
  601. }
  602. result.success = true;
  603. result.status = GenerationStatus::COMPLETED;
  604. result.outputPath = request.outputPath;
  605. std::cout << "Conversion completed successfully!" << std::endl;
  606. std::cout << " Original size: " << result.originalSize << std::endl;
  607. std::cout << " Converted size: " << result.convertedSize << std::endl;
  608. std::cout << " Time: " << result.conversionTime << "ms" << std::endl;
  609. // Trigger model rescan after successful conversion
  610. if (modelManager) {
  611. std::cout << "Triggering model rescan..." << std::endl;
  612. modelManager->scanModelsDirectory();
  613. }
  614. return result;
  615. }
  616. std::string formatFileSize(size_t bytes) {
  617. const char* units[] = {"B", "KB", "MB", "GB", "TB"};
  618. int unitIndex = 0;
  619. double size = static_cast<double>(bytes);
  620. while (size >= 1024.0 && unitIndex < 4) {
  621. size /= 1024.0;
  622. unitIndex++;
  623. }
  624. std::stringstream ss;
  625. ss << std::fixed << std::setprecision(2) << size << " " << units[unitIndex];
  626. return ss.str();
  627. }
  628. };
  629. GenerationQueue::GenerationQueue(ModelManager* modelManager, int maxConcurrentGenerations,
  630. const std::string& queueDir, const std::string& outputDir)
  631. : pImpl(std::make_unique<Impl>()) {
  632. pImpl->modelManager = modelManager;
  633. pImpl->maxConcurrentGenerations = maxConcurrentGenerations;
  634. pImpl->queueDir = queueDir;
  635. pImpl->outputDir = outputDir;
  636. std::cout << "GenerationQueue initialized" << std::endl;
  637. std::cout << " Max concurrent generations: " << maxConcurrentGenerations << std::endl;
  638. std::cout << " Queue directory: " << queueDir << std::endl;
  639. std::cout << " Output directory: " << outputDir << std::endl;
  640. // Load any existing jobs from disk
  641. pImpl->loadJobsFromDisk();
  642. }
  643. GenerationQueue::~GenerationQueue() {
  644. stop();
  645. }
  646. std::future<GenerationResult> GenerationQueue::enqueueRequest(const GenerationRequest& request) {
  647. std::cout << "Enqueuing generation request: " << request.id << std::endl;
  648. std::cout << " Prompt: " << request.prompt.substr(0, 100)
  649. << (request.prompt.length() > 100 ? "..." : "") << std::endl;
  650. std::cout << " Model: " << request.modelName << std::endl;
  651. std::cout << " Size: " << request.width << "x" << request.height << std::endl;
  652. std::cout << " Steps: " << request.steps << ", CFG: " << request.cfgScale << std::endl;
  653. // Create promise and future
  654. auto promise = std::make_shared<std::promise<GenerationResult>>();
  655. auto future = promise->get_future();
  656. // Store the promise
  657. {
  658. std::lock_guard<std::mutex> lock(pImpl->jobsMutex);
  659. pImpl->jobPromises[request.id] = std::move(*promise);
  660. }
  661. // Add to queue
  662. {
  663. std::lock_guard<std::mutex> lock(pImpl->queueMutex);
  664. // Create job info
  665. JobInfo jobInfo;
  666. jobInfo.id = request.id;
  667. jobInfo.type = JobType::GENERATION;
  668. jobInfo.status = GenerationStatus::QUEUED;
  669. jobInfo.prompt = request.prompt; // Store full prompt
  670. jobInfo.queuedTime = std::chrono::steady_clock::now();
  671. jobInfo.position = pImpl->requestQueue.size() + 1;
  672. // Store job info
  673. {
  674. std::lock_guard<std::mutex> jobsLock(pImpl->jobsMutex);
  675. pImpl->activeJobs[request.id] = jobInfo;
  676. }
  677. // Persist to disk
  678. pImpl->saveJobToFile(jobInfo);
  679. pImpl->requestQueue.push(request);
  680. pImpl->queueSize.store(pImpl->requestQueue.size());
  681. }
  682. // Notify worker thread
  683. pImpl->queueCondition.notify_one();
  684. return future;
  685. }
  686. std::future<HashResult> GenerationQueue::enqueueHashRequest(const HashRequest& request) {
  687. auto promise = std::make_shared<std::promise<HashResult>>();
  688. auto future = promise->get_future();
  689. std::unique_lock<std::mutex> lock(pImpl->queueMutex);
  690. // Create a generation request that acts as a placeholder for hash job
  691. GenerationRequest hashJobPlaceholder;
  692. hashJobPlaceholder.id = request.id;
  693. hashJobPlaceholder.prompt = "HASH_JOB"; // Special marker
  694. hashJobPlaceholder.modelName = request.modelNames.empty() ? "ALL_MODELS" : request.modelNames[0];
  695. // Store promise for retrieval later
  696. pImpl->hashPromises[request.id] = promise;
  697. pImpl->hashRequests[request.id] = request;
  698. pImpl->requestQueue.push(hashJobPlaceholder);
  699. pImpl->queueCondition.notify_one();
  700. std::cout << "Enqueued hash request: " << request.id << std::endl;
  701. return future;
  702. }
  703. std::future<ConversionResult> GenerationQueue::enqueueConversionRequest(const ConversionRequest& request) {
  704. auto promise = std::make_shared<std::promise<ConversionResult>>();
  705. auto future = promise->get_future();
  706. std::unique_lock<std::mutex> lock(pImpl->queueMutex);
  707. // Create a generation request that acts as a placeholder for conversion job
  708. GenerationRequest conversionJobPlaceholder;
  709. conversionJobPlaceholder.id = request.id;
  710. conversionJobPlaceholder.prompt = "CONVERSION_JOB"; // Special marker
  711. conversionJobPlaceholder.modelName = request.modelName;
  712. // Store promise for retrieval later
  713. pImpl->conversionPromises[request.id] = promise;
  714. pImpl->conversionRequests[request.id] = request;
  715. pImpl->requestQueue.push(conversionJobPlaceholder);
  716. pImpl->queueCondition.notify_one();
  717. std::cout << "Enqueued conversion request: " << request.id << " (model: " << request.modelName << ", type: " << request.quantizationType << ")" << std::endl;
  718. return future;
  719. }
  720. size_t GenerationQueue::getQueueSize() const {
  721. return pImpl->queueSize.load();
  722. }
  723. size_t GenerationQueue::getActiveGenerations() const {
  724. return pImpl->activeGenerations.load();
  725. }
  726. std::vector<JobInfo> GenerationQueue::getQueueStatus() const {
  727. std::vector<JobInfo> jobs;
  728. std::lock_guard<std::mutex> lock(pImpl->jobsMutex);
  729. jobs.reserve(pImpl->activeJobs.size());
  730. for (const auto& pair : pImpl->activeJobs) {
  731. jobs.push_back(pair.second);
  732. }
  733. // Sort by queued time, then by status
  734. std::sort(jobs.begin(), jobs.end(), [](const JobInfo& a, const JobInfo& b) {
  735. if (a.status != b.status) {
  736. return static_cast<int>(a.status) < static_cast<int>(b.status);
  737. }
  738. return a.queuedTime < b.queuedTime;
  739. });
  740. return jobs;
  741. }
  742. JobInfo GenerationQueue::getJobInfo(const std::string& jobId) const {
  743. std::lock_guard<std::mutex> lock(pImpl->jobsMutex);
  744. auto it = pImpl->activeJobs.find(jobId);
  745. if (it != pImpl->activeJobs.end()) {
  746. return it->second;
  747. }
  748. return JobInfo{}; // Return empty job info if not found
  749. }
  750. bool GenerationQueue::cancelJob(const std::string& jobId) {
  751. std::lock_guard<std::mutex> queueLock(pImpl->queueMutex);
  752. std::lock_guard<std::mutex> jobsLock(pImpl->jobsMutex);
  753. // Check if job is still queued
  754. std::queue<GenerationRequest> newQueue;
  755. bool found = false;
  756. while (!pImpl->requestQueue.empty()) {
  757. GenerationRequest request = pImpl->requestQueue.front();
  758. pImpl->requestQueue.pop();
  759. if (request.id == jobId) {
  760. found = true;
  761. // Update job status
  762. auto it = pImpl->activeJobs.find(jobId);
  763. if (it != pImpl->activeJobs.end()) {
  764. it->second.status = GenerationStatus::FAILED;
  765. it->second.endTime = std::chrono::steady_clock::now();
  766. }
  767. // Set promise with cancellation error
  768. auto promiseIt = pImpl->jobPromises.find(jobId);
  769. if (promiseIt != pImpl->jobPromises.end()) {
  770. GenerationResult result;
  771. result.requestId = jobId;
  772. result.success = false;
  773. result.errorMessage = "Job cancelled by user";
  774. result.generationTime = 0;
  775. promiseIt->second.set_value(result);
  776. pImpl->jobPromises.erase(promiseIt);
  777. }
  778. } else {
  779. newQueue.push(request);
  780. }
  781. }
  782. pImpl->requestQueue = newQueue;
  783. pImpl->queueSize.store(pImpl->requestQueue.size());
  784. return found;
  785. }
  786. void GenerationQueue::clearQueue() {
  787. std::cout << "Clearing generation queue" << std::endl;
  788. std::lock_guard<std::mutex> queueLock(pImpl->queueMutex);
  789. std::lock_guard<std::mutex> jobsLock(pImpl->jobsMutex);
  790. // Cancel all queued jobs
  791. while (!pImpl->requestQueue.empty()) {
  792. GenerationRequest request = pImpl->requestQueue.front();
  793. pImpl->requestQueue.pop();
  794. // Update job status
  795. auto it = pImpl->activeJobs.find(request.id);
  796. if (it != pImpl->activeJobs.end()) {
  797. it->second.status = GenerationStatus::FAILED;
  798. it->second.endTime = std::chrono::steady_clock::now();
  799. }
  800. // Set promise with cancellation error
  801. auto promiseIt = pImpl->jobPromises.find(request.id);
  802. if (promiseIt != pImpl->jobPromises.end()) {
  803. GenerationResult result;
  804. result.requestId = request.id;
  805. result.success = false;
  806. result.errorMessage = "Queue cleared";
  807. result.generationTime = 0;
  808. promiseIt->second.set_value(result);
  809. pImpl->jobPromises.erase(promiseIt);
  810. }
  811. }
  812. pImpl->queueSize.store(0);
  813. }
  814. void GenerationQueue::start() {
  815. if (pImpl->running.load()) {
  816. std::cout << "GenerationQueue is already running" << std::endl;
  817. return;
  818. }
  819. pImpl->running.store(true);
  820. pImpl->stopRequested.store(false);
  821. pImpl->workerThread = std::thread(&Impl::workerThreadFunction, pImpl.get());
  822. std::cout << "GenerationQueue started" << std::endl;
  823. }
  824. void GenerationQueue::stop() {
  825. if (!pImpl->running.load()) {
  826. return;
  827. }
  828. std::cout << "Stopping GenerationQueue..." << std::endl;
  829. pImpl->stopRequested.store(true);
  830. pImpl->queueCondition.notify_all();
  831. if (pImpl->workerThread.joinable()) {
  832. pImpl->workerThread.join();
  833. }
  834. pImpl->running.store(false);
  835. // Clear any remaining promises
  836. std::lock_guard<std::mutex> lock(pImpl->jobsMutex);
  837. for (auto& pair : pImpl->jobPromises) {
  838. GenerationResult result;
  839. result.requestId = pair.first;
  840. result.success = false;
  841. result.errorMessage = "Queue stopped";
  842. result.generationTime = 0;
  843. pair.second.set_value(result);
  844. }
  845. pImpl->jobPromises.clear();
  846. std::cout << "GenerationQueue stopped" << std::endl;
  847. }
  848. bool GenerationQueue::isRunning() const {
  849. return pImpl->running.load();
  850. }
  851. void GenerationQueue::setMaxConcurrentGenerations(int maxConcurrent) {
  852. pImpl->maxConcurrentGenerations = maxConcurrent;
  853. std::cout << "GenerationQueue max concurrent generations set to: " << maxConcurrent << std::endl;
  854. }