| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209 |
- #include "stable_diffusion_wrapper.h"
- #include <algorithm>
- #include <chrono>
- #include <cstring>
- #include <filesystem>
- #include <thread>
- #include "logger.h"
- #include "model_detector.h"
- extern "C" {
- #include "stable-diffusion.h"
- }
- class StableDiffusionWrapper::Impl {
- public:
- sd_ctx_t* sdContext = nullptr;
- std::string lastError;
- std::mutex contextMutex;
- bool verbose = false;
- std::string currentModelPath;
- StableDiffusionWrapper::GenerationParams currentModelParams;
- Impl() {
- // Initialize any required resources
- }
- ~Impl() {
- unloadModel();
- }
- bool loadModel(const std::string& modelPath, const StableDiffusionWrapper::GenerationParams& params) {
- std::lock_guard<std::mutex> lock(contextMutex);
- // Store verbose flag for use in other functions
- verbose = params.verbose;
- // Unload any existing model
- if (sdContext) {
- free_sd_ctx(sdContext);
- sdContext = nullptr;
- }
- // Initialize context parameters
- sd_ctx_params_t ctxParams;
- sd_ctx_params_init(&ctxParams);
- ctxParams.free_params_immediately = false; // avoid segfault when reusing
- // Get absolute path for logging
- std::filesystem::path absModelPath = std::filesystem::absolute(modelPath);
- if (params.verbose) {
- LOG_DEBUG("Loading model from absolute path: " + std::filesystem::absolute(modelPath).string());
- }
- // Create persistent string copies to fix lifetime issues
- // These strings will remain valid for the entire lifetime of the context
- std::string persistentModelPath = modelPath;
- std::string persistentClipLPath = params.clipLPath;
- std::string persistentClipGPath = params.clipGPath;
- std::string persistentVaePath = params.vaePath;
- std::string persistentTaesdPath = params.taesdPath;
- std::string persistentControlNetPath = params.controlNetPath;
- std::string persistentLoraModelDir = params.loraModelDir;
- std::string persistentEmbeddingDir = params.embeddingDir;
- // Use folder-based path selection with enhanced logic
- bool useDiffusionModelPath = false;
- std::string detectionSource = "folder";
- // Check if model is in diffusion_models directory by examining the path
- std::filesystem::path modelFilePath(modelPath);
- std::filesystem::path parentDir = modelFilePath.parent_path();
- std::string parentDirName = parentDir.filename().string();
- std::string modelFileName = modelFilePath.filename().string();
- // Convert to lowercase for comparison
- std::transform(parentDirName.begin(), parentDirName.end(), parentDirName.begin(), ::tolower);
- std::transform(modelFileName.begin(), modelFileName.end(), modelFileName.begin(), ::tolower);
- // Variables for fallback detection
- ModelDetectionResult detectionResult;
- bool detectionSuccessful = false;
- bool isQwenModel = false;
- // Check if this is a Qwen model based on filename
- if (modelFileName.find("qwen") != std::string::npos) {
- isQwenModel = true;
- if (params.verbose) {
- LOG_DEBUG("Detected Qwen model from filename: " + modelFileName);
- }
- }
- // Enhanced path selection logic
- if (parentDirName == "diffusion_models" || parentDirName == "diffusion") {
- useDiffusionModelPath = true;
- if (params.verbose) {
- LOG_DEBUG("Model is in " + parentDirName + " directory, using diffusion_model_path");
- }
- } else if (parentDirName == "checkpoints" || parentDirName == "stable-diffusion") {
- useDiffusionModelPath = false;
- if (params.verbose) {
- LOG_DEBUG("Model is in " + parentDirName + " directory, using model_path");
- }
- } else if (parentDirName == "sd_models" || parentDirName.empty()) {
- // Handle models in root /data/SD_MODELS/ directory
- if (isQwenModel) {
- // Qwen models should use diffusion_model_path regardless of directory
- useDiffusionModelPath = true;
- detectionSource = "qwen_root_detection";
- if (params.verbose) {
- LOG_DEBUG("Qwen model in root directory, preferring diffusion_model_path");
- }
- } else {
- // For non-Qwen models in root, try architecture detection
- if (params.verbose) {
- LOG_DEBUG("Model is in root directory '" + parentDirName + "', attempting architecture detection");
- }
- detectionSource = "architecture_fallback";
- try {
- detectionResult = ModelDetector::detectModel(modelPath);
- detectionSuccessful = true;
- if (params.verbose) {
- LOG_DEBUG("Architecture detection found: " + detectionResult.architectureName);
- }
- } catch (const std::exception& e) {
- LOG_ERROR("Warning: Architecture detection failed: " + std::string(e.what()) + ". Using default loading method.");
- detectionResult.architecture = ModelArchitecture::UNKNOWN;
- detectionResult.architectureName = "Unknown";
- }
- if (detectionSuccessful) {
- switch (detectionResult.architecture) {
- case ModelArchitecture::FLUX_SCHNELL:
- case ModelArchitecture::FLUX_DEV:
- case ModelArchitecture::FLUX_CHROMA:
- case ModelArchitecture::SD_3:
- case ModelArchitecture::QWEN2VL:
- // Modern architectures use diffusion_model_path
- useDiffusionModelPath = true;
- break;
- case ModelArchitecture::SD_1_5:
- case ModelArchitecture::SD_2_1:
- case ModelArchitecture::SDXL_BASE:
- case ModelArchitecture::SDXL_REFINER:
- // Traditional SD models use model_path
- useDiffusionModelPath = false;
- break;
- case ModelArchitecture::UNKNOWN:
- default:
- // Unknown architectures fall back to model_path for backward compatibility
- useDiffusionModelPath = false;
- if (params.verbose) {
- LOG_WARNING("Warning: Unknown model architecture detected, using default model_path for backward compatibility");
- }
- break;
- }
- } else {
- useDiffusionModelPath = false; // Default fallback
- detectionSource = "default_fallback";
- }
- }
- } else {
- // Unknown directory - try architecture detection
- if (params.verbose) {
- LOG_DEBUG("Model is in unknown directory '" + parentDirName + "', attempting architecture detection as fallback");
- }
- detectionSource = "architecture_fallback";
- try {
- detectionResult = ModelDetector::detectModel(modelPath);
- detectionSuccessful = true;
- if (params.verbose) {
- LOG_DEBUG("Fallback detection found architecture: " + detectionResult.architectureName);
- }
- } catch (const std::exception& e) {
- LOG_ERROR("Warning: Fallback model detection failed: " + std::string(e.what()) + ". Using default loading method.");
- detectionResult.architecture = ModelArchitecture::UNKNOWN;
- detectionResult.architectureName = "Unknown";
- }
- if (detectionSuccessful) {
- switch (detectionResult.architecture) {
- case ModelArchitecture::FLUX_SCHNELL:
- case ModelArchitecture::FLUX_DEV:
- case ModelArchitecture::FLUX_CHROMA:
- case ModelArchitecture::SD_3:
- case ModelArchitecture::QWEN2VL:
- // Modern architectures use diffusion_model_path
- useDiffusionModelPath = true;
- break;
- case ModelArchitecture::SD_1_5:
- case ModelArchitecture::SD_2_1:
- case ModelArchitecture::SDXL_BASE:
- case ModelArchitecture::SDXL_REFINER:
- // Traditional SD models use model_path
- useDiffusionModelPath = false;
- break;
- case ModelArchitecture::UNKNOWN:
- default:
- // Unknown architectures fall back to model_path for backward compatibility
- useDiffusionModelPath = false;
- if (params.verbose) {
- LOG_WARNING("Warning: Unknown model architecture detected, using default model_path for backward compatibility");
- }
- break;
- }
- } else {
- useDiffusionModelPath = false; // Default fallback
- detectionSource = "default_fallback";
- }
- }
- // Set the appropriate model path based on folder location or fallback detection
- if (useDiffusionModelPath) {
- ctxParams.diffusion_model_path = persistentModelPath.c_str();
- ctxParams.model_path = nullptr; // Clear the traditional path
- if (params.verbose) {
- LOG_DEBUG("Using diffusion_model_path (source: " + detectionSource + ")");
- }
- } else {
- ctxParams.model_path = persistentModelPath.c_str();
- ctxParams.diffusion_model_path = nullptr; // Clear the modern path
- if (params.verbose) {
- LOG_DEBUG("Using model_path (source: " + detectionSource + ")");
- }
- }
- // Set optional model paths using persistent strings to fix lifetime issues
- if (!persistentClipLPath.empty()) {
- ctxParams.clip_l_path = persistentClipLPath.c_str();
- if (params.verbose) {
- LOG_DEBUG("Using CLIP-L path: " + std::filesystem::absolute(persistentClipLPath).string());
- }
- }
- if (!persistentClipGPath.empty()) {
- ctxParams.clip_g_path = persistentClipGPath.c_str();
- if (params.verbose) {
- LOG_DEBUG("Using CLIP-G path: " + std::filesystem::absolute(persistentClipGPath).string());
- }
- }
- if (!persistentVaePath.empty()) {
- // Check if VAE file exists before setting it
- if (std::filesystem::exists(persistentVaePath)) {
- ctxParams.vae_path = persistentVaePath.c_str();
- if (params.verbose) {
- LOG_DEBUG("Using VAE path: " + std::filesystem::absolute(persistentVaePath).string());
- }
- } else {
- if (params.verbose) {
- LOG_DEBUG("VAE file not found: " + std::filesystem::absolute(persistentVaePath).string() + " - continuing without VAE");
- }
- ctxParams.vae_path = nullptr;
- }
- }
- if (!persistentTaesdPath.empty()) {
- ctxParams.taesd_path = persistentTaesdPath.c_str();
- if (params.verbose) {
- LOG_DEBUG("Using TAESD path: " + std::filesystem::absolute(persistentTaesdPath).string());
- }
- }
- if (!persistentControlNetPath.empty()) {
- ctxParams.control_net_path = persistentControlNetPath.c_str();
- if (params.verbose) {
- LOG_DEBUG("Using ControlNet path: " + std::filesystem::absolute(persistentControlNetPath).string());
- }
- }
- if (!persistentLoraModelDir.empty()) {
- ctxParams.lora_model_dir = persistentLoraModelDir.c_str();
- if (params.verbose) {
- LOG_DEBUG("Using LoRA model directory: " + std::filesystem::absolute(persistentLoraModelDir).string());
- }
- }
- if (!persistentEmbeddingDir.empty()) {
- ctxParams.embedding_dir = persistentEmbeddingDir.c_str();
- if (params.verbose) {
- LOG_DEBUG("Using embedding directory: " + std::filesystem::absolute(persistentEmbeddingDir).string());
- }
- }
- // Set performance parameters
- ctxParams.n_threads = params.nThreads;
- ctxParams.offload_params_to_cpu = params.offloadParamsToCpu;
- ctxParams.keep_clip_on_cpu = params.clipOnCpu;
- ctxParams.keep_vae_on_cpu = params.vaeOnCpu;
- ctxParams.diffusion_flash_attn = params.diffusionFlashAttn;
- ctxParams.diffusion_conv_direct = params.diffusionConvDirect;
- ctxParams.vae_conv_direct = params.vaeConvDirect;
- // Set model type
- ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType);
- // Create the stable-diffusion context
- if (params.verbose) {
- LOG_DEBUG("Attempting to create stable-diffusion context with selected parameters...");
- }
- sdContext = new_sd_ctx(&ctxParams);
- if (!sdContext) {
- lastError = "Failed to create stable-diffusion context";
- LOG_ERROR("Error: " + lastError + " with initial attempt");
- // If we used diffusion_model_path and it failed, try fallback to model_path
- if (useDiffusionModelPath) {
- if (params.verbose) {
- LOG_WARNING("Warning: Failed to load with diffusion_model_path. Attempting fallback to model_path...");
- }
- // Re-initialize context parameters
- sd_ctx_params_init(&ctxParams);
- // Set fallback model path using persistent string
- ctxParams.model_path = persistentModelPath.c_str();
- ctxParams.diffusion_model_path = nullptr;
- // Re-apply other parameters using persistent strings
- if (!persistentClipLPath.empty()) {
- ctxParams.clip_l_path = persistentClipLPath.c_str();
- }
- if (!persistentClipGPath.empty()) {
- ctxParams.clip_g_path = persistentClipGPath.c_str();
- }
- if (!persistentVaePath.empty()) {
- // Check if VAE file exists before setting it
- if (std::filesystem::exists(persistentVaePath)) {
- ctxParams.vae_path = persistentVaePath.c_str();
- } else {
- ctxParams.vae_path = nullptr;
- }
- }
- if (!persistentTaesdPath.empty()) {
- ctxParams.taesd_path = persistentTaesdPath.c_str();
- }
- if (!persistentControlNetPath.empty()) {
- ctxParams.control_net_path = persistentControlNetPath.c_str();
- }
- if (!persistentLoraModelDir.empty()) {
- ctxParams.lora_model_dir = persistentLoraModelDir.c_str();
- }
- if (!persistentEmbeddingDir.empty()) {
- ctxParams.embedding_dir = persistentEmbeddingDir.c_str();
- }
- // Re-apply performance parameters
- ctxParams.n_threads = params.nThreads;
- ctxParams.offload_params_to_cpu = params.offloadParamsToCpu;
- ctxParams.keep_clip_on_cpu = params.clipOnCpu;
- ctxParams.keep_vae_on_cpu = params.vaeOnCpu;
- ctxParams.diffusion_flash_attn = params.diffusionFlashAttn;
- ctxParams.diffusion_conv_direct = params.diffusionConvDirect;
- ctxParams.vae_conv_direct = params.vaeConvDirect;
- // Re-apply model type
- ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType);
- if (params.verbose) {
- LOG_DEBUG("Attempting to create context with fallback model_path...");
- }
- // Try creating context again with fallback
- sdContext = new_sd_ctx(&ctxParams);
- if (!sdContext) {
- lastError = "Failed to create stable-diffusion context with both diffusion_model_path and model_path fallback";
- LOG_ERROR("Error: " + lastError);
- // Additional fallback: try with minimal parameters for GGUF models
- if (modelFileName.find(".gguf") != std::string::npos || modelFileName.find(".ggml") != std::string::npos) {
- if (params.verbose) {
- LOG_DEBUG("Detected GGUF/GGML model, attempting minimal parameter fallback...");
- }
- // Re-initialize with minimal parameters
- sd_ctx_params_init(&ctxParams);
- ctxParams.model_path = persistentModelPath.c_str();
- ctxParams.diffusion_model_path = nullptr;
- // Set only essential parameters for GGUF
- ctxParams.n_threads = params.nThreads;
- ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType);
- if (params.verbose) {
- LOG_DEBUG("Attempting to create context with minimal GGUF parameters...");
- }
- sdContext = new_sd_ctx(&ctxParams);
- if (!sdContext) {
- lastError = "Failed to create stable-diffusion context even with minimal GGUF parameters";
- LOG_ERROR("Error: " + lastError);
- return false;
- }
- if (params.verbose) {
- LOG_DEBUG("Successfully loaded GGUF model with minimal parameters: " + absModelPath.string());
- }
- } else {
- return false;
- }
- } else {
- if (params.verbose) {
- LOG_DEBUG("Successfully loaded model with fallback to model_path: " + absModelPath.string());
- }
- }
- } else {
- // Try minimal fallback for non-diffusion_model_path failures
- if (modelFileName.find(".gguf") != std::string::npos || modelFileName.find(".ggml") != std::string::npos) {
- if (params.verbose) {
- LOG_DEBUG("Detected GGUF/GGML model, attempting minimal parameter fallback...");
- }
- // Re-initialize with minimal parameters
- sd_ctx_params_init(&ctxParams);
- ctxParams.model_path = persistentModelPath.c_str();
- // Set only essential parameters for GGUF
- ctxParams.n_threads = params.nThreads;
- ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType);
- if (params.verbose) {
- LOG_DEBUG("Attempting to create context with minimal GGUF parameters...");
- }
- sdContext = new_sd_ctx(&ctxParams);
- if (!sdContext) {
- lastError = "Failed to create stable-diffusion context even with minimal GGUF parameters";
- LOG_ERROR("Error: " + lastError);
- return false;
- }
- if (params.verbose) {
- LOG_DEBUG("Successfully loaded GGUF model with minimal parameters: " + absModelPath.string());
- }
- } else {
- LOG_ERROR("Error: " + lastError);
- return false;
- }
- }
- }
- // Log successful loading with detection information
- if (params.verbose) {
- LOG_DEBUG("Successfully loaded model: " + absModelPath.string());
- LOG_DEBUG(" Detection source: " + detectionSource);
- LOG_DEBUG(" Loading method: " + std::string(useDiffusionModelPath ? "diffusion_model_path" : "model_path"));
- LOG_DEBUG(" Parent directory: " + parentDirName);
- LOG_DEBUG(" Model filename: " + modelFileName);
- }
- // Log additional model properties if architecture detection was performed
- if (detectionSuccessful && params.verbose) {
- LOG_DEBUG(" Architecture: " + detectionResult.architectureName);
- if (detectionResult.textEncoderDim > 0) {
- LOG_DEBUG(" Text encoder dimension: " + std::to_string(detectionResult.textEncoderDim));
- }
- if (detectionResult.needsVAE) {
- LOG_DEBUG(" Requires VAE: " + (detectionResult.recommendedVAE.empty() ? std::string("Yes") : detectionResult.recommendedVAE));
- }
- }
- // Store current model info for potential reload after upscaling
- currentModelPath = modelPath;
- currentModelParams = params;
- return true;
- }
- void unloadModel() {
- std::lock_guard<std::mutex> lock(contextMutex);
- if (sdContext) {
- free_sd_ctx(sdContext);
- sdContext = nullptr;
- if (verbose) {
- LOG_DEBUG("Unloaded stable-diffusion model");
- }
- }
- // Clear stored model info
- currentModelPath.clear();
- currentModelParams = StableDiffusionWrapper::GenerationParams();
- }
- bool isModelLoaded() const {
- return sdContext != nullptr;
- }
- std::vector<StableDiffusionWrapper::GeneratedImage> generateImage(
- const StableDiffusionWrapper::GenerationParams& params,
- StableDiffusionWrapper::ProgressCallback progressCallback,
- void* userData) {
- std::vector<StableDiffusionWrapper::GeneratedImage> results;
- if (!sdContext) {
- lastError = "No model loaded";
- return results;
- }
- auto startTime = std::chrono::high_resolution_clock::now();
- // Initialize generation parameters
- sd_img_gen_params_t genParams;
- sd_img_gen_params_init(&genParams);
- // Set basic parameters
- genParams.prompt = params.prompt.c_str();
- genParams.negative_prompt = params.negativePrompt.c_str();
- genParams.width = params.width;
- genParams.height = params.height;
- genParams.sample_params.sample_steps = params.steps;
- genParams.seed = params.seed;
- genParams.batch_count = params.batchCount;
- // Set sampling parameters
- genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod);
- genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler);
- genParams.sample_params.guidance.txt_cfg = params.cfgScale;
- // Set advanced parameters
- genParams.clip_skip = params.clipSkip;
- genParams.strength = params.strength;
- // Set progress callback if provided
- // Track callback data to ensure proper cleanup
- std::pair<StableDiffusionWrapper::ProgressCallback, void*>* callbackData = nullptr;
- if (progressCallback) {
- callbackData = new std::pair<StableDiffusionWrapper::ProgressCallback, void*>(progressCallback, userData);
- sd_set_progress_callback([](int step, int steps, float time, void* data) {
- auto* callbackData = static_cast<std::pair<StableDiffusionWrapper::ProgressCallback, void*>*>(data);
- if (callbackData) {
- callbackData->first(step, steps, time, callbackData->second);
- }
- },
- callbackData);
- }
- // Generate the image
- LOG_DEBUG("[TIMING_ANALYSIS] Starting generate_image() call");
- auto generationCallStart = std::chrono::high_resolution_clock::now();
- sd_image_t* sdImages = generate_image(sdContext, &genParams);
- auto generationCallEnd = std::chrono::high_resolution_clock::now();
- auto generationCallTime = std::chrono::duration_cast<std::chrono::milliseconds>(generationCallEnd - generationCallStart).count();
- LOG_DEBUG("[TIMING_ANALYSIS] generate_image() call completed in " + std::to_string(generationCallTime) + "ms");
- // Clear and clean up progress callback - FIX: Wait for any pending callbacks
- sd_set_progress_callback(nullptr, nullptr);
- // Add a small delay to ensure any in-flight callbacks complete before cleanup
- std::this_thread::sleep_for(std::chrono::milliseconds(10));
- if (callbackData) {
- delete callbackData;
- callbackData = nullptr;
- }
- auto endTime = std::chrono::high_resolution_clock::now();
- auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
- if (!sdImages) {
- lastError = "Failed to generate image";
- return results;
- }
- // Convert stable-diffusion images to our format
- for (int i = 0; i < params.batchCount; i++) {
- StableDiffusionWrapper::GeneratedImage image;
- image.width = sdImages[i].width;
- image.height = sdImages[i].height;
- image.channels = sdImages[i].channel;
- image.seed = params.seed;
- image.generationTime = duration.count();
- // Copy image data
- if (sdImages[i].data && sdImages[i].width > 0 && sdImages[i].height > 0 && sdImages[i].channel > 0) {
- size_t dataSize = sdImages[i].width * sdImages[i].height * sdImages[i].channel;
- image.data.resize(dataSize);
- std::memcpy(image.data.data(), sdImages[i].data, dataSize);
- }
- results.push_back(image);
- }
- // Free the generated images
- // Clean up each image's data array
- for (int i = 0; i < params.batchCount; i++) {
- if (sdImages[i].data) {
- free(sdImages[i].data);
- sdImages[i].data = nullptr;
- }
- }
- // Free the image array itself
- free(sdImages);
- return results;
- }
- std::vector<StableDiffusionWrapper::GeneratedImage> generateImageImg2Img(
- const StableDiffusionWrapper::GenerationParams& params,
- const std::vector<uint8_t>& inputData,
- int inputWidth,
- int inputHeight,
- StableDiffusionWrapper::ProgressCallback progressCallback,
- void* userData) {
- std::vector<StableDiffusionWrapper::GeneratedImage> results;
- if (!sdContext) {
- lastError = "No model loaded";
- return results;
- }
- auto startTime = std::chrono::high_resolution_clock::now();
- // Initialize generation parameters
- sd_img_gen_params_t genParams;
- sd_img_gen_params_init(&genParams);
- // Set basic parameters
- genParams.prompt = params.prompt.c_str();
- genParams.negative_prompt = params.negativePrompt.c_str();
- genParams.width = params.width;
- genParams.height = params.height;
- genParams.sample_params.sample_steps = params.steps;
- genParams.seed = params.seed;
- genParams.batch_count = params.batchCount;
- genParams.strength = params.strength;
- // Set sampling parameters
- genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod);
- genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler);
- genParams.sample_params.guidance.txt_cfg = params.cfgScale;
- // Set advanced parameters
- genParams.clip_skip = params.clipSkip;
- // Set input image
- sd_image_t initImage;
- initImage.width = inputWidth;
- initImage.height = inputHeight;
- initImage.channel = 3; // RGB
- initImage.data = const_cast<uint8_t*>(inputData.data());
- genParams.init_image = initImage;
- // Set progress callback if provided
- // Track callback data to ensure proper cleanup
- std::pair<StableDiffusionWrapper::ProgressCallback, void*>* callbackData = nullptr;
- if (progressCallback) {
- callbackData = new std::pair<StableDiffusionWrapper::ProgressCallback, void*>(progressCallback, userData);
- sd_set_progress_callback([](int step, int steps, float time, void* data) {
- auto* callbackData = static_cast<std::pair<StableDiffusionWrapper::ProgressCallback, void*>*>(data);
- if (callbackData) {
- callbackData->first(step, steps, time, callbackData->second);
- }
- },
- callbackData);
- }
- // Generate the image
- sd_image_t* sdImages = generate_image(sdContext, &genParams);
- // Clear and clean up progress callback - FIX: Wait for any pending callbacks
- sd_set_progress_callback(nullptr, nullptr);
- // Add a small delay to ensure any in-flight callbacks complete before cleanup
- std::this_thread::sleep_for(std::chrono::milliseconds(10));
- if (callbackData) {
- delete callbackData;
- callbackData = nullptr;
- }
- auto endTime = std::chrono::high_resolution_clock::now();
- auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
- if (!sdImages) {
- lastError = "Failed to generate image";
- return results;
- }
- // Convert stable-diffusion images to our format
- for (int i = 0; i < params.batchCount; i++) {
- StableDiffusionWrapper::GeneratedImage image;
- image.width = sdImages[i].width;
- image.height = sdImages[i].height;
- image.channels = sdImages[i].channel;
- image.seed = params.seed;
- image.generationTime = duration.count();
- // Copy image data
- if (sdImages[i].data && sdImages[i].width > 0 && sdImages[i].height > 0 && sdImages[i].channel > 0) {
- size_t dataSize = sdImages[i].width * sdImages[i].height * sdImages[i].channel;
- image.data.resize(dataSize);
- std::memcpy(image.data.data(), sdImages[i].data, dataSize);
- }
- results.push_back(image);
- }
- // Free the generated images
- // Clean up each image's data array
- for (int i = 0; i < params.batchCount; i++) {
- if (sdImages[i].data) {
- free(sdImages[i].data);
- sdImages[i].data = nullptr;
- }
- }
- // Free the image array itself
- free(sdImages);
- return results;
- }
- std::vector<StableDiffusionWrapper::GeneratedImage> generateImageControlNet(
- const StableDiffusionWrapper::GenerationParams& params,
- const std::vector<uint8_t>& controlData,
- int controlWidth,
- int controlHeight,
- StableDiffusionWrapper::ProgressCallback progressCallback,
- void* userData) {
- std::vector<StableDiffusionWrapper::GeneratedImage> results;
- if (!sdContext) {
- lastError = "No model loaded";
- return results;
- }
- auto startTime = std::chrono::high_resolution_clock::now();
- // Initialize generation parameters
- sd_img_gen_params_t genParams;
- sd_img_gen_params_init(&genParams);
- // Set basic parameters
- genParams.prompt = params.prompt.c_str();
- genParams.negative_prompt = params.negativePrompt.c_str();
- genParams.width = params.width;
- genParams.height = params.height;
- genParams.sample_params.sample_steps = params.steps;
- genParams.seed = params.seed;
- genParams.batch_count = params.batchCount;
- genParams.control_strength = params.controlStrength;
- // Set sampling parameters
- genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod);
- genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler);
- genParams.sample_params.guidance.txt_cfg = params.cfgScale;
- // Set advanced parameters
- genParams.clip_skip = params.clipSkip;
- // Set control image
- sd_image_t controlImage;
- controlImage.width = controlWidth;
- controlImage.height = controlHeight;
- controlImage.channel = 3; // RGB
- controlImage.data = const_cast<uint8_t*>(controlData.data());
- genParams.control_image = controlImage;
- // Set progress callback if provided
- // Track callback data to ensure proper cleanup
- std::pair<StableDiffusionWrapper::ProgressCallback, void*>* callbackData = nullptr;
- if (progressCallback) {
- callbackData = new std::pair<StableDiffusionWrapper::ProgressCallback, void*>(progressCallback, userData);
- sd_set_progress_callback([](int step, int steps, float time, void* data) {
- auto* callbackData = static_cast<std::pair<StableDiffusionWrapper::ProgressCallback, void*>*>(data);
- if (callbackData) {
- callbackData->first(step, steps, time, callbackData->second);
- }
- },
- callbackData);
- }
- // Generate the image
- sd_image_t* sdImages = generate_image(sdContext, &genParams);
- // Clear and clean up progress callback - FIX: Wait for any pending callbacks
- sd_set_progress_callback(nullptr, nullptr);
- // Add a small delay to ensure any in-flight callbacks complete before cleanup
- std::this_thread::sleep_for(std::chrono::milliseconds(10));
- if (callbackData) {
- delete callbackData;
- callbackData = nullptr;
- }
- auto endTime = std::chrono::high_resolution_clock::now();
- auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
- if (!sdImages) {
- lastError = "Failed to generate image";
- return results;
- }
- // Convert stable-diffusion images to our format
- for (int i = 0; i < params.batchCount; i++) {
- StableDiffusionWrapper::GeneratedImage image;
- image.width = sdImages[i].width;
- image.height = sdImages[i].height;
- image.channels = sdImages[i].channel;
- image.seed = params.seed;
- image.generationTime = duration.count();
- // Copy image data
- if (sdImages[i].data && sdImages[i].width > 0 && sdImages[i].height > 0 && sdImages[i].channel > 0) {
- size_t dataSize = sdImages[i].width * sdImages[i].height * sdImages[i].channel;
- image.data.resize(dataSize);
- std::memcpy(image.data.data(), sdImages[i].data, dataSize);
- }
- results.push_back(image);
- }
- // Free the generated images
- // Clean up each image's data array
- for (int i = 0; i < params.batchCount; i++) {
- if (sdImages[i].data) {
- free(sdImages[i].data);
- sdImages[i].data = nullptr;
- }
- }
- // Free the image array itself
- free(sdImages);
- return results;
- }
- std::vector<StableDiffusionWrapper::GeneratedImage> generateImageInpainting(
- const StableDiffusionWrapper::GenerationParams& params,
- const std::vector<uint8_t>& inputData,
- int inputWidth,
- int inputHeight,
- const std::vector<uint8_t>& maskData,
- int maskWidth,
- int maskHeight,
- StableDiffusionWrapper::ProgressCallback progressCallback,
- void* userData) {
- std::vector<StableDiffusionWrapper::GeneratedImage> results;
- if (!sdContext) {
- lastError = "No model loaded";
- return results;
- }
- auto startTime = std::chrono::high_resolution_clock::now();
- // Initialize generation parameters
- sd_img_gen_params_t genParams;
- sd_img_gen_params_init(&genParams);
- // Set basic parameters
- genParams.prompt = params.prompt.c_str();
- genParams.negative_prompt = params.negativePrompt.c_str();
- genParams.width = params.width;
- genParams.height = params.height;
- genParams.sample_params.sample_steps = params.steps;
- genParams.seed = params.seed;
- genParams.batch_count = params.batchCount;
- genParams.strength = params.strength;
- // Set sampling parameters
- genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod);
- genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler);
- genParams.sample_params.guidance.txt_cfg = params.cfgScale;
- // Set advanced parameters
- genParams.clip_skip = params.clipSkip;
- // Set input image
- sd_image_t initImage;
- initImage.width = inputWidth;
- initImage.height = inputHeight;
- initImage.channel = 3; // RGB
- initImage.data = const_cast<uint8_t*>(inputData.data());
- genParams.init_image = initImage;
- // Set mask image
- sd_image_t maskImage;
- maskImage.width = maskWidth;
- maskImage.height = maskHeight;
- maskImage.channel = 1; // Grayscale mask
- maskImage.data = const_cast<uint8_t*>(maskData.data());
- genParams.mask_image = maskImage;
- // Set progress callback if provided
- // Track callback data to ensure proper cleanup
- std::pair<StableDiffusionWrapper::ProgressCallback, void*>* callbackData = nullptr;
- if (progressCallback) {
- callbackData = new std::pair<StableDiffusionWrapper::ProgressCallback, void*>(progressCallback, userData);
- sd_set_progress_callback([](int step, int steps, float time, void* data) {
- auto* callbackData = static_cast<std::pair<StableDiffusionWrapper::ProgressCallback, void*>*>(data);
- if (callbackData) {
- callbackData->first(step, steps, time, callbackData->second);
- }
- },
- callbackData);
- }
- // Generate the image
- sd_image_t* sdImages = generate_image(sdContext, &genParams);
- // Clear and clean up progress callback - FIX: Wait for any pending callbacks
- sd_set_progress_callback(nullptr, nullptr);
- // Add a small delay to ensure any in-flight callbacks complete before cleanup
- std::this_thread::sleep_for(std::chrono::milliseconds(10));
- if (callbackData) {
- delete callbackData;
- callbackData = nullptr;
- }
- auto endTime = std::chrono::high_resolution_clock::now();
- auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
- if (!sdImages) {
- lastError = "Failed to generate image";
- return results;
- }
- // Convert stable-diffusion images to our format
- for (int i = 0; i < params.batchCount; i++) {
- StableDiffusionWrapper::GeneratedImage image;
- image.width = sdImages[i].width;
- image.height = sdImages[i].height;
- image.channels = sdImages[i].channel;
- image.seed = params.seed;
- image.generationTime = duration.count();
- // Copy image data
- if (sdImages[i].data && sdImages[i].width > 0 && sdImages[i].height > 0 && sdImages[i].channel > 0) {
- size_t dataSize = sdImages[i].width * sdImages[i].height * sdImages[i].channel;
- image.data.resize(dataSize);
- std::memcpy(image.data.data(), sdImages[i].data, dataSize);
- }
- results.push_back(image);
- }
- // Free the generated images
- // Clean up each image's data array
- for (int i = 0; i < params.batchCount; i++) {
- if (sdImages[i].data) {
- free(sdImages[i].data);
- sdImages[i].data = nullptr;
- }
- }
- // Free the image array itself
- free(sdImages);
- return results;
- }
- StableDiffusionWrapper::GeneratedImage upscaleImage(
- const std::string& esrganPath,
- const std::vector<uint8_t>& inputData,
- int inputWidth,
- int inputHeight,
- int inputChannels,
- uint32_t upscaleFactor,
- int nThreads,
- bool offloadParamsToCpu,
- bool direct) {
- StableDiffusionWrapper::GeneratedImage result;
- auto startTime = std::chrono::high_resolution_clock::now();
- // Unload stable diffusion checkpoint before loading upscaler to prevent memory conflicts
- {
- std::lock_guard<std::mutex> lock(contextMutex);
- if (sdContext) {
- if (verbose) {
- LOG_DEBUG("Unloading stable diffusion checkpoint before loading upscaler model");
- }
- free_sd_ctx(sdContext);
- sdContext = nullptr;
- }
- }
- // Create upscaler context
- upscaler_ctx_t* upscalerCtx = new_upscaler_ctx(
- esrganPath.c_str(),
- offloadParamsToCpu,
- direct,
- nThreads);
- if (!upscalerCtx) {
- lastError = "Failed to create upscaler context";
- return result;
- }
- // Prepare input image
- sd_image_t inputImage;
- inputImage.width = inputWidth;
- inputImage.height = inputHeight;
- inputImage.channel = inputChannels;
- inputImage.data = const_cast<uint8_t*>(inputData.data());
- // Perform upscaling
- sd_image_t upscaled = upscale(upscalerCtx, inputImage, upscaleFactor);
- auto endTime = std::chrono::high_resolution_clock::now();
- auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
- if (!upscaled.data) {
- lastError = "Failed to upscale image";
- free_upscaler_ctx(upscalerCtx);
- return result;
- }
- // Convert to our format
- result.width = upscaled.width;
- result.height = upscaled.height;
- result.channels = upscaled.channel;
- result.seed = 0; // No seed for upscaling
- result.generationTime = duration.count();
- // Copy image data
- if (upscaled.data && upscaled.width > 0 && upscaled.height > 0 && upscaled.channel > 0) {
- size_t dataSize = upscaled.width * upscaled.height * upscaled.channel;
- result.data.resize(dataSize);
- std::memcpy(result.data.data(), upscaled.data, dataSize);
- }
- // Clean up
- free_upscaler_ctx(upscalerCtx);
- return result;
- }
- std::string getLastError() const {
- return lastError;
- }
- };
- // Static helper functions
- sample_method_t StableDiffusionWrapper::stringToSamplingMethod(const std::string& method) {
- std::string lowerMethod = method;
- std::transform(lowerMethod.begin(), lowerMethod.end(), lowerMethod.begin(), ::tolower);
- if (lowerMethod == "euler") {
- return EULER;
- } else if (lowerMethod == "euler_a") {
- return EULER_A;
- } else if (lowerMethod == "heun") {
- return HEUN;
- } else if (lowerMethod == "dpm2") {
- return DPM2;
- } else if (lowerMethod == "dpmpp2s_a") {
- return DPMPP2S_A;
- } else if (lowerMethod == "dpmpp2m") {
- return DPMPP2M;
- } else if (lowerMethod == "dpmpp2mv2") {
- return DPMPP2Mv2;
- } else if (lowerMethod == "ipndm") {
- return IPNDM;
- } else if (lowerMethod == "ipndm_v") {
- return IPNDM_V;
- } else if (lowerMethod == "lcm") {
- return LCM;
- } else if (lowerMethod == "ddim_trailing") {
- return DDIM_TRAILING;
- } else if (lowerMethod == "tcd") {
- return TCD;
- } else {
- return SAMPLE_METHOD_DEFAULT;
- }
- }
- scheduler_t StableDiffusionWrapper::stringToScheduler(const std::string& scheduler) {
- std::string lowerScheduler = scheduler;
- std::transform(lowerScheduler.begin(), lowerScheduler.end(), lowerScheduler.begin(), ::tolower);
- if (lowerScheduler == "discrete") {
- return DISCRETE;
- } else if (lowerScheduler == "karras") {
- return KARRAS;
- } else if (lowerScheduler == "exponential") {
- return EXPONENTIAL;
- } else if (lowerScheduler == "ays") {
- return AYS;
- } else if (lowerScheduler == "gits") {
- return GITS;
- } else if (lowerScheduler == "smoothstep") {
- return SMOOTHSTEP;
- } else if (lowerScheduler == "sgm_uniform") {
- return SGM_UNIFORM;
- } else if (lowerScheduler == "simple") {
- return SIMPLE;
- } else {
- return DEFAULT;
- }
- }
- sd_type_t StableDiffusionWrapper::stringToModelType(const std::string& type) {
- std::string lowerType = type;
- std::transform(lowerType.begin(), lowerType.end(), lowerType.begin(), ::tolower);
- if (lowerType == "f32") {
- return SD_TYPE_F32;
- } else if (lowerType == "f16") {
- return SD_TYPE_F16;
- } else if (lowerType == "q4_0") {
- return SD_TYPE_Q4_0;
- } else if (lowerType == "q4_1") {
- return SD_TYPE_Q4_1;
- } else if (lowerType == "q5_0") {
- return SD_TYPE_Q5_0;
- } else if (lowerType == "q5_1") {
- return SD_TYPE_Q5_1;
- } else if (lowerType == "q8_0") {
- return SD_TYPE_Q8_0;
- } else if (lowerType == "q8_1") {
- return SD_TYPE_Q8_1;
- } else if (lowerType == "q2_k") {
- return SD_TYPE_Q2_K;
- } else if (lowerType == "q3_k") {
- return SD_TYPE_Q3_K;
- } else if (lowerType == "q4_k") {
- return SD_TYPE_Q4_K;
- } else if (lowerType == "q5_k") {
- return SD_TYPE_Q5_K;
- } else if (lowerType == "q6_k") {
- return SD_TYPE_Q6_K;
- } else if (lowerType == "q8_k") {
- return SD_TYPE_Q8_K;
- } else {
- return SD_TYPE_F16; // Default to F16
- }
- }
- // Public interface implementation
- StableDiffusionWrapper::StableDiffusionWrapper() : pImpl(std::make_unique<Impl>()) {
- // wrapperMutex is automatically initialized by std::mutex default constructor
- }
- StableDiffusionWrapper::~StableDiffusionWrapper() = default;
- bool StableDiffusionWrapper::loadModel(const std::string& modelPath, const GenerationParams& params) {
- std::lock_guard<std::mutex> lock(wrapperMutex);
- return pImpl->loadModel(modelPath, params);
- }
- void StableDiffusionWrapper::unloadModel() {
- std::lock_guard<std::mutex> lock(wrapperMutex);
- pImpl->unloadModel();
- }
- bool StableDiffusionWrapper::isModelLoaded() const {
- std::lock_guard<std::mutex> lock(wrapperMutex);
- return pImpl->isModelLoaded();
- }
- std::vector<StableDiffusionWrapper::GeneratedImage> StableDiffusionWrapper::generateImage(
- const GenerationParams& params,
- ProgressCallback progressCallback,
- void* userData) {
- std::lock_guard<std::mutex> lock(wrapperMutex);
- return pImpl->generateImage(params, progressCallback, userData);
- }
- std::vector<StableDiffusionWrapper::GeneratedImage> StableDiffusionWrapper::generateImageImg2Img(
- const GenerationParams& params,
- const std::vector<uint8_t>& inputData,
- int inputWidth,
- int inputHeight,
- ProgressCallback progressCallback,
- void* userData) {
- std::lock_guard<std::mutex> lock(wrapperMutex);
- return pImpl->generateImageImg2Img(params, inputData, inputWidth, inputHeight, progressCallback, userData);
- }
- std::vector<StableDiffusionWrapper::GeneratedImage> StableDiffusionWrapper::generateImageControlNet(
- const GenerationParams& params,
- const std::vector<uint8_t>& controlData,
- int controlWidth,
- int controlHeight,
- ProgressCallback progressCallback,
- void* userData) {
- std::lock_guard<std::mutex> lock(wrapperMutex);
- return pImpl->generateImageControlNet(params, controlData, controlWidth, controlHeight, progressCallback, userData);
- }
- std::vector<StableDiffusionWrapper::GeneratedImage> StableDiffusionWrapper::generateImageInpainting(
- const GenerationParams& params,
- const std::vector<uint8_t>& inputData,
- int inputWidth,
- int inputHeight,
- const std::vector<uint8_t>& maskData,
- int maskWidth,
- int maskHeight,
- ProgressCallback progressCallback,
- void* userData) {
- std::lock_guard<std::mutex> lock(wrapperMutex);
- return pImpl->generateImageInpainting(params, inputData, inputWidth, inputHeight, maskData, maskWidth, maskHeight, progressCallback, userData);
- }
- StableDiffusionWrapper::GeneratedImage StableDiffusionWrapper::upscaleImage(
- const std::string& esrganPath,
- const std::vector<uint8_t>& inputData,
- int inputWidth,
- int inputHeight,
- int inputChannels,
- uint32_t upscaleFactor,
- int nThreads,
- bool offloadParamsToCpu,
- bool direct) {
- std::lock_guard<std::mutex> lock(wrapperMutex);
- return pImpl->upscaleImage(esrganPath, inputData, inputWidth, inputHeight, inputChannels, upscaleFactor, nThreads, offloadParamsToCpu, direct);
- }
- std::string StableDiffusionWrapper::getLastError() const {
- std::lock_guard<std::mutex> lock(wrapperMutex);
- return pImpl->getLastError();
- }
|