| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660 |
- #include "stable_diffusion_wrapper.h"
- #include <iostream>
- #include <chrono>
- #include <cstring>
- #include <algorithm>
- #include <random>
- extern "C" {
- #include "stable-diffusion.h"
- }
- class StableDiffusionWrapper::Impl {
- public:
- sd_ctx_t* sdContext = nullptr;
- std::string lastError;
- std::mutex contextMutex;
- Impl() {
- // Initialize any required resources
- }
- ~Impl() {
- unloadModel();
- }
- bool loadModel(const std::string& modelPath, const StableDiffusionWrapper::GenerationParams& params) {
- std::lock_guard<std::mutex> lock(contextMutex);
- // Unload any existing model
- if (sdContext) {
- free_sd_ctx(sdContext);
- sdContext = nullptr;
- }
- // Initialize context parameters
- sd_ctx_params_t ctxParams;
- sd_ctx_params_init(&ctxParams);
- // Set model path
- ctxParams.model_path = modelPath.c_str();
- // Set optional model paths if provided
- if (!params.clipLPath.empty()) {
- ctxParams.clip_l_path = params.clipLPath.c_str();
- }
- if (!params.clipGPath.empty()) {
- ctxParams.clip_g_path = params.clipGPath.c_str();
- }
- if (!params.vaePath.empty()) {
- ctxParams.vae_path = params.vaePath.c_str();
- }
- if (!params.taesdPath.empty()) {
- ctxParams.taesd_path = params.taesdPath.c_str();
- }
- if (!params.controlNetPath.empty()) {
- ctxParams.control_net_path = params.controlNetPath.c_str();
- }
- if (!params.loraModelDir.empty()) {
- ctxParams.lora_model_dir = params.loraModelDir.c_str();
- }
- if (!params.embeddingDir.empty()) {
- ctxParams.embedding_dir = params.embeddingDir.c_str();
- }
- // Set performance parameters
- ctxParams.n_threads = params.nThreads;
- ctxParams.offload_params_to_cpu = params.offloadParamsToCpu;
- ctxParams.keep_clip_on_cpu = params.clipOnCpu;
- ctxParams.keep_vae_on_cpu = params.vaeOnCpu;
- ctxParams.diffusion_flash_attn = params.diffusionFlashAttn;
- ctxParams.diffusion_conv_direct = params.diffusionConvDirect;
- ctxParams.vae_conv_direct = params.vaeConvDirect;
- // Set model type
- ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType);
- // Create the stable-diffusion context
- sdContext = new_sd_ctx(&ctxParams);
- if (!sdContext) {
- lastError = "Failed to create stable-diffusion context";
- return false;
- }
- std::cout << "Successfully loaded model: " << modelPath << std::endl;
- return true;
- }
- void unloadModel() {
- std::lock_guard<std::mutex> lock(contextMutex);
- if (sdContext) {
- free_sd_ctx(sdContext);
- sdContext = nullptr;
- std::cout << "Unloaded stable-diffusion model" << std::endl;
- }
- }
- bool isModelLoaded() const {
- return sdContext != nullptr;
- }
- std::vector<StableDiffusionWrapper::GeneratedImage> generateImage(
- const StableDiffusionWrapper::GenerationParams& params,
- StableDiffusionWrapper::ProgressCallback progressCallback,
- void* userData) {
- std::vector<StableDiffusionWrapper::GeneratedImage> results;
- if (!sdContext) {
- lastError = "No model loaded";
- return results;
- }
- auto startTime = std::chrono::high_resolution_clock::now();
- // Initialize generation parameters
- sd_img_gen_params_t genParams;
- sd_img_gen_params_init(&genParams);
- // Set basic parameters
- genParams.prompt = params.prompt.c_str();
- genParams.negative_prompt = params.negativePrompt.c_str();
- genParams.width = params.width;
- genParams.height = params.height;
- genParams.sample_params.sample_steps = params.steps;
- genParams.seed = params.seed;
- genParams.batch_count = params.batchCount;
- // Set sampling parameters
- genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod);
- genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler);
- genParams.sample_params.guidance.txt_cfg = params.cfgScale;
- // Set advanced parameters
- genParams.clip_skip = params.clipSkip;
- genParams.strength = params.strength;
- // Set progress callback if provided
- // Track callback data to ensure proper cleanup
- std::pair<StableDiffusionWrapper::ProgressCallback, void*>* callbackData = nullptr;
- if (progressCallback) {
- callbackData = new std::pair<StableDiffusionWrapper::ProgressCallback, void*>(progressCallback, userData);
- sd_set_progress_callback([](int step, int steps, float time, void* data) {
- auto* callbackData = static_cast<std::pair<StableDiffusionWrapper::ProgressCallback, void*>*>(data);
- if (callbackData) {
- callbackData->first(step, steps, time, callbackData->second);
- }
- }, callbackData);
- }
- // Generate the image
- sd_image_t* sdImages = generate_image(sdContext, &genParams);
- // Clear and clean up progress callback
- sd_set_progress_callback(nullptr, nullptr);
- if (callbackData) {
- delete callbackData;
- callbackData = nullptr;
- }
- auto endTime = std::chrono::high_resolution_clock::now();
- auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
- if (!sdImages) {
- lastError = "Failed to generate image";
- return results;
- }
- // Convert stable-diffusion images to our format
- for (int i = 0; i < params.batchCount; i++) {
- StableDiffusionWrapper::GeneratedImage image;
- image.width = sdImages[i].width;
- image.height = sdImages[i].height;
- image.channels = sdImages[i].channel;
- image.seed = params.seed;
- image.generationTime = duration.count();
- // Copy image data
- if (sdImages[i].data && sdImages[i].width > 0 && sdImages[i].height > 0 && sdImages[i].channel > 0) {
- size_t dataSize = sdImages[i].width * sdImages[i].height * sdImages[i].channel;
- image.data.resize(dataSize);
- std::memcpy(image.data.data(), sdImages[i].data, dataSize);
- }
- results.push_back(image);
- }
- // Free the generated images
- // Clean up each image's data array
- for (int i = 0; i < params.batchCount; i++) {
- if (sdImages[i].data) {
- free(sdImages[i].data);
- sdImages[i].data = nullptr;
- }
- }
- // Free the image array itself
- free(sdImages);
- return results;
- }
- std::vector<StableDiffusionWrapper::GeneratedImage> generateImageImg2Img(
- const StableDiffusionWrapper::GenerationParams& params,
- const std::vector<uint8_t>& inputData,
- int inputWidth,
- int inputHeight,
- StableDiffusionWrapper::ProgressCallback progressCallback,
- void* userData) {
- std::vector<StableDiffusionWrapper::GeneratedImage> results;
- if (!sdContext) {
- lastError = "No model loaded";
- return results;
- }
- auto startTime = std::chrono::high_resolution_clock::now();
- // Initialize generation parameters
- sd_img_gen_params_t genParams;
- sd_img_gen_params_init(&genParams);
- // Set basic parameters
- genParams.prompt = params.prompt.c_str();
- genParams.negative_prompt = params.negativePrompt.c_str();
- genParams.width = params.width;
- genParams.height = params.height;
- genParams.sample_params.sample_steps = params.steps;
- genParams.seed = params.seed;
- genParams.batch_count = params.batchCount;
- genParams.strength = params.strength;
- // Set sampling parameters
- genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod);
- genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler);
- genParams.sample_params.guidance.txt_cfg = params.cfgScale;
- // Set advanced parameters
- genParams.clip_skip = params.clipSkip;
- // Set input image
- sd_image_t initImage;
- initImage.width = inputWidth;
- initImage.height = inputHeight;
- initImage.channel = 3; // RGB
- initImage.data = const_cast<uint8_t*>(inputData.data());
- genParams.init_image = initImage;
- // Set progress callback if provided
- // Track callback data to ensure proper cleanup
- std::pair<StableDiffusionWrapper::ProgressCallback, void*>* callbackData = nullptr;
- if (progressCallback) {
- callbackData = new std::pair<StableDiffusionWrapper::ProgressCallback, void*>(progressCallback, userData);
- sd_set_progress_callback([](int step, int steps, float time, void* data) {
- auto* callbackData = static_cast<std::pair<StableDiffusionWrapper::ProgressCallback, void*>*>(data);
- if (callbackData) {
- callbackData->first(step, steps, time, callbackData->second);
- }
- }, callbackData);
- }
- // Generate the image
- sd_image_t* sdImages = generate_image(sdContext, &genParams);
- // Clear and clean up progress callback
- sd_set_progress_callback(nullptr, nullptr);
- if (callbackData) {
- delete callbackData;
- callbackData = nullptr;
- }
- auto endTime = std::chrono::high_resolution_clock::now();
- auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
- if (!sdImages) {
- lastError = "Failed to generate image";
- return results;
- }
- // Convert stable-diffusion images to our format
- for (int i = 0; i < params.batchCount; i++) {
- StableDiffusionWrapper::GeneratedImage image;
- image.width = sdImages[i].width;
- image.height = sdImages[i].height;
- image.channels = sdImages[i].channel;
- image.seed = params.seed;
- image.generationTime = duration.count();
- // Copy image data
- if (sdImages[i].data && sdImages[i].width > 0 && sdImages[i].height > 0 && sdImages[i].channel > 0) {
- size_t dataSize = sdImages[i].width * sdImages[i].height * sdImages[i].channel;
- image.data.resize(dataSize);
- std::memcpy(image.data.data(), sdImages[i].data, dataSize);
- }
- results.push_back(image);
- }
- // Free the generated images
- // Clean up each image's data array
- for (int i = 0; i < params.batchCount; i++) {
- if (sdImages[i].data) {
- free(sdImages[i].data);
- sdImages[i].data = nullptr;
- }
- }
- // Free the image array itself
- free(sdImages);
- return results;
- }
- std::vector<StableDiffusionWrapper::GeneratedImage> generateImageControlNet(
- const StableDiffusionWrapper::GenerationParams& params,
- const std::vector<uint8_t>& controlData,
- int controlWidth,
- int controlHeight,
- StableDiffusionWrapper::ProgressCallback progressCallback,
- void* userData) {
- std::vector<StableDiffusionWrapper::GeneratedImage> results;
- if (!sdContext) {
- lastError = "No model loaded";
- return results;
- }
- auto startTime = std::chrono::high_resolution_clock::now();
- // Initialize generation parameters
- sd_img_gen_params_t genParams;
- sd_img_gen_params_init(&genParams);
- // Set basic parameters
- genParams.prompt = params.prompt.c_str();
- genParams.negative_prompt = params.negativePrompt.c_str();
- genParams.width = params.width;
- genParams.height = params.height;
- genParams.sample_params.sample_steps = params.steps;
- genParams.seed = params.seed;
- genParams.batch_count = params.batchCount;
- genParams.control_strength = params.controlStrength;
- // Set sampling parameters
- genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod);
- genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler);
- genParams.sample_params.guidance.txt_cfg = params.cfgScale;
- // Set advanced parameters
- genParams.clip_skip = params.clipSkip;
- // Set control image
- sd_image_t controlImage;
- controlImage.width = controlWidth;
- controlImage.height = controlHeight;
- controlImage.channel = 3; // RGB
- controlImage.data = const_cast<uint8_t*>(controlData.data());
- genParams.control_image = controlImage;
- // Set progress callback if provided
- // Track callback data to ensure proper cleanup
- std::pair<StableDiffusionWrapper::ProgressCallback, void*>* callbackData = nullptr;
- if (progressCallback) {
- callbackData = new std::pair<StableDiffusionWrapper::ProgressCallback, void*>(progressCallback, userData);
- sd_set_progress_callback([](int step, int steps, float time, void* data) {
- auto* callbackData = static_cast<std::pair<StableDiffusionWrapper::ProgressCallback, void*>*>(data);
- if (callbackData) {
- callbackData->first(step, steps, time, callbackData->second);
- }
- }, callbackData);
- }
- // Generate the image
- sd_image_t* sdImages = generate_image(sdContext, &genParams);
- // Clear and clean up progress callback
- sd_set_progress_callback(nullptr, nullptr);
- if (callbackData) {
- delete callbackData;
- callbackData = nullptr;
- }
- auto endTime = std::chrono::high_resolution_clock::now();
- auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
- if (!sdImages) {
- lastError = "Failed to generate image";
- return results;
- }
- // Convert stable-diffusion images to our format
- for (int i = 0; i < params.batchCount; i++) {
- StableDiffusionWrapper::GeneratedImage image;
- image.width = sdImages[i].width;
- image.height = sdImages[i].height;
- image.channels = sdImages[i].channel;
- image.seed = params.seed;
- image.generationTime = duration.count();
- // Copy image data
- if (sdImages[i].data && sdImages[i].width > 0 && sdImages[i].height > 0 && sdImages[i].channel > 0) {
- size_t dataSize = sdImages[i].width * sdImages[i].height * sdImages[i].channel;
- image.data.resize(dataSize);
- std::memcpy(image.data.data(), sdImages[i].data, dataSize);
- }
- results.push_back(image);
- }
- // Free the generated images
- // Clean up each image's data array
- for (int i = 0; i < params.batchCount; i++) {
- if (sdImages[i].data) {
- free(sdImages[i].data);
- sdImages[i].data = nullptr;
- }
- }
- // Free the image array itself
- free(sdImages);
- return results;
- }
- StableDiffusionWrapper::GeneratedImage upscaleImage(
- const std::string& esrganPath,
- const std::vector<uint8_t>& inputData,
- int inputWidth,
- int inputHeight,
- int inputChannels,
- uint32_t upscaleFactor,
- int nThreads,
- bool offloadParamsToCpu,
- bool direct) {
- StableDiffusionWrapper::GeneratedImage result;
- auto startTime = std::chrono::high_resolution_clock::now();
- // Create upscaler context
- upscaler_ctx_t* upscalerCtx = new_upscaler_ctx(
- esrganPath.c_str(),
- offloadParamsToCpu,
- direct,
- nThreads
- );
- if (!upscalerCtx) {
- lastError = "Failed to create upscaler context";
- return result;
- }
- // Prepare input image
- sd_image_t inputImage;
- inputImage.width = inputWidth;
- inputImage.height = inputHeight;
- inputImage.channel = inputChannels;
- inputImage.data = const_cast<uint8_t*>(inputData.data());
- // Perform upscaling
- sd_image_t upscaled = upscale(upscalerCtx, inputImage, upscaleFactor);
- auto endTime = std::chrono::high_resolution_clock::now();
- auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
- if (!upscaled.data) {
- lastError = "Failed to upscale image";
- free_upscaler_ctx(upscalerCtx);
- return result;
- }
- // Convert to our format
- result.width = upscaled.width;
- result.height = upscaled.height;
- result.channels = upscaled.channel;
- result.seed = 0; // No seed for upscaling
- result.generationTime = duration.count();
- // Copy image data
- if (upscaled.data && upscaled.width > 0 && upscaled.height > 0 && upscaled.channel > 0) {
- size_t dataSize = upscaled.width * upscaled.height * upscaled.channel;
- result.data.resize(dataSize);
- std::memcpy(result.data.data(), upscaled.data, dataSize);
- }
- // Clean up
- free_upscaler_ctx(upscalerCtx);
- return result;
- }
- std::string getLastError() const {
- return lastError;
- }
- };
- // Static helper functions
- sample_method_t StableDiffusionWrapper::stringToSamplingMethod(const std::string& method) {
- std::string lowerMethod = method;
- std::transform(lowerMethod.begin(), lowerMethod.end(), lowerMethod.begin(), ::tolower);
- if (lowerMethod == "euler") {
- return EULER;
- } else if (lowerMethod == "euler_a") {
- return EULER_A;
- } else if (lowerMethod == "heun") {
- return HEUN;
- } else if (lowerMethod == "dpm2") {
- return DPM2;
- } else if (lowerMethod == "dpmpp2s_a") {
- return DPMPP2S_A;
- } else if (lowerMethod == "dpmpp2m") {
- return DPMPP2M;
- } else if (lowerMethod == "dpmpp2mv2") {
- return DPMPP2Mv2;
- } else if (lowerMethod == "ipndm") {
- return IPNDM;
- } else if (lowerMethod == "ipndm_v") {
- return IPNDM_V;
- } else if (lowerMethod == "lcm") {
- return LCM;
- } else if (lowerMethod == "ddim_trailing") {
- return DDIM_TRAILING;
- } else if (lowerMethod == "tcd") {
- return TCD;
- } else {
- return SAMPLE_METHOD_DEFAULT;
- }
- }
- scheduler_t StableDiffusionWrapper::stringToScheduler(const std::string& scheduler) {
- std::string lowerScheduler = scheduler;
- std::transform(lowerScheduler.begin(), lowerScheduler.end(), lowerScheduler.begin(), ::tolower);
- if (lowerScheduler == "discrete") {
- return DISCRETE;
- } else if (lowerScheduler == "karras") {
- return KARRAS;
- } else if (lowerScheduler == "exponential") {
- return EXPONENTIAL;
- } else if (lowerScheduler == "ays") {
- return AYS;
- } else if (lowerScheduler == "gits") {
- return GITS;
- } else if (lowerScheduler == "smoothstep") {
- return SMOOTHSTEP;
- } else if (lowerScheduler == "sgm_uniform") {
- return SGM_UNIFORM;
- } else if (lowerScheduler == "simple") {
- return SIMPLE;
- } else {
- return DEFAULT;
- }
- }
- sd_type_t StableDiffusionWrapper::stringToModelType(const std::string& type) {
- std::string lowerType = type;
- std::transform(lowerType.begin(), lowerType.end(), lowerType.begin(), ::tolower);
- if (lowerType == "f32") {
- return SD_TYPE_F32;
- } else if (lowerType == "f16") {
- return SD_TYPE_F16;
- } else if (lowerType == "q4_0") {
- return SD_TYPE_Q4_0;
- } else if (lowerType == "q4_1") {
- return SD_TYPE_Q4_1;
- } else if (lowerType == "q5_0") {
- return SD_TYPE_Q5_0;
- } else if (lowerType == "q5_1") {
- return SD_TYPE_Q5_1;
- } else if (lowerType == "q8_0") {
- return SD_TYPE_Q8_0;
- } else if (lowerType == "q8_1") {
- return SD_TYPE_Q8_1;
- } else if (lowerType == "q2_k") {
- return SD_TYPE_Q2_K;
- } else if (lowerType == "q3_k") {
- return SD_TYPE_Q3_K;
- } else if (lowerType == "q4_k") {
- return SD_TYPE_Q4_K;
- } else if (lowerType == "q5_k") {
- return SD_TYPE_Q5_K;
- } else if (lowerType == "q6_k") {
- return SD_TYPE_Q6_K;
- } else if (lowerType == "q8_k") {
- return SD_TYPE_Q8_K;
- } else {
- return SD_TYPE_F16; // Default to F16
- }
- }
- // Public interface implementation
- StableDiffusionWrapper::StableDiffusionWrapper() : pImpl(std::make_unique<Impl>()) {
- // wrapperMutex is automatically initialized by std::mutex default constructor
- }
- StableDiffusionWrapper::~StableDiffusionWrapper() = default;
- bool StableDiffusionWrapper::loadModel(const std::string& modelPath, const GenerationParams& params) {
- std::lock_guard<std::mutex> lock(wrapperMutex);
- return pImpl->loadModel(modelPath, params);
- }
- void StableDiffusionWrapper::unloadModel() {
- std::lock_guard<std::mutex> lock(wrapperMutex);
- pImpl->unloadModel();
- }
- bool StableDiffusionWrapper::isModelLoaded() const {
- std::lock_guard<std::mutex> lock(wrapperMutex);
- return pImpl->isModelLoaded();
- }
- std::vector<StableDiffusionWrapper::GeneratedImage> StableDiffusionWrapper::generateImage(
- const GenerationParams& params,
- ProgressCallback progressCallback,
- void* userData) {
- std::lock_guard<std::mutex> lock(wrapperMutex);
- return pImpl->generateImage(params, progressCallback, userData);
- }
- std::vector<StableDiffusionWrapper::GeneratedImage> StableDiffusionWrapper::generateImageImg2Img(
- const GenerationParams& params,
- const std::vector<uint8_t>& inputData,
- int inputWidth,
- int inputHeight,
- ProgressCallback progressCallback,
- void* userData) {
- std::lock_guard<std::mutex> lock(wrapperMutex);
- return pImpl->generateImageImg2Img(params, inputData, inputWidth, inputHeight, progressCallback, userData);
- }
- std::vector<StableDiffusionWrapper::GeneratedImage> StableDiffusionWrapper::generateImageControlNet(
- const GenerationParams& params,
- const std::vector<uint8_t>& controlData,
- int controlWidth,
- int controlHeight,
- ProgressCallback progressCallback,
- void* userData) {
- std::lock_guard<std::mutex> lock(wrapperMutex);
- return pImpl->generateImageControlNet(params, controlData, controlWidth, controlHeight, progressCallback, userData);
- }
- StableDiffusionWrapper::GeneratedImage StableDiffusionWrapper::upscaleImage(
- const std::string& esrganPath,
- const std::vector<uint8_t>& inputData,
- int inputWidth,
- int inputHeight,
- int inputChannels,
- uint32_t upscaleFactor,
- int nThreads,
- bool offloadParamsToCpu,
- bool direct) {
- std::lock_guard<std::mutex> lock(wrapperMutex);
- return pImpl->upscaleImage(esrganPath, inputData, inputWidth, inputHeight, inputChannels, upscaleFactor, nThreads, offloadParamsToCpu, direct);
- }
- std::string StableDiffusionWrapper::getLastError() const {
- std::lock_guard<std::mutex> lock(wrapperMutex);
- return pImpl->getLastError();
- }
|