#include "stable_diffusion_wrapper.h" #include #include #include #include #include #include "logger.h" #include "model_detector.h" extern "C" { #include "stable-diffusion.h" } class StableDiffusionWrapper::Impl { public: sd_ctx_t* sdContext = nullptr; std::string lastError; std::mutex contextMutex; bool verbose = false; std::string currentModelPath; StableDiffusionWrapper::GenerationParams currentModelParams; Impl() { // Initialize any required resources } ~Impl() { unloadModel(); } bool loadModel(const std::string& modelPath, const StableDiffusionWrapper::GenerationParams& params) { std::lock_guard lock(contextMutex); // Store verbose flag for use in other functions verbose = params.verbose; // Unload any existing model if (sdContext) { free_sd_ctx(sdContext); sdContext = nullptr; } // Initialize context parameters sd_ctx_params_t ctxParams; sd_ctx_params_init(&ctxParams); ctxParams.free_params_immediately = false; // avoid segfault when reusing // Get absolute path for logging std::filesystem::path absModelPath = std::filesystem::absolute(modelPath); if (params.verbose) { LOG_DEBUG("Loading model from absolute path: " + std::filesystem::absolute(modelPath).string()); } // Create persistent string copies to fix lifetime issues // These strings will remain valid for the entire lifetime of the context std::string persistentModelPath = modelPath; std::string persistentClipLPath = params.clipLPath; std::string persistentClipGPath = params.clipGPath; std::string persistentVaePath = params.vaePath; std::string persistentTaesdPath = params.taesdPath; std::string persistentControlNetPath = params.controlNetPath; std::string persistentLoraModelDir = params.loraModelDir; std::string persistentEmbeddingDir = params.embeddingDir; // Use folder-based path selection with enhanced logic bool useDiffusionModelPath = false; std::string detectionSource = "folder"; // Check if model is in diffusion_models directory by examining the path std::filesystem::path modelFilePath(modelPath); std::filesystem::path parentDir = modelFilePath.parent_path(); std::string parentDirName = parentDir.filename().string(); std::string modelFileName = modelFilePath.filename().string(); // Convert to lowercase for comparison std::transform(parentDirName.begin(), parentDirName.end(), parentDirName.begin(), ::tolower); std::transform(modelFileName.begin(), modelFileName.end(), modelFileName.begin(), ::tolower); // Variables for fallback detection ModelDetectionResult detectionResult; bool detectionSuccessful = false; bool isQwenModel = false; // Check if this is a Qwen model based on filename if (modelFileName.find("qwen") != std::string::npos) { isQwenModel = true; if (params.verbose) { LOG_DEBUG("Detected Qwen model from filename: " + modelFileName); } } // Enhanced path selection logic if (parentDirName == "diffusion_models" || parentDirName == "diffusion") { useDiffusionModelPath = true; if (params.verbose) { LOG_DEBUG("Model is in " + parentDirName + " directory, using diffusion_model_path"); } } else if (parentDirName == "checkpoints" || parentDirName == "stable-diffusion") { useDiffusionModelPath = false; if (params.verbose) { LOG_DEBUG("Model is in " + parentDirName + " directory, using model_path"); } } else if (parentDirName == "sd_models" || parentDirName.empty()) { // Handle models in root /data/SD_MODELS/ directory if (isQwenModel) { // Qwen models should use diffusion_model_path regardless of directory useDiffusionModelPath = true; detectionSource = "qwen_root_detection"; if (params.verbose) { LOG_DEBUG("Qwen model in root directory, preferring diffusion_model_path"); } } else { // For non-Qwen models in root, try architecture detection if (params.verbose) { LOG_DEBUG("Model is in root directory '" + parentDirName + "', attempting architecture detection"); } detectionSource = "architecture_fallback"; try { detectionResult = ModelDetector::detectModel(modelPath); detectionSuccessful = true; if (params.verbose) { LOG_DEBUG("Architecture detection found: " + detectionResult.architectureName); } } catch (const std::exception& e) { LOG_ERROR("Warning: Architecture detection failed: " + std::string(e.what()) + ". Using default loading method."); detectionResult.architecture = ModelArchitecture::UNKNOWN; detectionResult.architectureName = "Unknown"; } if (detectionSuccessful) { switch (detectionResult.architecture) { case ModelArchitecture::FLUX_SCHNELL: case ModelArchitecture::FLUX_DEV: case ModelArchitecture::FLUX_CHROMA: case ModelArchitecture::SD_3: case ModelArchitecture::QWEN2VL: // Modern architectures use diffusion_model_path useDiffusionModelPath = true; break; case ModelArchitecture::SD_1_5: case ModelArchitecture::SD_2_1: case ModelArchitecture::SDXL_BASE: case ModelArchitecture::SDXL_REFINER: // Traditional SD models use model_path useDiffusionModelPath = false; break; case ModelArchitecture::UNKNOWN: default: // Unknown architectures fall back to model_path for backward compatibility useDiffusionModelPath = false; if (params.verbose) { LOG_WARNING("Warning: Unknown model architecture detected, using default model_path for backward compatibility"); } break; } } else { useDiffusionModelPath = false; // Default fallback detectionSource = "default_fallback"; } } } else { // Unknown directory - try architecture detection if (params.verbose) { LOG_DEBUG("Model is in unknown directory '" + parentDirName + "', attempting architecture detection as fallback"); } detectionSource = "architecture_fallback"; try { detectionResult = ModelDetector::detectModel(modelPath); detectionSuccessful = true; if (params.verbose) { LOG_DEBUG("Fallback detection found architecture: " + detectionResult.architectureName); } } catch (const std::exception& e) { LOG_ERROR("Warning: Fallback model detection failed: " + std::string(e.what()) + ". Using default loading method."); detectionResult.architecture = ModelArchitecture::UNKNOWN; detectionResult.architectureName = "Unknown"; } if (detectionSuccessful) { switch (detectionResult.architecture) { case ModelArchitecture::FLUX_SCHNELL: case ModelArchitecture::FLUX_DEV: case ModelArchitecture::FLUX_CHROMA: case ModelArchitecture::SD_3: case ModelArchitecture::QWEN2VL: // Modern architectures use diffusion_model_path useDiffusionModelPath = true; break; case ModelArchitecture::SD_1_5: case ModelArchitecture::SD_2_1: case ModelArchitecture::SDXL_BASE: case ModelArchitecture::SDXL_REFINER: // Traditional SD models use model_path useDiffusionModelPath = false; break; case ModelArchitecture::UNKNOWN: default: // Unknown architectures fall back to model_path for backward compatibility useDiffusionModelPath = false; if (params.verbose) { LOG_WARNING("Warning: Unknown model architecture detected, using default model_path for backward compatibility"); } break; } } else { useDiffusionModelPath = false; // Default fallback detectionSource = "default_fallback"; } } // Set the appropriate model path based on folder location or fallback detection if (useDiffusionModelPath) { ctxParams.diffusion_model_path = persistentModelPath.c_str(); ctxParams.model_path = nullptr; // Clear the traditional path if (params.verbose) { LOG_DEBUG("Using diffusion_model_path (source: " + detectionSource + ")"); } } else { ctxParams.model_path = persistentModelPath.c_str(); ctxParams.diffusion_model_path = nullptr; // Clear the modern path if (params.verbose) { LOG_DEBUG("Using model_path (source: " + detectionSource + ")"); } } // Set optional model paths using persistent strings to fix lifetime issues if (!persistentClipLPath.empty()) { ctxParams.clip_l_path = persistentClipLPath.c_str(); if (params.verbose) { LOG_DEBUG("Using CLIP-L path: " + std::filesystem::absolute(persistentClipLPath).string()); } } if (!persistentClipGPath.empty()) { ctxParams.clip_g_path = persistentClipGPath.c_str(); if (params.verbose) { LOG_DEBUG("Using CLIP-G path: " + std::filesystem::absolute(persistentClipGPath).string()); } } if (!persistentVaePath.empty()) { // Check if VAE file exists before setting it if (std::filesystem::exists(persistentVaePath)) { ctxParams.vae_path = persistentVaePath.c_str(); if (params.verbose) { LOG_DEBUG("Using VAE path: " + std::filesystem::absolute(persistentVaePath).string()); } } else { if (params.verbose) { LOG_DEBUG("VAE file not found: " + std::filesystem::absolute(persistentVaePath).string() + " - continuing without VAE"); } ctxParams.vae_path = nullptr; } } if (!persistentTaesdPath.empty()) { ctxParams.taesd_path = persistentTaesdPath.c_str(); if (params.verbose) { LOG_DEBUG("Using TAESD path: " + std::filesystem::absolute(persistentTaesdPath).string()); } } if (!persistentControlNetPath.empty()) { ctxParams.control_net_path = persistentControlNetPath.c_str(); if (params.verbose) { LOG_DEBUG("Using ControlNet path: " + std::filesystem::absolute(persistentControlNetPath).string()); } } if (!persistentLoraModelDir.empty()) { ctxParams.lora_model_dir = persistentLoraModelDir.c_str(); if (params.verbose) { LOG_DEBUG("Using LoRA model directory: " + std::filesystem::absolute(persistentLoraModelDir).string()); } } if (!persistentEmbeddingDir.empty()) { ctxParams.embedding_dir = persistentEmbeddingDir.c_str(); if (params.verbose) { LOG_DEBUG("Using embedding directory: " + std::filesystem::absolute(persistentEmbeddingDir).string()); } } // Set performance parameters ctxParams.n_threads = params.nThreads; ctxParams.offload_params_to_cpu = params.offloadParamsToCpu; ctxParams.keep_clip_on_cpu = params.clipOnCpu; ctxParams.keep_vae_on_cpu = params.vaeOnCpu; ctxParams.diffusion_flash_attn = params.diffusionFlashAttn; ctxParams.diffusion_conv_direct = params.diffusionConvDirect; ctxParams.vae_conv_direct = params.vaeConvDirect; // Set model type ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType); // Create the stable-diffusion context if (params.verbose) { LOG_DEBUG("Attempting to create stable-diffusion context with selected parameters..."); } sdContext = new_sd_ctx(&ctxParams); if (!sdContext) { lastError = "Failed to create stable-diffusion context"; LOG_ERROR("Error: " + lastError + " with initial attempt"); // If we used diffusion_model_path and it failed, try fallback to model_path if (useDiffusionModelPath) { if (params.verbose) { LOG_WARNING("Warning: Failed to load with diffusion_model_path. Attempting fallback to model_path..."); } // Re-initialize context parameters sd_ctx_params_init(&ctxParams); // Set fallback model path using persistent string ctxParams.model_path = persistentModelPath.c_str(); ctxParams.diffusion_model_path = nullptr; // Re-apply other parameters using persistent strings if (!persistentClipLPath.empty()) { ctxParams.clip_l_path = persistentClipLPath.c_str(); } if (!persistentClipGPath.empty()) { ctxParams.clip_g_path = persistentClipGPath.c_str(); } if (!persistentVaePath.empty()) { // Check if VAE file exists before setting it if (std::filesystem::exists(persistentVaePath)) { ctxParams.vae_path = persistentVaePath.c_str(); } else { ctxParams.vae_path = nullptr; } } if (!persistentTaesdPath.empty()) { ctxParams.taesd_path = persistentTaesdPath.c_str(); } if (!persistentControlNetPath.empty()) { ctxParams.control_net_path = persistentControlNetPath.c_str(); } if (!persistentLoraModelDir.empty()) { ctxParams.lora_model_dir = persistentLoraModelDir.c_str(); } if (!persistentEmbeddingDir.empty()) { ctxParams.embedding_dir = persistentEmbeddingDir.c_str(); } // Re-apply performance parameters ctxParams.n_threads = params.nThreads; ctxParams.offload_params_to_cpu = params.offloadParamsToCpu; ctxParams.keep_clip_on_cpu = params.clipOnCpu; ctxParams.keep_vae_on_cpu = params.vaeOnCpu; ctxParams.diffusion_flash_attn = params.diffusionFlashAttn; ctxParams.diffusion_conv_direct = params.diffusionConvDirect; ctxParams.vae_conv_direct = params.vaeConvDirect; // Re-apply model type ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType); if (params.verbose) { LOG_DEBUG("Attempting to create context with fallback model_path..."); } // Try creating context again with fallback sdContext = new_sd_ctx(&ctxParams); if (!sdContext) { lastError = "Failed to create stable-diffusion context with both diffusion_model_path and model_path fallback"; LOG_ERROR("Error: " + lastError); // Additional fallback: try with minimal parameters for GGUF models if (modelFileName.find(".gguf") != std::string::npos || modelFileName.find(".ggml") != std::string::npos) { if (params.verbose) { LOG_DEBUG("Detected GGUF/GGML model, attempting minimal parameter fallback..."); } // Re-initialize with minimal parameters sd_ctx_params_init(&ctxParams); ctxParams.model_path = persistentModelPath.c_str(); ctxParams.diffusion_model_path = nullptr; // Set only essential parameters for GGUF ctxParams.n_threads = params.nThreads; ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType); if (params.verbose) { LOG_DEBUG("Attempting to create context with minimal GGUF parameters..."); } sdContext = new_sd_ctx(&ctxParams); if (!sdContext) { lastError = "Failed to create stable-diffusion context even with minimal GGUF parameters"; LOG_ERROR("Error: " + lastError); return false; } if (params.verbose) { LOG_DEBUG("Successfully loaded GGUF model with minimal parameters: " + absModelPath.string()); } } else { return false; } } else { if (params.verbose) { LOG_DEBUG("Successfully loaded model with fallback to model_path: " + absModelPath.string()); } } } else { // Try minimal fallback for non-diffusion_model_path failures if (modelFileName.find(".gguf") != std::string::npos || modelFileName.find(".ggml") != std::string::npos) { if (params.verbose) { LOG_DEBUG("Detected GGUF/GGML model, attempting minimal parameter fallback..."); } // Re-initialize with minimal parameters sd_ctx_params_init(&ctxParams); ctxParams.model_path = persistentModelPath.c_str(); // Set only essential parameters for GGUF ctxParams.n_threads = params.nThreads; ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType); if (params.verbose) { LOG_DEBUG("Attempting to create context with minimal GGUF parameters..."); } sdContext = new_sd_ctx(&ctxParams); if (!sdContext) { lastError = "Failed to create stable-diffusion context even with minimal GGUF parameters"; LOG_ERROR("Error: " + lastError); return false; } if (params.verbose) { LOG_DEBUG("Successfully loaded GGUF model with minimal parameters: " + absModelPath.string()); } } else { LOG_ERROR("Error: " + lastError); return false; } } } // Log successful loading with detection information if (params.verbose) { LOG_DEBUG("Successfully loaded model: " + absModelPath.string()); LOG_DEBUG(" Detection source: " + detectionSource); LOG_DEBUG(" Loading method: " + std::string(useDiffusionModelPath ? "diffusion_model_path" : "model_path")); LOG_DEBUG(" Parent directory: " + parentDirName); LOG_DEBUG(" Model filename: " + modelFileName); } // Log additional model properties if architecture detection was performed if (detectionSuccessful && params.verbose) { LOG_DEBUG(" Architecture: " + detectionResult.architectureName); if (detectionResult.textEncoderDim > 0) { LOG_DEBUG(" Text encoder dimension: " + std::to_string(detectionResult.textEncoderDim)); } if (detectionResult.needsVAE) { LOG_DEBUG(" Requires VAE: " + (detectionResult.recommendedVAE.empty() ? std::string("Yes") : detectionResult.recommendedVAE)); } } // Store current model info for potential reload after upscaling currentModelPath = modelPath; currentModelParams = params; return true; } void unloadModel() { std::lock_guard lock(contextMutex); if (sdContext) { free_sd_ctx(sdContext); sdContext = nullptr; if (verbose) { LOG_DEBUG("Unloaded stable-diffusion model"); } } // Clear stored model info currentModelPath.clear(); currentModelParams = StableDiffusionWrapper::GenerationParams(); } bool isModelLoaded() const { return sdContext != nullptr; } std::vector generateImage( const StableDiffusionWrapper::GenerationParams& params, StableDiffusionWrapper::ProgressCallback progressCallback, void* userData) { std::vector results; if (!sdContext) { lastError = "No model loaded"; return results; } auto startTime = std::chrono::high_resolution_clock::now(); // Initialize generation parameters sd_img_gen_params_t genParams; sd_img_gen_params_init(&genParams); // Set basic parameters genParams.prompt = params.prompt.c_str(); genParams.negative_prompt = params.negativePrompt.c_str(); genParams.width = params.width; genParams.height = params.height; genParams.sample_params.sample_steps = params.steps; genParams.seed = params.seed; genParams.batch_count = params.batchCount; // Set sampling parameters genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod); genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler); genParams.sample_params.guidance.txt_cfg = params.cfgScale; // Set advanced parameters genParams.clip_skip = params.clipSkip; genParams.strength = params.strength; // Set progress callback if provided // Track callback data to ensure proper cleanup std::pair* callbackData = nullptr; if (progressCallback) { callbackData = new std::pair(progressCallback, userData); sd_set_progress_callback([](int step, int steps, float time, void* data) { auto* callbackData = static_cast*>(data); if (callbackData) { callbackData->first(step, steps, time, callbackData->second); } }, callbackData); } // Generate the image LOG_DEBUG("[TIMING_ANALYSIS] Starting generate_image() call"); auto generationCallStart = std::chrono::high_resolution_clock::now(); sd_image_t* sdImages = generate_image(sdContext, &genParams); auto generationCallEnd = std::chrono::high_resolution_clock::now(); auto generationCallTime = std::chrono::duration_cast(generationCallEnd - generationCallStart).count(); LOG_DEBUG("[TIMING_ANALYSIS] generate_image() call completed in " + std::to_string(generationCallTime) + "ms"); // Clear and clean up progress callback - FIX: Wait for any pending callbacks sd_set_progress_callback(nullptr, nullptr); // Add a small delay to ensure any in-flight callbacks complete before cleanup std::this_thread::sleep_for(std::chrono::milliseconds(10)); if (callbackData) { delete callbackData; callbackData = nullptr; } auto endTime = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(endTime - startTime); if (!sdImages) { lastError = "Failed to generate image"; return results; } // Convert stable-diffusion images to our format for (int i = 0; i < params.batchCount; i++) { StableDiffusionWrapper::GeneratedImage image; image.width = sdImages[i].width; image.height = sdImages[i].height; image.channels = sdImages[i].channel; image.seed = params.seed; image.generationTime = duration.count(); // Copy image data if (sdImages[i].data && sdImages[i].width > 0 && sdImages[i].height > 0 && sdImages[i].channel > 0) { size_t dataSize = sdImages[i].width * sdImages[i].height * sdImages[i].channel; image.data.resize(dataSize); std::memcpy(image.data.data(), sdImages[i].data, dataSize); } results.push_back(image); } // Free the generated images // Clean up each image's data array for (int i = 0; i < params.batchCount; i++) { if (sdImages[i].data) { free(sdImages[i].data); sdImages[i].data = nullptr; } } // Free the image array itself free(sdImages); return results; } std::vector generateImageImg2Img( const StableDiffusionWrapper::GenerationParams& params, const std::vector& inputData, int inputWidth, int inputHeight, StableDiffusionWrapper::ProgressCallback progressCallback, void* userData) { std::vector results; if (!sdContext) { lastError = "No model loaded"; return results; } auto startTime = std::chrono::high_resolution_clock::now(); // Initialize generation parameters sd_img_gen_params_t genParams; sd_img_gen_params_init(&genParams); // Set basic parameters genParams.prompt = params.prompt.c_str(); genParams.negative_prompt = params.negativePrompt.c_str(); genParams.width = params.width; genParams.height = params.height; genParams.sample_params.sample_steps = params.steps; genParams.seed = params.seed; genParams.batch_count = params.batchCount; genParams.strength = params.strength; // Set sampling parameters genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod); genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler); genParams.sample_params.guidance.txt_cfg = params.cfgScale; // Set advanced parameters genParams.clip_skip = params.clipSkip; // Set input image sd_image_t initImage; initImage.width = inputWidth; initImage.height = inputHeight; initImage.channel = 3; // RGB initImage.data = const_cast(inputData.data()); genParams.init_image = initImage; // Set progress callback if provided // Track callback data to ensure proper cleanup std::pair* callbackData = nullptr; if (progressCallback) { callbackData = new std::pair(progressCallback, userData); sd_set_progress_callback([](int step, int steps, float time, void* data) { auto* callbackData = static_cast*>(data); if (callbackData) { callbackData->first(step, steps, time, callbackData->second); } }, callbackData); } // Generate the image sd_image_t* sdImages = generate_image(sdContext, &genParams); // Clear and clean up progress callback - FIX: Wait for any pending callbacks sd_set_progress_callback(nullptr, nullptr); // Add a small delay to ensure any in-flight callbacks complete before cleanup std::this_thread::sleep_for(std::chrono::milliseconds(10)); if (callbackData) { delete callbackData; callbackData = nullptr; } auto endTime = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(endTime - startTime); if (!sdImages) { lastError = "Failed to generate image"; return results; } // Convert stable-diffusion images to our format for (int i = 0; i < params.batchCount; i++) { StableDiffusionWrapper::GeneratedImage image; image.width = sdImages[i].width; image.height = sdImages[i].height; image.channels = sdImages[i].channel; image.seed = params.seed; image.generationTime = duration.count(); // Copy image data if (sdImages[i].data && sdImages[i].width > 0 && sdImages[i].height > 0 && sdImages[i].channel > 0) { size_t dataSize = sdImages[i].width * sdImages[i].height * sdImages[i].channel; image.data.resize(dataSize); std::memcpy(image.data.data(), sdImages[i].data, dataSize); } results.push_back(image); } // Free the generated images // Clean up each image's data array for (int i = 0; i < params.batchCount; i++) { if (sdImages[i].data) { free(sdImages[i].data); sdImages[i].data = nullptr; } } // Free the image array itself free(sdImages); return results; } std::vector generateImageControlNet( const StableDiffusionWrapper::GenerationParams& params, const std::vector& controlData, int controlWidth, int controlHeight, StableDiffusionWrapper::ProgressCallback progressCallback, void* userData) { std::vector results; if (!sdContext) { lastError = "No model loaded"; return results; } auto startTime = std::chrono::high_resolution_clock::now(); // Initialize generation parameters sd_img_gen_params_t genParams; sd_img_gen_params_init(&genParams); // Set basic parameters genParams.prompt = params.prompt.c_str(); genParams.negative_prompt = params.negativePrompt.c_str(); genParams.width = params.width; genParams.height = params.height; genParams.sample_params.sample_steps = params.steps; genParams.seed = params.seed; genParams.batch_count = params.batchCount; genParams.control_strength = params.controlStrength; // Set sampling parameters genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod); genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler); genParams.sample_params.guidance.txt_cfg = params.cfgScale; // Set advanced parameters genParams.clip_skip = params.clipSkip; // Set control image sd_image_t controlImage; controlImage.width = controlWidth; controlImage.height = controlHeight; controlImage.channel = 3; // RGB controlImage.data = const_cast(controlData.data()); genParams.control_image = controlImage; // Set progress callback if provided // Track callback data to ensure proper cleanup std::pair* callbackData = nullptr; if (progressCallback) { callbackData = new std::pair(progressCallback, userData); sd_set_progress_callback([](int step, int steps, float time, void* data) { auto* callbackData = static_cast*>(data); if (callbackData) { callbackData->first(step, steps, time, callbackData->second); } }, callbackData); } // Generate the image sd_image_t* sdImages = generate_image(sdContext, &genParams); // Clear and clean up progress callback - FIX: Wait for any pending callbacks sd_set_progress_callback(nullptr, nullptr); // Add a small delay to ensure any in-flight callbacks complete before cleanup std::this_thread::sleep_for(std::chrono::milliseconds(10)); if (callbackData) { delete callbackData; callbackData = nullptr; } auto endTime = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(endTime - startTime); if (!sdImages) { lastError = "Failed to generate image"; return results; } // Convert stable-diffusion images to our format for (int i = 0; i < params.batchCount; i++) { StableDiffusionWrapper::GeneratedImage image; image.width = sdImages[i].width; image.height = sdImages[i].height; image.channels = sdImages[i].channel; image.seed = params.seed; image.generationTime = duration.count(); // Copy image data if (sdImages[i].data && sdImages[i].width > 0 && sdImages[i].height > 0 && sdImages[i].channel > 0) { size_t dataSize = sdImages[i].width * sdImages[i].height * sdImages[i].channel; image.data.resize(dataSize); std::memcpy(image.data.data(), sdImages[i].data, dataSize); } results.push_back(image); } // Free the generated images // Clean up each image's data array for (int i = 0; i < params.batchCount; i++) { if (sdImages[i].data) { free(sdImages[i].data); sdImages[i].data = nullptr; } } // Free the image array itself free(sdImages); return results; } std::vector generateImageInpainting( const StableDiffusionWrapper::GenerationParams& params, const std::vector& inputData, int inputWidth, int inputHeight, const std::vector& maskData, int maskWidth, int maskHeight, StableDiffusionWrapper::ProgressCallback progressCallback, void* userData) { std::vector results; if (!sdContext) { lastError = "No model loaded"; return results; } auto startTime = std::chrono::high_resolution_clock::now(); // Initialize generation parameters sd_img_gen_params_t genParams; sd_img_gen_params_init(&genParams); // Set basic parameters genParams.prompt = params.prompt.c_str(); genParams.negative_prompt = params.negativePrompt.c_str(); genParams.width = params.width; genParams.height = params.height; genParams.sample_params.sample_steps = params.steps; genParams.seed = params.seed; genParams.batch_count = params.batchCount; genParams.strength = params.strength; // Set sampling parameters genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod); genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler); genParams.sample_params.guidance.txt_cfg = params.cfgScale; // Set advanced parameters genParams.clip_skip = params.clipSkip; // Set input image sd_image_t initImage; initImage.width = inputWidth; initImage.height = inputHeight; initImage.channel = 3; // RGB initImage.data = const_cast(inputData.data()); genParams.init_image = initImage; // Set mask image sd_image_t maskImage; maskImage.width = maskWidth; maskImage.height = maskHeight; maskImage.channel = 1; // Grayscale mask maskImage.data = const_cast(maskData.data()); genParams.mask_image = maskImage; // Set progress callback if provided // Track callback data to ensure proper cleanup std::pair* callbackData = nullptr; if (progressCallback) { callbackData = new std::pair(progressCallback, userData); sd_set_progress_callback([](int step, int steps, float time, void* data) { auto* callbackData = static_cast*>(data); if (callbackData) { callbackData->first(step, steps, time, callbackData->second); } }, callbackData); } // Generate the image sd_image_t* sdImages = generate_image(sdContext, &genParams); // Clear and clean up progress callback - FIX: Wait for any pending callbacks sd_set_progress_callback(nullptr, nullptr); // Add a small delay to ensure any in-flight callbacks complete before cleanup std::this_thread::sleep_for(std::chrono::milliseconds(10)); if (callbackData) { delete callbackData; callbackData = nullptr; } auto endTime = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(endTime - startTime); if (!sdImages) { lastError = "Failed to generate image"; return results; } // Convert stable-diffusion images to our format for (int i = 0; i < params.batchCount; i++) { StableDiffusionWrapper::GeneratedImage image; image.width = sdImages[i].width; image.height = sdImages[i].height; image.channels = sdImages[i].channel; image.seed = params.seed; image.generationTime = duration.count(); // Copy image data if (sdImages[i].data && sdImages[i].width > 0 && sdImages[i].height > 0 && sdImages[i].channel > 0) { size_t dataSize = sdImages[i].width * sdImages[i].height * sdImages[i].channel; image.data.resize(dataSize); std::memcpy(image.data.data(), sdImages[i].data, dataSize); } results.push_back(image); } // Free the generated images // Clean up each image's data array for (int i = 0; i < params.batchCount; i++) { if (sdImages[i].data) { free(sdImages[i].data); sdImages[i].data = nullptr; } } // Free the image array itself free(sdImages); return results; } StableDiffusionWrapper::GeneratedImage upscaleImage( const std::string& esrganPath, const std::vector& inputData, int inputWidth, int inputHeight, int inputChannels, uint32_t upscaleFactor, int nThreads, bool offloadParamsToCpu, bool direct) { StableDiffusionWrapper::GeneratedImage result; auto startTime = std::chrono::high_resolution_clock::now(); // Unload stable diffusion checkpoint before loading upscaler to prevent memory conflicts { std::lock_guard lock(contextMutex); if (sdContext) { if (verbose) { LOG_DEBUG("Unloading stable diffusion checkpoint before loading upscaler model"); } free_sd_ctx(sdContext); sdContext = nullptr; } } // Create upscaler context upscaler_ctx_t* upscalerCtx = new_upscaler_ctx( esrganPath.c_str(), offloadParamsToCpu, direct, nThreads); if (!upscalerCtx) { lastError = "Failed to create upscaler context"; return result; } // Prepare input image sd_image_t inputImage; inputImage.width = inputWidth; inputImage.height = inputHeight; inputImage.channel = inputChannels; inputImage.data = const_cast(inputData.data()); // Perform upscaling sd_image_t upscaled = upscale(upscalerCtx, inputImage, upscaleFactor); auto endTime = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(endTime - startTime); if (!upscaled.data) { lastError = "Failed to upscale image"; free_upscaler_ctx(upscalerCtx); return result; } // Convert to our format result.width = upscaled.width; result.height = upscaled.height; result.channels = upscaled.channel; result.seed = 0; // No seed for upscaling result.generationTime = duration.count(); // Copy image data if (upscaled.data && upscaled.width > 0 && upscaled.height > 0 && upscaled.channel > 0) { size_t dataSize = upscaled.width * upscaled.height * upscaled.channel; result.data.resize(dataSize); std::memcpy(result.data.data(), upscaled.data, dataSize); } // Clean up free_upscaler_ctx(upscalerCtx); return result; } std::string getLastError() const { return lastError; } }; // Static helper functions sample_method_t StableDiffusionWrapper::stringToSamplingMethod(const std::string& method) { std::string lowerMethod = method; std::transform(lowerMethod.begin(), lowerMethod.end(), lowerMethod.begin(), ::tolower); if (lowerMethod == "euler") { return EULER; } else if (lowerMethod == "euler_a") { return EULER_A; } else if (lowerMethod == "heun") { return HEUN; } else if (lowerMethod == "dpm2") { return DPM2; } else if (lowerMethod == "dpmpp2s_a") { return DPMPP2S_A; } else if (lowerMethod == "dpmpp2m") { return DPMPP2M; } else if (lowerMethod == "dpmpp2mv2") { return DPMPP2Mv2; } else if (lowerMethod == "ipndm") { return IPNDM; } else if (lowerMethod == "ipndm_v") { return IPNDM_V; } else if (lowerMethod == "lcm") { return LCM; } else if (lowerMethod == "ddim_trailing") { return DDIM_TRAILING; } else if (lowerMethod == "tcd") { return TCD; } else { return SAMPLE_METHOD_DEFAULT; } } scheduler_t StableDiffusionWrapper::stringToScheduler(const std::string& scheduler) { std::string lowerScheduler = scheduler; std::transform(lowerScheduler.begin(), lowerScheduler.end(), lowerScheduler.begin(), ::tolower); if (lowerScheduler == "discrete") { return DISCRETE; } else if (lowerScheduler == "karras") { return KARRAS; } else if (lowerScheduler == "exponential") { return EXPONENTIAL; } else if (lowerScheduler == "ays") { return AYS; } else if (lowerScheduler == "gits") { return GITS; } else if (lowerScheduler == "smoothstep") { return SMOOTHSTEP; } else if (lowerScheduler == "sgm_uniform") { return SGM_UNIFORM; } else if (lowerScheduler == "simple") { return SIMPLE; } else { return DEFAULT; } } sd_type_t StableDiffusionWrapper::stringToModelType(const std::string& type) { std::string lowerType = type; std::transform(lowerType.begin(), lowerType.end(), lowerType.begin(), ::tolower); if (lowerType == "f32") { return SD_TYPE_F32; } else if (lowerType == "f16") { return SD_TYPE_F16; } else if (lowerType == "q4_0") { return SD_TYPE_Q4_0; } else if (lowerType == "q4_1") { return SD_TYPE_Q4_1; } else if (lowerType == "q5_0") { return SD_TYPE_Q5_0; } else if (lowerType == "q5_1") { return SD_TYPE_Q5_1; } else if (lowerType == "q8_0") { return SD_TYPE_Q8_0; } else if (lowerType == "q8_1") { return SD_TYPE_Q8_1; } else if (lowerType == "q2_k") { return SD_TYPE_Q2_K; } else if (lowerType == "q3_k") { return SD_TYPE_Q3_K; } else if (lowerType == "q4_k") { return SD_TYPE_Q4_K; } else if (lowerType == "q5_k") { return SD_TYPE_Q5_K; } else if (lowerType == "q6_k") { return SD_TYPE_Q6_K; } else if (lowerType == "q8_k") { return SD_TYPE_Q8_K; } else { return SD_TYPE_F16; // Default to F16 } } // Public interface implementation StableDiffusionWrapper::StableDiffusionWrapper() : pImpl(std::make_unique()) { // wrapperMutex is automatically initialized by std::mutex default constructor } StableDiffusionWrapper::~StableDiffusionWrapper() = default; bool StableDiffusionWrapper::loadModel(const std::string& modelPath, const GenerationParams& params) { std::lock_guard lock(wrapperMutex); return pImpl->loadModel(modelPath, params); } void StableDiffusionWrapper::unloadModel() { std::lock_guard lock(wrapperMutex); pImpl->unloadModel(); } bool StableDiffusionWrapper::isModelLoaded() const { std::lock_guard lock(wrapperMutex); return pImpl->isModelLoaded(); } std::vector StableDiffusionWrapper::generateImage( const GenerationParams& params, ProgressCallback progressCallback, void* userData) { std::lock_guard lock(wrapperMutex); return pImpl->generateImage(params, progressCallback, userData); } std::vector StableDiffusionWrapper::generateImageImg2Img( const GenerationParams& params, const std::vector& inputData, int inputWidth, int inputHeight, ProgressCallback progressCallback, void* userData) { std::lock_guard lock(wrapperMutex); return pImpl->generateImageImg2Img(params, inputData, inputWidth, inputHeight, progressCallback, userData); } std::vector StableDiffusionWrapper::generateImageControlNet( const GenerationParams& params, const std::vector& controlData, int controlWidth, int controlHeight, ProgressCallback progressCallback, void* userData) { std::lock_guard lock(wrapperMutex); return pImpl->generateImageControlNet(params, controlData, controlWidth, controlHeight, progressCallback, userData); } std::vector StableDiffusionWrapper::generateImageInpainting( const GenerationParams& params, const std::vector& inputData, int inputWidth, int inputHeight, const std::vector& maskData, int maskWidth, int maskHeight, ProgressCallback progressCallback, void* userData) { std::lock_guard lock(wrapperMutex); return pImpl->generateImageInpainting(params, inputData, inputWidth, inputHeight, maskData, maskWidth, maskHeight, progressCallback, userData); } StableDiffusionWrapper::GeneratedImage StableDiffusionWrapper::upscaleImage( const std::string& esrganPath, const std::vector& inputData, int inputWidth, int inputHeight, int inputChannels, uint32_t upscaleFactor, int nThreads, bool offloadParamsToCpu, bool direct) { std::lock_guard lock(wrapperMutex); return pImpl->upscaleImage(esrganPath, inputData, inputWidth, inputHeight, inputChannels, upscaleFactor, nThreads, offloadParamsToCpu, direct); } std::string StableDiffusionWrapper::getLastError() const { std::lock_guard lock(wrapperMutex); return pImpl->getLastError(); }