#include "stable_diffusion_wrapper.h" #include #include #include #include #include extern "C" { #include "stable-diffusion.h" } class StableDiffusionWrapper::Impl { public: sd_ctx_t* sdContext = nullptr; std::string lastError; std::mutex contextMutex; Impl() { // Initialize any required resources } ~Impl() { unloadModel(); } bool loadModel(const std::string& modelPath, const StableDiffusionWrapper::GenerationParams& params) { std::lock_guard lock(contextMutex); // Unload any existing model if (sdContext) { free_sd_ctx(sdContext); sdContext = nullptr; } // Initialize context parameters sd_ctx_params_t ctxParams; sd_ctx_params_init(&ctxParams); // Set model path ctxParams.model_path = modelPath.c_str(); // Set optional model paths if provided if (!params.clipLPath.empty()) { ctxParams.clip_l_path = params.clipLPath.c_str(); } if (!params.clipGPath.empty()) { ctxParams.clip_g_path = params.clipGPath.c_str(); } if (!params.vaePath.empty()) { ctxParams.vae_path = params.vaePath.c_str(); } if (!params.taesdPath.empty()) { ctxParams.taesd_path = params.taesdPath.c_str(); } if (!params.controlNetPath.empty()) { ctxParams.control_net_path = params.controlNetPath.c_str(); } if (!params.loraModelDir.empty()) { ctxParams.lora_model_dir = params.loraModelDir.c_str(); } if (!params.embeddingDir.empty()) { ctxParams.embedding_dir = params.embeddingDir.c_str(); } // Set performance parameters ctxParams.n_threads = params.nThreads; ctxParams.offload_params_to_cpu = params.offloadParamsToCpu; ctxParams.keep_clip_on_cpu = params.clipOnCpu; ctxParams.keep_vae_on_cpu = params.vaeOnCpu; ctxParams.diffusion_flash_attn = params.diffusionFlashAttn; ctxParams.diffusion_conv_direct = params.diffusionConvDirect; ctxParams.vae_conv_direct = params.vaeConvDirect; // Set model type ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType); // Create the stable-diffusion context sdContext = new_sd_ctx(&ctxParams); if (!sdContext) { lastError = "Failed to create stable-diffusion context"; return false; } std::cout << "Successfully loaded model: " << modelPath << std::endl; return true; } void unloadModel() { std::lock_guard lock(contextMutex); if (sdContext) { free_sd_ctx(sdContext); sdContext = nullptr; std::cout << "Unloaded stable-diffusion model" << std::endl; } } bool isModelLoaded() const { return sdContext != nullptr; } std::vector generateImage( const StableDiffusionWrapper::GenerationParams& params, StableDiffusionWrapper::ProgressCallback progressCallback, void* userData) { std::vector results; if (!sdContext) { lastError = "No model loaded"; return results; } auto startTime = std::chrono::high_resolution_clock::now(); // Initialize generation parameters sd_img_gen_params_t genParams; sd_img_gen_params_init(&genParams); // Set basic parameters genParams.prompt = params.prompt.c_str(); genParams.negative_prompt = params.negativePrompt.c_str(); genParams.width = params.width; genParams.height = params.height; genParams.sample_params.sample_steps = params.steps; genParams.seed = params.seed; genParams.batch_count = params.batchCount; // Set sampling parameters genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod); genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler); genParams.sample_params.guidance.txt_cfg = params.cfgScale; // Set advanced parameters genParams.clip_skip = params.clipSkip; genParams.strength = params.strength; // Set progress callback if provided // Track callback data to ensure proper cleanup std::pair* callbackData = nullptr; if (progressCallback) { callbackData = new std::pair(progressCallback, userData); sd_set_progress_callback([](int step, int steps, float time, void* data) { auto* callbackData = static_cast*>(data); if (callbackData) { callbackData->first(step, steps, time, callbackData->second); } }, callbackData); } // Generate the image sd_image_t* sdImages = generate_image(sdContext, &genParams); // Clear and clean up progress callback sd_set_progress_callback(nullptr, nullptr); if (callbackData) { delete callbackData; callbackData = nullptr; } auto endTime = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(endTime - startTime); if (!sdImages) { lastError = "Failed to generate image"; return results; } // Convert stable-diffusion images to our format for (int i = 0; i < params.batchCount; i++) { StableDiffusionWrapper::GeneratedImage image; image.width = sdImages[i].width; image.height = sdImages[i].height; image.channels = sdImages[i].channel; image.seed = params.seed; image.generationTime = duration.count(); // Copy image data if (sdImages[i].data && sdImages[i].width > 0 && sdImages[i].height > 0 && sdImages[i].channel > 0) { size_t dataSize = sdImages[i].width * sdImages[i].height * sdImages[i].channel; image.data.resize(dataSize); std::memcpy(image.data.data(), sdImages[i].data, dataSize); } results.push_back(image); } // Free the generated images // Clean up each image's data array for (int i = 0; i < params.batchCount; i++) { if (sdImages[i].data) { free(sdImages[i].data); sdImages[i].data = nullptr; } } // Free the image array itself free(sdImages); return results; } std::vector generateImageImg2Img( const StableDiffusionWrapper::GenerationParams& params, const std::vector& inputData, int inputWidth, int inputHeight, StableDiffusionWrapper::ProgressCallback progressCallback, void* userData) { std::vector results; if (!sdContext) { lastError = "No model loaded"; return results; } auto startTime = std::chrono::high_resolution_clock::now(); // Initialize generation parameters sd_img_gen_params_t genParams; sd_img_gen_params_init(&genParams); // Set basic parameters genParams.prompt = params.prompt.c_str(); genParams.negative_prompt = params.negativePrompt.c_str(); genParams.width = params.width; genParams.height = params.height; genParams.sample_params.sample_steps = params.steps; genParams.seed = params.seed; genParams.batch_count = params.batchCount; genParams.strength = params.strength; // Set sampling parameters genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod); genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler); genParams.sample_params.guidance.txt_cfg = params.cfgScale; // Set advanced parameters genParams.clip_skip = params.clipSkip; // Set input image sd_image_t initImage; initImage.width = inputWidth; initImage.height = inputHeight; initImage.channel = 3; // RGB initImage.data = const_cast(inputData.data()); genParams.init_image = initImage; // Set progress callback if provided // Track callback data to ensure proper cleanup std::pair* callbackData = nullptr; if (progressCallback) { callbackData = new std::pair(progressCallback, userData); sd_set_progress_callback([](int step, int steps, float time, void* data) { auto* callbackData = static_cast*>(data); if (callbackData) { callbackData->first(step, steps, time, callbackData->second); } }, callbackData); } // Generate the image sd_image_t* sdImages = generate_image(sdContext, &genParams); // Clear and clean up progress callback sd_set_progress_callback(nullptr, nullptr); if (callbackData) { delete callbackData; callbackData = nullptr; } auto endTime = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(endTime - startTime); if (!sdImages) { lastError = "Failed to generate image"; return results; } // Convert stable-diffusion images to our format for (int i = 0; i < params.batchCount; i++) { StableDiffusionWrapper::GeneratedImage image; image.width = sdImages[i].width; image.height = sdImages[i].height; image.channels = sdImages[i].channel; image.seed = params.seed; image.generationTime = duration.count(); // Copy image data if (sdImages[i].data && sdImages[i].width > 0 && sdImages[i].height > 0 && sdImages[i].channel > 0) { size_t dataSize = sdImages[i].width * sdImages[i].height * sdImages[i].channel; image.data.resize(dataSize); std::memcpy(image.data.data(), sdImages[i].data, dataSize); } results.push_back(image); } // Free the generated images // Clean up each image's data array for (int i = 0; i < params.batchCount; i++) { if (sdImages[i].data) { free(sdImages[i].data); sdImages[i].data = nullptr; } } // Free the image array itself free(sdImages); return results; } std::vector generateImageControlNet( const StableDiffusionWrapper::GenerationParams& params, const std::vector& controlData, int controlWidth, int controlHeight, StableDiffusionWrapper::ProgressCallback progressCallback, void* userData) { std::vector results; if (!sdContext) { lastError = "No model loaded"; return results; } auto startTime = std::chrono::high_resolution_clock::now(); // Initialize generation parameters sd_img_gen_params_t genParams; sd_img_gen_params_init(&genParams); // Set basic parameters genParams.prompt = params.prompt.c_str(); genParams.negative_prompt = params.negativePrompt.c_str(); genParams.width = params.width; genParams.height = params.height; genParams.sample_params.sample_steps = params.steps; genParams.seed = params.seed; genParams.batch_count = params.batchCount; genParams.control_strength = params.controlStrength; // Set sampling parameters genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod); genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler); genParams.sample_params.guidance.txt_cfg = params.cfgScale; // Set advanced parameters genParams.clip_skip = params.clipSkip; // Set control image sd_image_t controlImage; controlImage.width = controlWidth; controlImage.height = controlHeight; controlImage.channel = 3; // RGB controlImage.data = const_cast(controlData.data()); genParams.control_image = controlImage; // Set progress callback if provided // Track callback data to ensure proper cleanup std::pair* callbackData = nullptr; if (progressCallback) { callbackData = new std::pair(progressCallback, userData); sd_set_progress_callback([](int step, int steps, float time, void* data) { auto* callbackData = static_cast*>(data); if (callbackData) { callbackData->first(step, steps, time, callbackData->second); } }, callbackData); } // Generate the image sd_image_t* sdImages = generate_image(sdContext, &genParams); // Clear and clean up progress callback sd_set_progress_callback(nullptr, nullptr); if (callbackData) { delete callbackData; callbackData = nullptr; } auto endTime = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(endTime - startTime); if (!sdImages) { lastError = "Failed to generate image"; return results; } // Convert stable-diffusion images to our format for (int i = 0; i < params.batchCount; i++) { StableDiffusionWrapper::GeneratedImage image; image.width = sdImages[i].width; image.height = sdImages[i].height; image.channels = sdImages[i].channel; image.seed = params.seed; image.generationTime = duration.count(); // Copy image data if (sdImages[i].data && sdImages[i].width > 0 && sdImages[i].height > 0 && sdImages[i].channel > 0) { size_t dataSize = sdImages[i].width * sdImages[i].height * sdImages[i].channel; image.data.resize(dataSize); std::memcpy(image.data.data(), sdImages[i].data, dataSize); } results.push_back(image); } // Free the generated images // Clean up each image's data array for (int i = 0; i < params.batchCount; i++) { if (sdImages[i].data) { free(sdImages[i].data); sdImages[i].data = nullptr; } } // Free the image array itself free(sdImages); return results; } StableDiffusionWrapper::GeneratedImage upscaleImage( const std::string& esrganPath, const std::vector& inputData, int inputWidth, int inputHeight, int inputChannels, uint32_t upscaleFactor, int nThreads, bool offloadParamsToCpu, bool direct) { StableDiffusionWrapper::GeneratedImage result; auto startTime = std::chrono::high_resolution_clock::now(); // Create upscaler context upscaler_ctx_t* upscalerCtx = new_upscaler_ctx( esrganPath.c_str(), offloadParamsToCpu, direct, nThreads ); if (!upscalerCtx) { lastError = "Failed to create upscaler context"; return result; } // Prepare input image sd_image_t inputImage; inputImage.width = inputWidth; inputImage.height = inputHeight; inputImage.channel = inputChannels; inputImage.data = const_cast(inputData.data()); // Perform upscaling sd_image_t upscaled = upscale(upscalerCtx, inputImage, upscaleFactor); auto endTime = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(endTime - startTime); if (!upscaled.data) { lastError = "Failed to upscale image"; free_upscaler_ctx(upscalerCtx); return result; } // Convert to our format result.width = upscaled.width; result.height = upscaled.height; result.channels = upscaled.channel; result.seed = 0; // No seed for upscaling result.generationTime = duration.count(); // Copy image data if (upscaled.data && upscaled.width > 0 && upscaled.height > 0 && upscaled.channel > 0) { size_t dataSize = upscaled.width * upscaled.height * upscaled.channel; result.data.resize(dataSize); std::memcpy(result.data.data(), upscaled.data, dataSize); } // Clean up free_upscaler_ctx(upscalerCtx); return result; } std::string getLastError() const { return lastError; } }; // Static helper functions sample_method_t StableDiffusionWrapper::stringToSamplingMethod(const std::string& method) { std::string lowerMethod = method; std::transform(lowerMethod.begin(), lowerMethod.end(), lowerMethod.begin(), ::tolower); if (lowerMethod == "euler") { return EULER; } else if (lowerMethod == "euler_a") { return EULER_A; } else if (lowerMethod == "heun") { return HEUN; } else if (lowerMethod == "dpm2") { return DPM2; } else if (lowerMethod == "dpmpp2s_a") { return DPMPP2S_A; } else if (lowerMethod == "dpmpp2m") { return DPMPP2M; } else if (lowerMethod == "dpmpp2mv2") { return DPMPP2Mv2; } else if (lowerMethod == "ipndm") { return IPNDM; } else if (lowerMethod == "ipndm_v") { return IPNDM_V; } else if (lowerMethod == "lcm") { return LCM; } else if (lowerMethod == "ddim_trailing") { return DDIM_TRAILING; } else if (lowerMethod == "tcd") { return TCD; } else { return SAMPLE_METHOD_DEFAULT; } } scheduler_t StableDiffusionWrapper::stringToScheduler(const std::string& scheduler) { std::string lowerScheduler = scheduler; std::transform(lowerScheduler.begin(), lowerScheduler.end(), lowerScheduler.begin(), ::tolower); if (lowerScheduler == "discrete") { return DISCRETE; } else if (lowerScheduler == "karras") { return KARRAS; } else if (lowerScheduler == "exponential") { return EXPONENTIAL; } else if (lowerScheduler == "ays") { return AYS; } else if (lowerScheduler == "gits") { return GITS; } else if (lowerScheduler == "smoothstep") { return SMOOTHSTEP; } else if (lowerScheduler == "sgm_uniform") { return SGM_UNIFORM; } else if (lowerScheduler == "simple") { return SIMPLE; } else { return DEFAULT; } } sd_type_t StableDiffusionWrapper::stringToModelType(const std::string& type) { std::string lowerType = type; std::transform(lowerType.begin(), lowerType.end(), lowerType.begin(), ::tolower); if (lowerType == "f32") { return SD_TYPE_F32; } else if (lowerType == "f16") { return SD_TYPE_F16; } else if (lowerType == "q4_0") { return SD_TYPE_Q4_0; } else if (lowerType == "q4_1") { return SD_TYPE_Q4_1; } else if (lowerType == "q5_0") { return SD_TYPE_Q5_0; } else if (lowerType == "q5_1") { return SD_TYPE_Q5_1; } else if (lowerType == "q8_0") { return SD_TYPE_Q8_0; } else if (lowerType == "q8_1") { return SD_TYPE_Q8_1; } else if (lowerType == "q2_k") { return SD_TYPE_Q2_K; } else if (lowerType == "q3_k") { return SD_TYPE_Q3_K; } else if (lowerType == "q4_k") { return SD_TYPE_Q4_K; } else if (lowerType == "q5_k") { return SD_TYPE_Q5_K; } else if (lowerType == "q6_k") { return SD_TYPE_Q6_K; } else if (lowerType == "q8_k") { return SD_TYPE_Q8_K; } else { return SD_TYPE_F16; // Default to F16 } } // Public interface implementation StableDiffusionWrapper::StableDiffusionWrapper() : pImpl(std::make_unique()) { // wrapperMutex is automatically initialized by std::mutex default constructor } StableDiffusionWrapper::~StableDiffusionWrapper() = default; bool StableDiffusionWrapper::loadModel(const std::string& modelPath, const GenerationParams& params) { std::lock_guard lock(wrapperMutex); return pImpl->loadModel(modelPath, params); } void StableDiffusionWrapper::unloadModel() { std::lock_guard lock(wrapperMutex); pImpl->unloadModel(); } bool StableDiffusionWrapper::isModelLoaded() const { std::lock_guard lock(wrapperMutex); return pImpl->isModelLoaded(); } std::vector StableDiffusionWrapper::generateImage( const GenerationParams& params, ProgressCallback progressCallback, void* userData) { std::lock_guard lock(wrapperMutex); return pImpl->generateImage(params, progressCallback, userData); } std::vector StableDiffusionWrapper::generateImageImg2Img( const GenerationParams& params, const std::vector& inputData, int inputWidth, int inputHeight, ProgressCallback progressCallback, void* userData) { std::lock_guard lock(wrapperMutex); return pImpl->generateImageImg2Img(params, inputData, inputWidth, inputHeight, progressCallback, userData); } std::vector StableDiffusionWrapper::generateImageControlNet( const GenerationParams& params, const std::vector& controlData, int controlWidth, int controlHeight, ProgressCallback progressCallback, void* userData) { std::lock_guard lock(wrapperMutex); return pImpl->generateImageControlNet(params, controlData, controlWidth, controlHeight, progressCallback, userData); } StableDiffusionWrapper::GeneratedImage StableDiffusionWrapper::upscaleImage( const std::string& esrganPath, const std::vector& inputData, int inputWidth, int inputHeight, int inputChannels, uint32_t upscaleFactor, int nThreads, bool offloadParamsToCpu, bool direct) { std::lock_guard lock(wrapperMutex); return pImpl->upscaleImage(esrganPath, inputData, inputWidth, inputHeight, inputChannels, upscaleFactor, nThreads, offloadParamsToCpu, direct); } std::string StableDiffusionWrapper::getLastError() const { std::lock_guard lock(wrapperMutex); return pImpl->getLastError(); }