| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036 |
- #include "generation_queue.h"
- #include "model_manager.h"
- #include "stable_diffusion_wrapper.h"
- #include "utils.h"
- #include <iostream>
- #include <random>
- #include <sstream>
- #include <iomanip>
- #include <algorithm>
- #include <fstream>
- #include <filesystem>
- #include <nlohmann/json.hpp>
- #define STB_IMAGE_WRITE_IMPLEMENTATION
- #include "../stable-diffusion.cpp-src/thirdparty/stb_image_write.h"
- #define STB_IMAGE_IMPLEMENTATION
- #include "../stable-diffusion.cpp-src/thirdparty/stb_image.h"
- class GenerationQueue::Impl {
- public:
- // Model manager reference
- ModelManager* modelManager = nullptr;
- // Thread management
- std::thread workerThread;
- std::atomic<bool> running{false};
- std::atomic<bool> stopRequested{false};
- // Queue management
- mutable std::mutex queueMutex;
- std::condition_variable queueCondition;
- std::queue<GenerationRequest> requestQueue;
- // Job tracking
- mutable std::mutex jobsMutex;
- std::unordered_map<std::string, JobInfo> activeJobs;
- std::unordered_map<std::string, std::promise<GenerationResult>> jobPromises;
- // Hash job tracking
- std::map<std::string, std::shared_ptr<std::promise<HashResult>>> hashPromises;
- std::map<std::string, HashRequest> hashRequests;
- // Conversion job tracking
- std::map<std::string, std::shared_ptr<std::promise<ConversionResult>>> conversionPromises;
- std::map<std::string, ConversionRequest> conversionRequests;
- // Configuration
- int maxConcurrentGenerations = 1;
- std::string queueDir = "./queue";
- std::string outputDir = "./output";
- // Statistics
- std::atomic<size_t> queueSize{0};
- std::atomic<size_t> activeGenerations{0};
- std::atomic<uint64_t> totalJobsProcessed{0};
- // Worker thread function
- void workerThreadFunction() {
- std::cout << "GenerationQueue worker thread started" << std::endl;
- while (running.load() && !stopRequested.load()) {
- std::unique_lock<std::mutex> lock(queueMutex);
- // Wait for a request or stop signal
- queueCondition.wait(lock, [this] {
- return !requestQueue.empty() || stopRequested.load();
- });
- if (stopRequested.load()) {
- break;
- }
- if (requestQueue.empty()) {
- continue;
- }
- // Get the next request
- GenerationRequest request = requestQueue.front();
- requestQueue.pop();
- queueSize.store(requestQueue.size());
- lock.unlock();
- // Process the request
- processRequest(request);
- }
- std::cout << "GenerationQueue worker thread stopped" << std::endl;
- }
- void processRequest(const GenerationRequest& request) {
- // Check if this is a hash job
- if (request.prompt == "HASH_JOB") {
- auto hashIt = hashRequests.find(request.id);
- if (hashIt != hashRequests.end()) {
- HashResult result = performHashJob(hashIt->second);
- auto promiseIt = hashPromises.find(request.id);
- if (promiseIt != hashPromises.end()) {
- promiseIt->second->set_value(result);
- hashPromises.erase(promiseIt);
- }
- hashRequests.erase(hashIt);
- }
- return;
- }
- // Check if this is a conversion job
- if (request.prompt == "CONVERSION_JOB") {
- auto convIt = conversionRequests.find(request.id);
- if (convIt != conversionRequests.end()) {
- ConversionResult result = performConversionJob(convIt->second);
- auto promiseIt = conversionPromises.find(request.id);
- if (promiseIt != conversionPromises.end()) {
- promiseIt->second->set_value(result);
- conversionPromises.erase(promiseIt);
- }
- conversionRequests.erase(convIt);
- }
- return;
- }
- auto startTime = std::chrono::steady_clock::now();
- // Update job status to PROCESSING
- {
- std::lock_guard<std::mutex> lock(jobsMutex);
- if (activeJobs.find(request.id) != activeJobs.end()) {
- activeJobs[request.id].status = GenerationStatus::PROCESSING;
- activeJobs[request.id].startTime = startTime;
- saveJobToFile(activeJobs[request.id]);
- }
- }
- activeGenerations.store(1); // Only one generation at a time
- std::cout << "Processing generation request: " << request.id
- << " (prompt: " << request.prompt.substr(0, 50)
- << (request.prompt.length() > 50 ? "..." : "") << ")" << std::endl;
- // Real generation logic using stable-diffusion.cpp
- GenerationResult result = performActualGeneration(request);
- auto endTime = std::chrono::steady_clock::now();
- auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
- result.generationTime = duration.count();
- // Update job status to COMPLETED/FAILED
- {
- std::lock_guard<std::mutex> lock(jobsMutex);
- if (activeJobs.find(request.id) != activeJobs.end()) {
- activeJobs[request.id].status = result.success ? GenerationStatus::COMPLETED : GenerationStatus::FAILED;
- activeJobs[request.id].endTime = endTime;
- // Store output files and error message
- activeJobs[request.id].outputFiles = result.imagePaths;
- activeJobs[request.id].errorMessage = result.errorMessage;
- // Persist to disk
- saveJobToFile(activeJobs[request.id]);
- }
- // Set the promise value
- auto it = jobPromises.find(request.id);
- if (it != jobPromises.end()) {
- it->second.set_value(result);
- jobPromises.erase(it);
- }
- }
- activeGenerations.store(0);
- totalJobsProcessed.fetch_add(1);
- std::cout << "Completed generation request: " << request.id
- << " (success: " << (result.success ? "true" : "false")
- << ", time: " << result.generationTime << "ms)";
- if (!result.success && !result.errorMessage.empty()) {
- std::cout << " - Error: " << result.errorMessage;
- }
- std::cout << std::endl;
- }
- GenerationResult performActualGeneration(const GenerationRequest& request) {
- GenerationResult result;
- result.requestId = request.id;
- result.success = false;
- // Check if model manager is available
- if (!modelManager) {
- result.errorMessage = "Model manager not available";
- return result;
- }
- // Check if the model is loaded (DO NOT auto-load)
- if (!modelManager->isModelLoaded(request.modelName)) {
- result.errorMessage = "Model not loaded: " + request.modelName + ". Please load the model first using POST /api/models/{hash}/load";
- return result;
- }
- // Get the model wrapper from the shared model manager
- auto* modelWrapper = modelManager->getModel(request.modelName);
- if (!modelWrapper) {
- result.errorMessage = "Model not found or not loaded: " + request.modelName;
- return result;
- }
- // Prepare generation parameters
- StableDiffusionWrapper::GenerationParams params;
- params.prompt = request.prompt;
- params.negativePrompt = request.negativePrompt;
- params.width = request.width;
- params.height = request.height;
- params.batchCount = request.batchCount;
- params.steps = request.steps;
- params.cfgScale = request.cfgScale;
- params.samplingMethod = samplingMethodToString(request.samplingMethod);
- params.scheduler = schedulerToString(request.scheduler);
- params.clipSkip = request.clipSkip;
- params.strength = request.strength;
- params.controlStrength = request.controlStrength;
- params.nThreads = request.nThreads;
- params.offloadParamsToCpu = request.offloadParamsToCpu;
- params.clipOnCpu = request.clipOnCpu;
- params.vaeOnCpu = request.vaeOnCpu;
- params.diffusionFlashAttn = request.diffusionFlashAttn;
- params.diffusionConvDirect = request.diffusionConvDirect;
- params.vaeConvDirect = request.vaeConvDirect;
- // Set model paths if provided
- params.modelPath = modelManager->getModelInfo(request.modelName).path;
- params.clipLPath = request.clipLPath;
- params.clipGPath = request.clipGPath;
- params.vaePath = request.vaePath;
- params.taesdPath = request.taesdPath;
- params.controlNetPath = request.controlNetPath;
- params.embeddingDir = request.embeddingDir;
- params.loraModelDir = request.loraModelDir;
- // Parse seed
- if (request.seed == "random") {
- std::random_device rd;
- std::mt19937 gen(rd());
- std::uniform_int_distribution<int64_t> dis;
- params.seed = dis(gen);
- } else {
- try {
- params.seed = std::stoll(request.seed);
- } catch (...) {
- params.seed = 42; // Default seed
- }
- }
- result.actualSeed = params.seed;
- // Generate images based on request type
- try {
- std::vector<StableDiffusionWrapper::GeneratedImage> generatedImages;
- switch (request.requestType) {
- case GenerationRequest::RequestType::TEXT2IMG:
- generatedImages = modelWrapper->generateImage(params);
- break;
- case GenerationRequest::RequestType::IMG2IMG:
- if (request.initImageData.empty()) {
- result.errorMessage = "No init image data provided for img2img";
- return result;
- }
- generatedImages = modelWrapper->generateImageImg2Img(
- params,
- request.initImageData,
- request.initImageWidth,
- request.initImageHeight
- );
- break;
- case GenerationRequest::RequestType::CONTROLNET:
- if (request.controlImageData.empty()) {
- result.errorMessage = "No control image data provided for ControlNet";
- return result;
- }
- generatedImages = modelWrapper->generateImageControlNet(
- params,
- request.controlImageData,
- request.controlImageWidth,
- request.controlImageHeight
- );
- break;
- case GenerationRequest::RequestType::UPSCALER:
- if (request.initImageData.empty()) {
- result.errorMessage = "No input image data provided for upscaling";
- return result;
- }
- if (request.esrganPath.empty()) {
- result.errorMessage = "No ESRGAN model path provided for upscaling";
- return result;
- }
- {
- auto upscaledImage = modelWrapper->upscaleImage(
- request.esrganPath,
- request.initImageData,
- request.initImageWidth,
- request.initImageHeight,
- request.initImageChannels,
- request.upscaleFactor,
- request.nThreads,
- request.offloadParamsToCpu,
- request.diffusionConvDirect
- );
- generatedImages.push_back(upscaledImage);
- }
- break;
- default:
- result.errorMessage = "Unknown request type";
- return result;
- }
- if (generatedImages.empty()) {
- result.errorMessage = "Failed to generate images: " + modelWrapper->getLastError();
- return result;
- }
- // Save generated images to files
- for (size_t i = 0; i < generatedImages.size(); i++) {
- const auto& image = generatedImages[i];
- std::string imagePath = saveImageToFile(image, request.id, i);
- if (!imagePath.empty()) {
- result.imagePaths.push_back(imagePath);
- } else {
- result.errorMessage = "Failed to save generated image " + std::to_string(i);
- return result;
- }
- }
- result.success = true;
- result.generationTime = generatedImages.empty() ? 0 : generatedImages[0].generationTime;
- result.errorMessage = "";
- } catch (const std::exception& e) {
- result.errorMessage = "Exception during generation: " + std::string(e.what());
- }
- return result;
- }
- std::string saveImageToFile(const StableDiffusionWrapper::GeneratedImage& image, const std::string& requestId, size_t index) {
- // Create job-specific output directory
- std::string jobOutputDir = outputDir + "/" + requestId;
- std::filesystem::create_directories(jobOutputDir);
- // Generate filename
- std::stringstream ss;
- ss << jobOutputDir << "/" << requestId << "_" << index << ".png";
- std::string filename = ss.str();
- // Check if image data is valid
- if (image.data.empty() || image.width <= 0 || image.height <= 0) {
- std::cerr << "Invalid image data: width=" << image.width
- << ", height=" << image.height
- << ", data_size=" << image.data.size() << std::endl;
- return "";
- }
- // Write PNG file using stb_image_write
- int result = stbi_write_png(
- filename.c_str(),
- image.width,
- image.height,
- image.channels,
- image.data.data(),
- image.width * image.channels // stride in bytes
- );
- if (result == 0) {
- std::cerr << "Failed to write PNG file: " << filename << std::endl;
- return "";
- }
- std::cout << "Saved generated image to: " << filename
- << " (" << image.width << "x" << image.height
- << ", " << image.channels << " channels, "
- << image.data.size() << " bytes)" << std::endl;
- return filename;
- }
- std::string samplingMethodToString(SamplingMethod method) {
- switch (method) {
- case SamplingMethod::EULER: return "euler";
- case SamplingMethod::EULER_A: return "euler_a";
- case SamplingMethod::HEUN: return "heun";
- case SamplingMethod::DPM2: return "dpm2";
- case SamplingMethod::DPMPP2S_A: return "dpmpp2s_a";
- case SamplingMethod::DPMPP2M: return "dpmpp2m";
- case SamplingMethod::DPMPP2MV2: return "dpmpp2mv2";
- case SamplingMethod::IPNDM: return "ipndm";
- case SamplingMethod::IPNDM_V: return "ipndm_v";
- case SamplingMethod::LCM: return "lcm";
- case SamplingMethod::DDIM_TRAILING: return "ddim_trailing";
- case SamplingMethod::TCD: return "tcd";
- default: return "euler";
- }
- }
- std::string schedulerToString(Scheduler scheduler) {
- switch (scheduler) {
- case Scheduler::DISCRETE: return "discrete";
- case Scheduler::KARRAS: return "karras";
- case Scheduler::EXPONENTIAL: return "exponential";
- case Scheduler::AYS: return "ays";
- case Scheduler::GITS: return "gits";
- case Scheduler::SMOOTHSTEP: return "smoothstep";
- case Scheduler::SGM_UNIFORM: return "sgm_uniform";
- case Scheduler::SIMPLE: return "simple";
- default: return "default";
- }
- }
- std::string jobStatusToString(GenerationStatus status) {
- switch (status) {
- case GenerationStatus::QUEUED: return "queued";
- case GenerationStatus::PROCESSING: return "processing";
- case GenerationStatus::COMPLETED: return "completed";
- case GenerationStatus::FAILED: return "failed";
- default: return "unknown";
- }
- }
- GenerationStatus stringToJobStatus(const std::string& status) {
- if (status == "queued") return GenerationStatus::QUEUED;
- if (status == "processing") return GenerationStatus::PROCESSING;
- if (status == "completed") return GenerationStatus::COMPLETED;
- if (status == "failed") return GenerationStatus::FAILED;
- return GenerationStatus::QUEUED;
- }
- std::string jobTypeToString(JobType type) {
- switch (type) {
- case JobType::GENERATION: return "generation";
- case JobType::HASHING: return "hashing";
- default: return "unknown";
- }
- }
- JobType stringToJobType(const std::string& type) {
- if (type == "generation") return JobType::GENERATION;
- if (type == "hashing") return JobType::HASHING;
- return JobType::GENERATION;
- }
- void saveJobToFile(const JobInfo& job) {
- try {
- // Create queue directory if it doesn't exist
- std::filesystem::create_directories(queueDir);
- // Create JSON object
- nlohmann::json jobJson;
- jobJson["id"] = job.id;
- jobJson["type"] = jobTypeToString(job.type);
- jobJson["status"] = jobStatusToString(job.status);
- jobJson["prompt"] = job.prompt;
- jobJson["position"] = job.position;
- // Convert time points to milliseconds since epoch
- auto queuedMs = std::chrono::duration_cast<std::chrono::milliseconds>(
- job.queuedTime.time_since_epoch()).count();
- jobJson["queued_time"] = queuedMs;
- if (job.status != GenerationStatus::QUEUED) {
- auto startMs = std::chrono::duration_cast<std::chrono::milliseconds>(
- job.startTime.time_since_epoch()).count();
- jobJson["start_time"] = startMs;
- }
- if (job.status == GenerationStatus::COMPLETED || job.status == GenerationStatus::FAILED) {
- auto endMs = std::chrono::duration_cast<std::chrono::milliseconds>(
- job.endTime.time_since_epoch()).count();
- jobJson["end_time"] = endMs;
- }
- jobJson["output_files"] = job.outputFiles;
- jobJson["error_message"] = job.errorMessage;
- // Write to file
- std::string filename = queueDir + "/" + job.id + ".json";
- std::ofstream file(filename);
- if (file.is_open()) {
- file << jobJson.dump(2);
- file.close();
- }
- } catch (const std::exception& e) {
- std::cerr << "Error saving job to file: " << e.what() << std::endl;
- }
- }
- void loadJobsFromDisk() {
- try {
- if (!std::filesystem::exists(queueDir)) {
- return;
- }
- std::cout << "Loading persisted jobs from: " << queueDir << std::endl;
- int loadedCount = 0;
- for (const auto& entry : std::filesystem::directory_iterator(queueDir)) {
- if (entry.path().extension() != ".json") {
- continue;
- }
- try {
- std::ifstream file(entry.path());
- if (!file.is_open()) {
- continue;
- }
- nlohmann::json jobJson = nlohmann::json::parse(file);
- file.close();
- // Reconstruct JobInfo
- JobInfo job;
- job.id = jobJson["id"];
- job.type = stringToJobType(jobJson["type"]);
- job.status = stringToJobStatus(jobJson["status"]);
- job.prompt = jobJson["prompt"];
- job.position = jobJson["position"];
- // Reconstruct time points
- auto queuedMs = jobJson["queued_time"].get<int64_t>();
- job.queuedTime = std::chrono::steady_clock::time_point(
- std::chrono::milliseconds(queuedMs));
- if (jobJson.contains("start_time")) {
- auto startMs = jobJson["start_time"].get<int64_t>();
- job.startTime = std::chrono::steady_clock::time_point(
- std::chrono::milliseconds(startMs));
- }
- if (jobJson.contains("end_time")) {
- auto endMs = jobJson["end_time"].get<int64_t>();
- job.endTime = std::chrono::steady_clock::time_point(
- std::chrono::milliseconds(endMs));
- }
- if (jobJson.contains("output_files")) {
- job.outputFiles = jobJson["output_files"].get<std::vector<std::string>>();
- }
- if (jobJson.contains("error_message")) {
- job.errorMessage = jobJson["error_message"];
- }
- // Add to active jobs
- std::lock_guard<std::mutex> lock(jobsMutex);
- activeJobs[job.id] = job;
- loadedCount++;
- } catch (const std::exception& e) {
- std::cerr << "Error loading job from " << entry.path() << ": " << e.what() << std::endl;
- }
- }
- if (loadedCount > 0) {
- std::cout << "Loaded " << loadedCount << " persisted job(s)" << std::endl;
- }
- } catch (const std::exception& e) {
- std::cerr << "Error loading jobs from disk: " << e.what() << std::endl;
- }
- }
- HashResult performHashJob(const HashRequest& request) {
- HashResult result;
- result.requestId = request.id;
- result.success = false;
- result.modelsHashed = 0;
- auto startTime = std::chrono::steady_clock::now();
- if (!modelManager) {
- result.errorMessage = "Model manager not available";
- result.status = GenerationStatus::FAILED;
- return result;
- }
- // Get list of models to hash
- std::vector<std::string> modelsToHash;
- if (request.modelNames.empty()) {
- // Hash all models without hashes
- auto allModels = modelManager->getAllModels();
- for (const auto& [name, info] : allModels) {
- if (info.sha256.empty() || request.forceRehash) {
- modelsToHash.push_back(name);
- }
- }
- } else {
- modelsToHash = request.modelNames;
- }
- std::cout << "Hashing " << modelsToHash.size() << " model(s)..." << std::endl;
- // Hash each model
- for (const auto& modelName : modelsToHash) {
- std::string hash = modelManager->ensureModelHash(modelName, request.forceRehash);
- if (!hash.empty()) {
- result.modelHashes[modelName] = hash;
- result.modelsHashed++;
- } else {
- std::cerr << "Failed to hash model: " << modelName << std::endl;
- }
- }
- auto endTime = std::chrono::steady_clock::now();
- result.hashingTime = std::chrono::duration_cast<std::chrono::milliseconds>(
- endTime - startTime).count();
- result.success = result.modelsHashed > 0;
- result.status = result.success ? GenerationStatus::COMPLETED : GenerationStatus::FAILED;
- if (!result.success) {
- result.errorMessage = "Failed to hash any models";
- }
- return result;
- }
- ConversionResult performConversionJob(const ConversionRequest& request) {
- ConversionResult result;
- result.requestId = request.id;
- result.success = false;
- auto startTime = std::chrono::steady_clock::now();
- std::cout << "Starting model conversion: " << request.modelName << std::endl;
- std::cout << " Input: " << request.modelPath << std::endl;
- std::cout << " Output: " << request.outputPath << std::endl;
- std::cout << " Quantization: " << request.quantizationType << std::endl;
- // Check if input file exists
- namespace fs = std::filesystem;
- if (!fs::exists(request.modelPath)) {
- result.errorMessage = "Input model file not found: " + request.modelPath;
- result.status = GenerationStatus::FAILED;
- return result;
- }
- // Get original file size
- try {
- auto originalSize = fs::file_size(request.modelPath);
- result.originalSize = formatFileSize(originalSize);
- } catch (const std::exception& e) {
- result.originalSize = "Unknown";
- }
- // Build conversion command
- // Get the sd binary path from the CMake installation directory
- std::string sdBinaryPath = "../build/stable-diffusion.cpp-install/bin/sd";
- std::stringstream cmd;
- cmd << sdBinaryPath << " --mode convert";
- cmd << " -m \"" << request.modelPath << "\"";
- cmd << " -o \"" << request.outputPath << "\"";
- cmd << " --type " << request.quantizationType;
- cmd << " 2>&1"; // Capture stderr
- std::cout << "Executing: " << cmd.str() << std::endl;
- // Execute conversion
- FILE* pipe = popen(cmd.str().c_str(), "r");
- if (!pipe) {
- result.errorMessage = "Failed to execute conversion command";
- result.status = GenerationStatus::FAILED;
- return result;
- }
- // Read command output
- char buffer[256];
- std::string output;
- while (fgets(buffer, sizeof(buffer), pipe) != nullptr) {
- output += buffer;
- std::cout << buffer; // Print progress
- }
- int exitCode = pclose(pipe);
- auto endTime = std::chrono::steady_clock::now();
- result.conversionTime = std::chrono::duration_cast<std::chrono::milliseconds>(
- endTime - startTime).count();
- if (exitCode != 0) {
- result.errorMessage = "Conversion failed with exit code " + std::to_string(exitCode);
- if (!output.empty()) {
- result.errorMessage += "\nOutput: " + output;
- }
- result.status = GenerationStatus::FAILED;
- return result;
- }
- // Check if output file was created
- if (!fs::exists(request.outputPath)) {
- result.errorMessage = "Output file was not created: " + request.outputPath;
- result.status = GenerationStatus::FAILED;
- return result;
- }
- // Get converted file size
- try {
- auto convertedSize = fs::file_size(request.outputPath);
- result.convertedSize = formatFileSize(convertedSize);
- } catch (const std::exception& e) {
- result.convertedSize = "Unknown";
- }
- result.success = true;
- result.status = GenerationStatus::COMPLETED;
- result.outputPath = request.outputPath;
- std::cout << "Conversion completed successfully!" << std::endl;
- std::cout << " Original size: " << result.originalSize << std::endl;
- std::cout << " Converted size: " << result.convertedSize << std::endl;
- std::cout << " Time: " << result.conversionTime << "ms" << std::endl;
- // Trigger model rescan after successful conversion
- if (modelManager) {
- std::cout << "Triggering model rescan..." << std::endl;
- modelManager->scanModelsDirectory();
- }
- return result;
- }
- std::string formatFileSize(size_t bytes) {
- const char* units[] = {"B", "KB", "MB", "GB", "TB"};
- int unitIndex = 0;
- double size = static_cast<double>(bytes);
- while (size >= 1024.0 && unitIndex < 4) {
- size /= 1024.0;
- unitIndex++;
- }
- std::stringstream ss;
- ss << std::fixed << std::setprecision(2) << size << " " << units[unitIndex];
- return ss.str();
- }
- };
- GenerationQueue::GenerationQueue(ModelManager* modelManager, int maxConcurrentGenerations,
- const std::string& queueDir, const std::string& outputDir)
- : pImpl(std::make_unique<Impl>()) {
- pImpl->modelManager = modelManager;
- pImpl->maxConcurrentGenerations = maxConcurrentGenerations;
- pImpl->queueDir = queueDir;
- pImpl->outputDir = outputDir;
- std::cout << "GenerationQueue initialized" << std::endl;
- std::cout << " Max concurrent generations: " << maxConcurrentGenerations << std::endl;
- std::cout << " Queue directory: " << queueDir << std::endl;
- std::cout << " Output directory: " << outputDir << std::endl;
- // Load any existing jobs from disk
- pImpl->loadJobsFromDisk();
- }
- GenerationQueue::~GenerationQueue() {
- stop();
- }
- std::future<GenerationResult> GenerationQueue::enqueueRequest(const GenerationRequest& request) {
- std::cout << "Enqueuing generation request: " << request.id << std::endl;
- std::cout << " Prompt: " << request.prompt.substr(0, 100)
- << (request.prompt.length() > 100 ? "..." : "") << std::endl;
- std::cout << " Model: " << request.modelName << std::endl;
- std::cout << " Size: " << request.width << "x" << request.height << std::endl;
- std::cout << " Steps: " << request.steps << ", CFG: " << request.cfgScale << std::endl;
- // Create promise and future
- auto promise = std::make_shared<std::promise<GenerationResult>>();
- auto future = promise->get_future();
- // Store the promise
- {
- std::lock_guard<std::mutex> lock(pImpl->jobsMutex);
- pImpl->jobPromises[request.id] = std::move(*promise);
- }
- // Add to queue
- {
- std::lock_guard<std::mutex> lock(pImpl->queueMutex);
- // Create job info
- JobInfo jobInfo;
- jobInfo.id = request.id;
- jobInfo.type = JobType::GENERATION;
- jobInfo.status = GenerationStatus::QUEUED;
- jobInfo.prompt = request.prompt; // Store full prompt
- jobInfo.queuedTime = std::chrono::steady_clock::now();
- jobInfo.position = pImpl->requestQueue.size() + 1;
- // Store job info
- {
- std::lock_guard<std::mutex> jobsLock(pImpl->jobsMutex);
- pImpl->activeJobs[request.id] = jobInfo;
- }
- // Persist to disk
- pImpl->saveJobToFile(jobInfo);
- pImpl->requestQueue.push(request);
- pImpl->queueSize.store(pImpl->requestQueue.size());
- }
- // Notify worker thread
- pImpl->queueCondition.notify_one();
- return future;
- }
- std::future<HashResult> GenerationQueue::enqueueHashRequest(const HashRequest& request) {
- auto promise = std::make_shared<std::promise<HashResult>>();
- auto future = promise->get_future();
- std::unique_lock<std::mutex> lock(pImpl->queueMutex);
- // Create a generation request that acts as a placeholder for hash job
- GenerationRequest hashJobPlaceholder;
- hashJobPlaceholder.id = request.id;
- hashJobPlaceholder.prompt = "HASH_JOB"; // Special marker
- hashJobPlaceholder.modelName = request.modelNames.empty() ? "ALL_MODELS" : request.modelNames[0];
- // Store promise for retrieval later
- pImpl->hashPromises[request.id] = promise;
- pImpl->hashRequests[request.id] = request;
- pImpl->requestQueue.push(hashJobPlaceholder);
- pImpl->queueCondition.notify_one();
- std::cout << "Enqueued hash request: " << request.id << std::endl;
- return future;
- }
- std::future<ConversionResult> GenerationQueue::enqueueConversionRequest(const ConversionRequest& request) {
- auto promise = std::make_shared<std::promise<ConversionResult>>();
- auto future = promise->get_future();
- std::unique_lock<std::mutex> lock(pImpl->queueMutex);
- // Create a generation request that acts as a placeholder for conversion job
- GenerationRequest conversionJobPlaceholder;
- conversionJobPlaceholder.id = request.id;
- conversionJobPlaceholder.prompt = "CONVERSION_JOB"; // Special marker
- conversionJobPlaceholder.modelName = request.modelName;
- // Store promise for retrieval later
- pImpl->conversionPromises[request.id] = promise;
- pImpl->conversionRequests[request.id] = request;
- pImpl->requestQueue.push(conversionJobPlaceholder);
- pImpl->queueCondition.notify_one();
- std::cout << "Enqueued conversion request: " << request.id << " (model: " << request.modelName << ", type: " << request.quantizationType << ")" << std::endl;
- return future;
- }
- size_t GenerationQueue::getQueueSize() const {
- return pImpl->queueSize.load();
- }
- size_t GenerationQueue::getActiveGenerations() const {
- return pImpl->activeGenerations.load();
- }
- std::vector<JobInfo> GenerationQueue::getQueueStatus() const {
- std::vector<JobInfo> jobs;
- std::lock_guard<std::mutex> lock(pImpl->jobsMutex);
- jobs.reserve(pImpl->activeJobs.size());
- for (const auto& pair : pImpl->activeJobs) {
- jobs.push_back(pair.second);
- }
- // Sort by queued time, then by status
- std::sort(jobs.begin(), jobs.end(), [](const JobInfo& a, const JobInfo& b) {
- if (a.status != b.status) {
- return static_cast<int>(a.status) < static_cast<int>(b.status);
- }
- return a.queuedTime < b.queuedTime;
- });
- return jobs;
- }
- JobInfo GenerationQueue::getJobInfo(const std::string& jobId) const {
- std::lock_guard<std::mutex> lock(pImpl->jobsMutex);
- auto it = pImpl->activeJobs.find(jobId);
- if (it != pImpl->activeJobs.end()) {
- return it->second;
- }
- return JobInfo{}; // Return empty job info if not found
- }
- bool GenerationQueue::cancelJob(const std::string& jobId) {
- std::lock_guard<std::mutex> queueLock(pImpl->queueMutex);
- std::lock_guard<std::mutex> jobsLock(pImpl->jobsMutex);
- // Check if job is still queued
- std::queue<GenerationRequest> newQueue;
- bool found = false;
- while (!pImpl->requestQueue.empty()) {
- GenerationRequest request = pImpl->requestQueue.front();
- pImpl->requestQueue.pop();
- if (request.id == jobId) {
- found = true;
- // Update job status
- auto it = pImpl->activeJobs.find(jobId);
- if (it != pImpl->activeJobs.end()) {
- it->second.status = GenerationStatus::FAILED;
- it->second.endTime = std::chrono::steady_clock::now();
- }
- // Set promise with cancellation error
- auto promiseIt = pImpl->jobPromises.find(jobId);
- if (promiseIt != pImpl->jobPromises.end()) {
- GenerationResult result;
- result.requestId = jobId;
- result.success = false;
- result.errorMessage = "Job cancelled by user";
- result.generationTime = 0;
- promiseIt->second.set_value(result);
- pImpl->jobPromises.erase(promiseIt);
- }
- } else {
- newQueue.push(request);
- }
- }
- pImpl->requestQueue = newQueue;
- pImpl->queueSize.store(pImpl->requestQueue.size());
- return found;
- }
- void GenerationQueue::clearQueue() {
- std::cout << "Clearing generation queue" << std::endl;
- std::lock_guard<std::mutex> queueLock(pImpl->queueMutex);
- std::lock_guard<std::mutex> jobsLock(pImpl->jobsMutex);
- // Cancel all queued jobs
- while (!pImpl->requestQueue.empty()) {
- GenerationRequest request = pImpl->requestQueue.front();
- pImpl->requestQueue.pop();
- // Update job status
- auto it = pImpl->activeJobs.find(request.id);
- if (it != pImpl->activeJobs.end()) {
- it->second.status = GenerationStatus::FAILED;
- it->second.endTime = std::chrono::steady_clock::now();
- }
- // Set promise with cancellation error
- auto promiseIt = pImpl->jobPromises.find(request.id);
- if (promiseIt != pImpl->jobPromises.end()) {
- GenerationResult result;
- result.requestId = request.id;
- result.success = false;
- result.errorMessage = "Queue cleared";
- result.generationTime = 0;
- promiseIt->second.set_value(result);
- pImpl->jobPromises.erase(promiseIt);
- }
- }
- pImpl->queueSize.store(0);
- }
- void GenerationQueue::start() {
- if (pImpl->running.load()) {
- std::cout << "GenerationQueue is already running" << std::endl;
- return;
- }
- pImpl->running.store(true);
- pImpl->stopRequested.store(false);
- pImpl->workerThread = std::thread(&Impl::workerThreadFunction, pImpl.get());
- std::cout << "GenerationQueue started" << std::endl;
- }
- void GenerationQueue::stop() {
- if (!pImpl->running.load()) {
- return;
- }
- std::cout << "Stopping GenerationQueue..." << std::endl;
- pImpl->stopRequested.store(true);
- pImpl->queueCondition.notify_all();
- if (pImpl->workerThread.joinable()) {
- pImpl->workerThread.join();
- }
- pImpl->running.store(false);
- // Clear any remaining promises
- std::lock_guard<std::mutex> lock(pImpl->jobsMutex);
- for (auto& pair : pImpl->jobPromises) {
- GenerationResult result;
- result.requestId = pair.first;
- result.success = false;
- result.errorMessage = "Queue stopped";
- result.generationTime = 0;
- pair.second.set_value(result);
- }
- pImpl->jobPromises.clear();
- std::cout << "GenerationQueue stopped" << std::endl;
- }
- bool GenerationQueue::isRunning() const {
- return pImpl->running.load();
- }
- void GenerationQueue::setMaxConcurrentGenerations(int maxConcurrent) {
- pImpl->maxConcurrentGenerations = maxConcurrent;
- std::cout << "GenerationQueue max concurrent generations set to: " << maxConcurrent << std::endl;
- }
|