|
|
@@ -1,14 +1,14 @@
|
|
|
#include "stable_diffusion_wrapper.h"
|
|
|
-#include "model_detector.h"
|
|
|
-#include <iostream>
|
|
|
+#include <algorithm>
|
|
|
#include <chrono>
|
|
|
#include <cstring>
|
|
|
-#include <algorithm>
|
|
|
#include <filesystem>
|
|
|
+#include <iostream>
|
|
|
#include <thread>
|
|
|
+#include "model_detector.h"
|
|
|
|
|
|
extern "C" {
|
|
|
- #include "stable-diffusion.h"
|
|
|
+#include "stable-diffusion.h"
|
|
|
}
|
|
|
|
|
|
class StableDiffusionWrapper::Impl {
|
|
|
@@ -28,7 +28,7 @@ public:
|
|
|
|
|
|
bool loadModel(const std::string& modelPath, const StableDiffusionWrapper::GenerationParams& params) {
|
|
|
std::lock_guard<std::mutex> lock(contextMutex);
|
|
|
-
|
|
|
+
|
|
|
// Store verbose flag for use in other functions
|
|
|
verbose = params.verbose;
|
|
|
|
|
|
@@ -41,43 +41,44 @@ public:
|
|
|
// Initialize context parameters
|
|
|
sd_ctx_params_t ctxParams;
|
|
|
sd_ctx_params_init(&ctxParams);
|
|
|
-
|
|
|
+ ctxParams.free_params_immediately = false; // avoid segfault when reusing
|
|
|
+
|
|
|
// Get absolute path for logging
|
|
|
std::filesystem::path absModelPath = std::filesystem::absolute(modelPath);
|
|
|
if (params.verbose) {
|
|
|
std::cout << "Loading model from absolute path: " << absModelPath << std::endl;
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
// Create persistent string copies to fix lifetime issues
|
|
|
// These strings will remain valid for the entire lifetime of the context
|
|
|
- std::string persistentModelPath = modelPath;
|
|
|
- std::string persistentClipLPath = params.clipLPath;
|
|
|
- std::string persistentClipGPath = params.clipGPath;
|
|
|
- std::string persistentVaePath = params.vaePath;
|
|
|
- std::string persistentTaesdPath = params.taesdPath;
|
|
|
+ std::string persistentModelPath = modelPath;
|
|
|
+ std::string persistentClipLPath = params.clipLPath;
|
|
|
+ std::string persistentClipGPath = params.clipGPath;
|
|
|
+ std::string persistentVaePath = params.vaePath;
|
|
|
+ std::string persistentTaesdPath = params.taesdPath;
|
|
|
std::string persistentControlNetPath = params.controlNetPath;
|
|
|
- std::string persistentLoraModelDir = params.loraModelDir;
|
|
|
- std::string persistentEmbeddingDir = params.embeddingDir;
|
|
|
-
|
|
|
+ std::string persistentLoraModelDir = params.loraModelDir;
|
|
|
+ std::string persistentEmbeddingDir = params.embeddingDir;
|
|
|
+
|
|
|
// Use folder-based path selection with enhanced logic
|
|
|
- bool useDiffusionModelPath = false;
|
|
|
+ bool useDiffusionModelPath = false;
|
|
|
std::string detectionSource = "folder";
|
|
|
-
|
|
|
+
|
|
|
// Check if model is in diffusion_models directory by examining the path
|
|
|
std::filesystem::path modelFilePath(modelPath);
|
|
|
std::filesystem::path parentDir = modelFilePath.parent_path();
|
|
|
- std::string parentDirName = parentDir.filename().string();
|
|
|
- std::string modelFileName = modelFilePath.filename().string();
|
|
|
-
|
|
|
+ std::string parentDirName = parentDir.filename().string();
|
|
|
+ std::string modelFileName = modelFilePath.filename().string();
|
|
|
+
|
|
|
// Convert to lowercase for comparison
|
|
|
std::transform(parentDirName.begin(), parentDirName.end(), parentDirName.begin(), ::tolower);
|
|
|
std::transform(modelFileName.begin(), modelFileName.end(), modelFileName.begin(), ::tolower);
|
|
|
-
|
|
|
+
|
|
|
// Variables for fallback detection
|
|
|
ModelDetectionResult detectionResult;
|
|
|
bool detectionSuccessful = false;
|
|
|
- bool isQwenModel = false;
|
|
|
-
|
|
|
+ bool isQwenModel = false;
|
|
|
+
|
|
|
// Check if this is a Qwen model based on filename
|
|
|
if (modelFileName.find("qwen") != std::string::npos) {
|
|
|
isQwenModel = true;
|
|
|
@@ -85,7 +86,7 @@ public:
|
|
|
std::cout << "Detected Qwen model from filename: " << modelFileName << std::endl;
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
// Enhanced path selection logic
|
|
|
if (parentDirName == "diffusion_models" || parentDirName == "diffusion") {
|
|
|
useDiffusionModelPath = true;
|
|
|
@@ -102,7 +103,7 @@ public:
|
|
|
if (isQwenModel) {
|
|
|
// Qwen models should use diffusion_model_path regardless of directory
|
|
|
useDiffusionModelPath = true;
|
|
|
- detectionSource = "qwen_root_detection";
|
|
|
+ detectionSource = "qwen_root_detection";
|
|
|
if (params.verbose) {
|
|
|
std::cout << "Qwen model in root directory, preferring diffusion_model_path" << std::endl;
|
|
|
}
|
|
|
@@ -112,16 +113,16 @@ public:
|
|
|
std::cout << "Model is in root directory '" << parentDirName << "', attempting architecture detection" << std::endl;
|
|
|
}
|
|
|
detectionSource = "architecture_fallback";
|
|
|
-
|
|
|
+
|
|
|
try {
|
|
|
- detectionResult = ModelDetector::detectModel(modelPath);
|
|
|
+ detectionResult = ModelDetector::detectModel(modelPath);
|
|
|
detectionSuccessful = true;
|
|
|
if (params.verbose) {
|
|
|
std::cout << "Architecture detection found: " << detectionResult.architectureName << std::endl;
|
|
|
}
|
|
|
} catch (const std::exception& e) {
|
|
|
std::cerr << "Warning: Architecture detection failed: " << e.what() << ". Using default loading method." << std::endl;
|
|
|
- detectionResult.architecture = ModelArchitecture::UNKNOWN;
|
|
|
+ detectionResult.architecture = ModelArchitecture::UNKNOWN;
|
|
|
detectionResult.architectureName = "Unknown";
|
|
|
}
|
|
|
|
|
|
@@ -152,8 +153,8 @@ public:
|
|
|
break;
|
|
|
}
|
|
|
} else {
|
|
|
- useDiffusionModelPath = false; // Default fallback
|
|
|
- detectionSource = "default_fallback";
|
|
|
+ useDiffusionModelPath = false; // Default fallback
|
|
|
+ detectionSource = "default_fallback";
|
|
|
}
|
|
|
}
|
|
|
} else {
|
|
|
@@ -162,16 +163,16 @@ public:
|
|
|
std::cout << "Model is in unknown directory '" << parentDirName << "', attempting architecture detection as fallback" << std::endl;
|
|
|
}
|
|
|
detectionSource = "architecture_fallback";
|
|
|
-
|
|
|
+
|
|
|
try {
|
|
|
- detectionResult = ModelDetector::detectModel(modelPath);
|
|
|
+ detectionResult = ModelDetector::detectModel(modelPath);
|
|
|
detectionSuccessful = true;
|
|
|
if (params.verbose) {
|
|
|
std::cout << "Fallback detection found architecture: " << detectionResult.architectureName << std::endl;
|
|
|
}
|
|
|
} catch (const std::exception& e) {
|
|
|
std::cerr << "Warning: Fallback model detection failed: " << e.what() << ". Using default loading method." << std::endl;
|
|
|
- detectionResult.architecture = ModelArchitecture::UNKNOWN;
|
|
|
+ detectionResult.architecture = ModelArchitecture::UNKNOWN;
|
|
|
detectionResult.architectureName = "Unknown";
|
|
|
}
|
|
|
|
|
|
@@ -202,21 +203,21 @@ public:
|
|
|
break;
|
|
|
}
|
|
|
} else {
|
|
|
- useDiffusionModelPath = false; // Default fallback
|
|
|
- detectionSource = "default_fallback";
|
|
|
+ useDiffusionModelPath = false; // Default fallback
|
|
|
+ detectionSource = "default_fallback";
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// Set the appropriate model path based on folder location or fallback detection
|
|
|
if (useDiffusionModelPath) {
|
|
|
ctxParams.diffusion_model_path = persistentModelPath.c_str();
|
|
|
- ctxParams.model_path = nullptr; // Clear the traditional path
|
|
|
+ ctxParams.model_path = nullptr; // Clear the traditional path
|
|
|
if (params.verbose) {
|
|
|
std::cout << "Using diffusion_model_path (source: " << detectionSource << ")" << std::endl;
|
|
|
}
|
|
|
} else {
|
|
|
- ctxParams.model_path = persistentModelPath.c_str();
|
|
|
- ctxParams.diffusion_model_path = nullptr; // Clear the modern path
|
|
|
+ ctxParams.model_path = persistentModelPath.c_str();
|
|
|
+ ctxParams.diffusion_model_path = nullptr; // Clear the modern path
|
|
|
if (params.verbose) {
|
|
|
std::cout << "Using model_path (source: " << detectionSource << ")" << std::endl;
|
|
|
}
|
|
|
@@ -245,7 +246,7 @@ public:
|
|
|
} else {
|
|
|
if (params.verbose) {
|
|
|
std::cout << "VAE file not found: " << std::filesystem::absolute(persistentVaePath)
|
|
|
- << " - continuing without VAE" << std::endl;
|
|
|
+ << " - continuing without VAE" << std::endl;
|
|
|
}
|
|
|
ctxParams.vae_path = nullptr;
|
|
|
}
|
|
|
@@ -276,13 +277,13 @@ public:
|
|
|
}
|
|
|
|
|
|
// Set performance parameters
|
|
|
- ctxParams.n_threads = params.nThreads;
|
|
|
+ ctxParams.n_threads = params.nThreads;
|
|
|
ctxParams.offload_params_to_cpu = params.offloadParamsToCpu;
|
|
|
- ctxParams.keep_clip_on_cpu = params.clipOnCpu;
|
|
|
- ctxParams.keep_vae_on_cpu = params.vaeOnCpu;
|
|
|
- ctxParams.diffusion_flash_attn = params.diffusionFlashAttn;
|
|
|
+ ctxParams.keep_clip_on_cpu = params.clipOnCpu;
|
|
|
+ ctxParams.keep_vae_on_cpu = params.vaeOnCpu;
|
|
|
+ ctxParams.diffusion_flash_attn = params.diffusionFlashAttn;
|
|
|
ctxParams.diffusion_conv_direct = params.diffusionConvDirect;
|
|
|
- ctxParams.vae_conv_direct = params.vaeConvDirect;
|
|
|
+ ctxParams.vae_conv_direct = params.vaeConvDirect;
|
|
|
|
|
|
// Set model type
|
|
|
ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType);
|
|
|
@@ -306,7 +307,7 @@ public:
|
|
|
sd_ctx_params_init(&ctxParams);
|
|
|
|
|
|
// Set fallback model path using persistent string
|
|
|
- ctxParams.model_path = persistentModelPath.c_str();
|
|
|
+ ctxParams.model_path = persistentModelPath.c_str();
|
|
|
ctxParams.diffusion_model_path = nullptr;
|
|
|
|
|
|
// Re-apply other parameters using persistent strings
|
|
|
@@ -338,13 +339,13 @@ public:
|
|
|
}
|
|
|
|
|
|
// Re-apply performance parameters
|
|
|
- ctxParams.n_threads = params.nThreads;
|
|
|
+ ctxParams.n_threads = params.nThreads;
|
|
|
ctxParams.offload_params_to_cpu = params.offloadParamsToCpu;
|
|
|
- ctxParams.keep_clip_on_cpu = params.clipOnCpu;
|
|
|
- ctxParams.keep_vae_on_cpu = params.vaeOnCpu;
|
|
|
- ctxParams.diffusion_flash_attn = params.diffusionFlashAttn;
|
|
|
+ ctxParams.keep_clip_on_cpu = params.clipOnCpu;
|
|
|
+ ctxParams.keep_vae_on_cpu = params.vaeOnCpu;
|
|
|
+ ctxParams.diffusion_flash_attn = params.diffusionFlashAttn;
|
|
|
ctxParams.diffusion_conv_direct = params.diffusionConvDirect;
|
|
|
- ctxParams.vae_conv_direct = params.vaeConvDirect;
|
|
|
+ ctxParams.vae_conv_direct = params.vaeConvDirect;
|
|
|
|
|
|
// Re-apply model type
|
|
|
ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType);
|
|
|
@@ -357,33 +358,33 @@ public:
|
|
|
if (!sdContext) {
|
|
|
lastError = "Failed to create stable-diffusion context with both diffusion_model_path and model_path fallback";
|
|
|
std::cerr << "Error: " << lastError << std::endl;
|
|
|
-
|
|
|
+
|
|
|
// Additional fallback: try with minimal parameters for GGUF models
|
|
|
if (modelFileName.find(".gguf") != std::string::npos || modelFileName.find(".ggml") != std::string::npos) {
|
|
|
if (params.verbose) {
|
|
|
std::cout << "Detected GGUF/GGML model, attempting minimal parameter fallback..." << std::endl;
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
// Re-initialize with minimal parameters
|
|
|
sd_ctx_params_init(&ctxParams);
|
|
|
- ctxParams.model_path = persistentModelPath.c_str();
|
|
|
+ ctxParams.model_path = persistentModelPath.c_str();
|
|
|
ctxParams.diffusion_model_path = nullptr;
|
|
|
-
|
|
|
+
|
|
|
// Set only essential parameters for GGUF
|
|
|
ctxParams.n_threads = params.nThreads;
|
|
|
- ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType);
|
|
|
-
|
|
|
+ ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType);
|
|
|
+
|
|
|
if (params.verbose) {
|
|
|
std::cout << "Attempting to create context with minimal GGUF parameters..." << std::endl;
|
|
|
}
|
|
|
sdContext = new_sd_ctx(&ctxParams);
|
|
|
-
|
|
|
+
|
|
|
if (!sdContext) {
|
|
|
lastError = "Failed to create stable-diffusion context even with minimal GGUF parameters";
|
|
|
std::cerr << "Error: " << lastError << std::endl;
|
|
|
return false;
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
if (params.verbose) {
|
|
|
std::cout << "Successfully loaded GGUF model with minimal parameters: " << absModelPath << std::endl;
|
|
|
}
|
|
|
@@ -401,26 +402,26 @@ public:
|
|
|
if (params.verbose) {
|
|
|
std::cout << "Detected GGUF/GGML model, attempting minimal parameter fallback..." << std::endl;
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
// Re-initialize with minimal parameters
|
|
|
sd_ctx_params_init(&ctxParams);
|
|
|
ctxParams.model_path = persistentModelPath.c_str();
|
|
|
-
|
|
|
+
|
|
|
// Set only essential parameters for GGUF
|
|
|
ctxParams.n_threads = params.nThreads;
|
|
|
- ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType);
|
|
|
-
|
|
|
+ ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType);
|
|
|
+
|
|
|
if (params.verbose) {
|
|
|
std::cout << "Attempting to create context with minimal GGUF parameters..." << std::endl;
|
|
|
}
|
|
|
sdContext = new_sd_ctx(&ctxParams);
|
|
|
-
|
|
|
+
|
|
|
if (!sdContext) {
|
|
|
lastError = "Failed to create stable-diffusion context even with minimal GGUF parameters";
|
|
|
std::cerr << "Error: " << lastError << std::endl;
|
|
|
return false;
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
if (params.verbose) {
|
|
|
std::cout << "Successfully loaded GGUF model with minimal parameters: " << absModelPath << std::endl;
|
|
|
}
|
|
|
@@ -473,7 +474,6 @@ public:
|
|
|
const StableDiffusionWrapper::GenerationParams& params,
|
|
|
StableDiffusionWrapper::ProgressCallback progressCallback,
|
|
|
void* userData) {
|
|
|
-
|
|
|
std::vector<StableDiffusionWrapper::GeneratedImage> results;
|
|
|
|
|
|
if (!sdContext) {
|
|
|
@@ -488,22 +488,22 @@ public:
|
|
|
sd_img_gen_params_init(&genParams);
|
|
|
|
|
|
// Set basic parameters
|
|
|
- genParams.prompt = params.prompt.c_str();
|
|
|
- genParams.negative_prompt = params.negativePrompt.c_str();
|
|
|
- genParams.width = params.width;
|
|
|
- genParams.height = params.height;
|
|
|
+ genParams.prompt = params.prompt.c_str();
|
|
|
+ genParams.negative_prompt = params.negativePrompt.c_str();
|
|
|
+ genParams.width = params.width;
|
|
|
+ genParams.height = params.height;
|
|
|
genParams.sample_params.sample_steps = params.steps;
|
|
|
- genParams.seed = params.seed;
|
|
|
- genParams.batch_count = params.batchCount;
|
|
|
+ genParams.seed = params.seed;
|
|
|
+ genParams.batch_count = params.batchCount;
|
|
|
|
|
|
// Set sampling parameters
|
|
|
- genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod);
|
|
|
- genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler);
|
|
|
+ genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod);
|
|
|
+ genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler);
|
|
|
genParams.sample_params.guidance.txt_cfg = params.cfgScale;
|
|
|
|
|
|
// Set advanced parameters
|
|
|
genParams.clip_skip = params.clipSkip;
|
|
|
- genParams.strength = params.strength;
|
|
|
+ genParams.strength = params.strength;
|
|
|
|
|
|
// Set progress callback if provided
|
|
|
// Track callback data to ensure proper cleanup
|
|
|
@@ -515,7 +515,8 @@ public:
|
|
|
if (callbackData) {
|
|
|
callbackData->first(step, steps, time, callbackData->second);
|
|
|
}
|
|
|
- }, callbackData);
|
|
|
+ },
|
|
|
+ callbackData);
|
|
|
}
|
|
|
|
|
|
// Generate the image
|
|
|
@@ -523,16 +524,16 @@ public:
|
|
|
|
|
|
// Clear and clean up progress callback - FIX: Wait for any pending callbacks
|
|
|
sd_set_progress_callback(nullptr, nullptr);
|
|
|
-
|
|
|
+
|
|
|
// Add a small delay to ensure any in-flight callbacks complete before cleanup
|
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
|
|
-
|
|
|
+
|
|
|
if (callbackData) {
|
|
|
delete callbackData;
|
|
|
callbackData = nullptr;
|
|
|
}
|
|
|
|
|
|
- auto endTime = std::chrono::high_resolution_clock::now();
|
|
|
+ auto endTime = std::chrono::high_resolution_clock::now();
|
|
|
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
|
|
|
|
|
|
if (!sdImages) {
|
|
|
@@ -543,10 +544,10 @@ public:
|
|
|
// Convert stable-diffusion images to our format
|
|
|
for (int i = 0; i < params.batchCount; i++) {
|
|
|
StableDiffusionWrapper::GeneratedImage image;
|
|
|
- image.width = sdImages[i].width;
|
|
|
- image.height = sdImages[i].height;
|
|
|
- image.channels = sdImages[i].channel;
|
|
|
- image.seed = params.seed;
|
|
|
+ image.width = sdImages[i].width;
|
|
|
+ image.height = sdImages[i].height;
|
|
|
+ image.channels = sdImages[i].channel;
|
|
|
+ image.seed = params.seed;
|
|
|
image.generationTime = duration.count();
|
|
|
|
|
|
// Copy image data
|
|
|
@@ -580,7 +581,6 @@ public:
|
|
|
int inputHeight,
|
|
|
StableDiffusionWrapper::ProgressCallback progressCallback,
|
|
|
void* userData) {
|
|
|
-
|
|
|
std::vector<StableDiffusionWrapper::GeneratedImage> results;
|
|
|
|
|
|
if (!sdContext) {
|
|
|
@@ -595,18 +595,18 @@ public:
|
|
|
sd_img_gen_params_init(&genParams);
|
|
|
|
|
|
// Set basic parameters
|
|
|
- genParams.prompt = params.prompt.c_str();
|
|
|
- genParams.negative_prompt = params.negativePrompt.c_str();
|
|
|
- genParams.width = params.width;
|
|
|
- genParams.height = params.height;
|
|
|
+ genParams.prompt = params.prompt.c_str();
|
|
|
+ genParams.negative_prompt = params.negativePrompt.c_str();
|
|
|
+ genParams.width = params.width;
|
|
|
+ genParams.height = params.height;
|
|
|
genParams.sample_params.sample_steps = params.steps;
|
|
|
- genParams.seed = params.seed;
|
|
|
- genParams.batch_count = params.batchCount;
|
|
|
- genParams.strength = params.strength;
|
|
|
+ genParams.seed = params.seed;
|
|
|
+ genParams.batch_count = params.batchCount;
|
|
|
+ genParams.strength = params.strength;
|
|
|
|
|
|
// Set sampling parameters
|
|
|
- genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod);
|
|
|
- genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler);
|
|
|
+ genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod);
|
|
|
+ genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler);
|
|
|
genParams.sample_params.guidance.txt_cfg = params.cfgScale;
|
|
|
|
|
|
// Set advanced parameters
|
|
|
@@ -614,10 +614,10 @@ public:
|
|
|
|
|
|
// Set input image
|
|
|
sd_image_t initImage;
|
|
|
- initImage.width = inputWidth;
|
|
|
- initImage.height = inputHeight;
|
|
|
- initImage.channel = 3; // RGB
|
|
|
- initImage.data = const_cast<uint8_t*>(inputData.data());
|
|
|
+ initImage.width = inputWidth;
|
|
|
+ initImage.height = inputHeight;
|
|
|
+ initImage.channel = 3; // RGB
|
|
|
+ initImage.data = const_cast<uint8_t*>(inputData.data());
|
|
|
genParams.init_image = initImage;
|
|
|
|
|
|
// Set progress callback if provided
|
|
|
@@ -630,7 +630,8 @@ public:
|
|
|
if (callbackData) {
|
|
|
callbackData->first(step, steps, time, callbackData->second);
|
|
|
}
|
|
|
- }, callbackData);
|
|
|
+ },
|
|
|
+ callbackData);
|
|
|
}
|
|
|
|
|
|
// Generate the image
|
|
|
@@ -638,16 +639,16 @@ public:
|
|
|
|
|
|
// Clear and clean up progress callback - FIX: Wait for any pending callbacks
|
|
|
sd_set_progress_callback(nullptr, nullptr);
|
|
|
-
|
|
|
+
|
|
|
// Add a small delay to ensure any in-flight callbacks complete before cleanup
|
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
|
|
-
|
|
|
+
|
|
|
if (callbackData) {
|
|
|
delete callbackData;
|
|
|
callbackData = nullptr;
|
|
|
}
|
|
|
|
|
|
- auto endTime = std::chrono::high_resolution_clock::now();
|
|
|
+ auto endTime = std::chrono::high_resolution_clock::now();
|
|
|
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
|
|
|
|
|
|
if (!sdImages) {
|
|
|
@@ -658,10 +659,10 @@ public:
|
|
|
// Convert stable-diffusion images to our format
|
|
|
for (int i = 0; i < params.batchCount; i++) {
|
|
|
StableDiffusionWrapper::GeneratedImage image;
|
|
|
- image.width = sdImages[i].width;
|
|
|
- image.height = sdImages[i].height;
|
|
|
- image.channels = sdImages[i].channel;
|
|
|
- image.seed = params.seed;
|
|
|
+ image.width = sdImages[i].width;
|
|
|
+ image.height = sdImages[i].height;
|
|
|
+ image.channels = sdImages[i].channel;
|
|
|
+ image.seed = params.seed;
|
|
|
image.generationTime = duration.count();
|
|
|
|
|
|
// Copy image data
|
|
|
@@ -695,7 +696,6 @@ public:
|
|
|
int controlHeight,
|
|
|
StableDiffusionWrapper::ProgressCallback progressCallback,
|
|
|
void* userData) {
|
|
|
-
|
|
|
std::vector<StableDiffusionWrapper::GeneratedImage> results;
|
|
|
|
|
|
if (!sdContext) {
|
|
|
@@ -710,18 +710,18 @@ public:
|
|
|
sd_img_gen_params_init(&genParams);
|
|
|
|
|
|
// Set basic parameters
|
|
|
- genParams.prompt = params.prompt.c_str();
|
|
|
- genParams.negative_prompt = params.negativePrompt.c_str();
|
|
|
- genParams.width = params.width;
|
|
|
- genParams.height = params.height;
|
|
|
+ genParams.prompt = params.prompt.c_str();
|
|
|
+ genParams.negative_prompt = params.negativePrompt.c_str();
|
|
|
+ genParams.width = params.width;
|
|
|
+ genParams.height = params.height;
|
|
|
genParams.sample_params.sample_steps = params.steps;
|
|
|
- genParams.seed = params.seed;
|
|
|
- genParams.batch_count = params.batchCount;
|
|
|
- genParams.control_strength = params.controlStrength;
|
|
|
+ genParams.seed = params.seed;
|
|
|
+ genParams.batch_count = params.batchCount;
|
|
|
+ genParams.control_strength = params.controlStrength;
|
|
|
|
|
|
// Set sampling parameters
|
|
|
- genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod);
|
|
|
- genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler);
|
|
|
+ genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod);
|
|
|
+ genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler);
|
|
|
genParams.sample_params.guidance.txt_cfg = params.cfgScale;
|
|
|
|
|
|
// Set advanced parameters
|
|
|
@@ -729,10 +729,10 @@ public:
|
|
|
|
|
|
// Set control image
|
|
|
sd_image_t controlImage;
|
|
|
- controlImage.width = controlWidth;
|
|
|
- controlImage.height = controlHeight;
|
|
|
- controlImage.channel = 3; // RGB
|
|
|
- controlImage.data = const_cast<uint8_t*>(controlData.data());
|
|
|
+ controlImage.width = controlWidth;
|
|
|
+ controlImage.height = controlHeight;
|
|
|
+ controlImage.channel = 3; // RGB
|
|
|
+ controlImage.data = const_cast<uint8_t*>(controlData.data());
|
|
|
genParams.control_image = controlImage;
|
|
|
|
|
|
// Set progress callback if provided
|
|
|
@@ -745,7 +745,8 @@ public:
|
|
|
if (callbackData) {
|
|
|
callbackData->first(step, steps, time, callbackData->second);
|
|
|
}
|
|
|
- }, callbackData);
|
|
|
+ },
|
|
|
+ callbackData);
|
|
|
}
|
|
|
|
|
|
// Generate the image
|
|
|
@@ -753,16 +754,16 @@ public:
|
|
|
|
|
|
// Clear and clean up progress callback - FIX: Wait for any pending callbacks
|
|
|
sd_set_progress_callback(nullptr, nullptr);
|
|
|
-
|
|
|
+
|
|
|
// Add a small delay to ensure any in-flight callbacks complete before cleanup
|
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
|
|
-
|
|
|
+
|
|
|
if (callbackData) {
|
|
|
delete callbackData;
|
|
|
callbackData = nullptr;
|
|
|
}
|
|
|
|
|
|
- auto endTime = std::chrono::high_resolution_clock::now();
|
|
|
+ auto endTime = std::chrono::high_resolution_clock::now();
|
|
|
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
|
|
|
|
|
|
if (!sdImages) {
|
|
|
@@ -773,10 +774,10 @@ public:
|
|
|
// Convert stable-diffusion images to our format
|
|
|
for (int i = 0; i < params.batchCount; i++) {
|
|
|
StableDiffusionWrapper::GeneratedImage image;
|
|
|
- image.width = sdImages[i].width;
|
|
|
- image.height = sdImages[i].height;
|
|
|
- image.channels = sdImages[i].channel;
|
|
|
- image.seed = params.seed;
|
|
|
+ image.width = sdImages[i].width;
|
|
|
+ image.height = sdImages[i].height;
|
|
|
+ image.channels = sdImages[i].channel;
|
|
|
+ image.seed = params.seed;
|
|
|
image.generationTime = duration.count();
|
|
|
|
|
|
// Copy image data
|
|
|
@@ -813,7 +814,6 @@ public:
|
|
|
int maskHeight,
|
|
|
StableDiffusionWrapper::ProgressCallback progressCallback,
|
|
|
void* userData) {
|
|
|
-
|
|
|
std::vector<StableDiffusionWrapper::GeneratedImage> results;
|
|
|
|
|
|
if (!sdContext) {
|
|
|
@@ -828,18 +828,18 @@ public:
|
|
|
sd_img_gen_params_init(&genParams);
|
|
|
|
|
|
// Set basic parameters
|
|
|
- genParams.prompt = params.prompt.c_str();
|
|
|
- genParams.negative_prompt = params.negativePrompt.c_str();
|
|
|
- genParams.width = params.width;
|
|
|
- genParams.height = params.height;
|
|
|
+ genParams.prompt = params.prompt.c_str();
|
|
|
+ genParams.negative_prompt = params.negativePrompt.c_str();
|
|
|
+ genParams.width = params.width;
|
|
|
+ genParams.height = params.height;
|
|
|
genParams.sample_params.sample_steps = params.steps;
|
|
|
- genParams.seed = params.seed;
|
|
|
- genParams.batch_count = params.batchCount;
|
|
|
- genParams.strength = params.strength;
|
|
|
+ genParams.seed = params.seed;
|
|
|
+ genParams.batch_count = params.batchCount;
|
|
|
+ genParams.strength = params.strength;
|
|
|
|
|
|
// Set sampling parameters
|
|
|
- genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod);
|
|
|
- genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler);
|
|
|
+ genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod);
|
|
|
+ genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler);
|
|
|
genParams.sample_params.guidance.txt_cfg = params.cfgScale;
|
|
|
|
|
|
// Set advanced parameters
|
|
|
@@ -847,18 +847,18 @@ public:
|
|
|
|
|
|
// Set input image
|
|
|
sd_image_t initImage;
|
|
|
- initImage.width = inputWidth;
|
|
|
- initImage.height = inputHeight;
|
|
|
- initImage.channel = 3; // RGB
|
|
|
- initImage.data = const_cast<uint8_t*>(inputData.data());
|
|
|
+ initImage.width = inputWidth;
|
|
|
+ initImage.height = inputHeight;
|
|
|
+ initImage.channel = 3; // RGB
|
|
|
+ initImage.data = const_cast<uint8_t*>(inputData.data());
|
|
|
genParams.init_image = initImage;
|
|
|
|
|
|
// Set mask image
|
|
|
sd_image_t maskImage;
|
|
|
- maskImage.width = maskWidth;
|
|
|
- maskImage.height = maskHeight;
|
|
|
- maskImage.channel = 1; // Grayscale mask
|
|
|
- maskImage.data = const_cast<uint8_t*>(maskData.data());
|
|
|
+ maskImage.width = maskWidth;
|
|
|
+ maskImage.height = maskHeight;
|
|
|
+ maskImage.channel = 1; // Grayscale mask
|
|
|
+ maskImage.data = const_cast<uint8_t*>(maskData.data());
|
|
|
genParams.mask_image = maskImage;
|
|
|
|
|
|
// Set progress callback if provided
|
|
|
@@ -871,7 +871,8 @@ public:
|
|
|
if (callbackData) {
|
|
|
callbackData->first(step, steps, time, callbackData->second);
|
|
|
}
|
|
|
- }, callbackData);
|
|
|
+ },
|
|
|
+ callbackData);
|
|
|
}
|
|
|
|
|
|
// Generate the image
|
|
|
@@ -879,16 +880,16 @@ public:
|
|
|
|
|
|
// Clear and clean up progress callback - FIX: Wait for any pending callbacks
|
|
|
sd_set_progress_callback(nullptr, nullptr);
|
|
|
-
|
|
|
+
|
|
|
// Add a small delay to ensure any in-flight callbacks complete before cleanup
|
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
|
|
-
|
|
|
+
|
|
|
if (callbackData) {
|
|
|
delete callbackData;
|
|
|
callbackData = nullptr;
|
|
|
}
|
|
|
|
|
|
- auto endTime = std::chrono::high_resolution_clock::now();
|
|
|
+ auto endTime = std::chrono::high_resolution_clock::now();
|
|
|
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
|
|
|
|
|
|
if (!sdImages) {
|
|
|
@@ -899,10 +900,10 @@ public:
|
|
|
// Convert stable-diffusion images to our format
|
|
|
for (int i = 0; i < params.batchCount; i++) {
|
|
|
StableDiffusionWrapper::GeneratedImage image;
|
|
|
- image.width = sdImages[i].width;
|
|
|
- image.height = sdImages[i].height;
|
|
|
- image.channels = sdImages[i].channel;
|
|
|
- image.seed = params.seed;
|
|
|
+ image.width = sdImages[i].width;
|
|
|
+ image.height = sdImages[i].height;
|
|
|
+ image.channels = sdImages[i].channel;
|
|
|
+ image.seed = params.seed;
|
|
|
image.generationTime = duration.count();
|
|
|
|
|
|
// Copy image data
|
|
|
@@ -939,7 +940,6 @@ public:
|
|
|
int nThreads,
|
|
|
bool offloadParamsToCpu,
|
|
|
bool direct) {
|
|
|
-
|
|
|
StableDiffusionWrapper::GeneratedImage result;
|
|
|
|
|
|
auto startTime = std::chrono::high_resolution_clock::now();
|
|
|
@@ -949,8 +949,7 @@ public:
|
|
|
esrganPath.c_str(),
|
|
|
offloadParamsToCpu,
|
|
|
direct,
|
|
|
- nThreads
|
|
|
- );
|
|
|
+ nThreads);
|
|
|
|
|
|
if (!upscalerCtx) {
|
|
|
lastError = "Failed to create upscaler context";
|
|
|
@@ -959,15 +958,15 @@ public:
|
|
|
|
|
|
// Prepare input image
|
|
|
sd_image_t inputImage;
|
|
|
- inputImage.width = inputWidth;
|
|
|
- inputImage.height = inputHeight;
|
|
|
+ inputImage.width = inputWidth;
|
|
|
+ inputImage.height = inputHeight;
|
|
|
inputImage.channel = inputChannels;
|
|
|
- inputImage.data = const_cast<uint8_t*>(inputData.data());
|
|
|
+ inputImage.data = const_cast<uint8_t*>(inputData.data());
|
|
|
|
|
|
// Perform upscaling
|
|
|
sd_image_t upscaled = upscale(upscalerCtx, inputImage, upscaleFactor);
|
|
|
|
|
|
- auto endTime = std::chrono::high_resolution_clock::now();
|
|
|
+ auto endTime = std::chrono::high_resolution_clock::now();
|
|
|
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
|
|
|
|
|
|
if (!upscaled.data) {
|
|
|
@@ -977,10 +976,10 @@ public:
|
|
|
}
|
|
|
|
|
|
// Convert to our format
|
|
|
- result.width = upscaled.width;
|
|
|
- result.height = upscaled.height;
|
|
|
- result.channels = upscaled.channel;
|
|
|
- result.seed = 0; // No seed for upscaling
|
|
|
+ result.width = upscaled.width;
|
|
|
+ result.height = upscaled.height;
|
|
|
+ result.channels = upscaled.channel;
|
|
|
+ result.seed = 0; // No seed for upscaling
|
|
|
result.generationTime = duration.count();
|
|
|
|
|
|
// Copy image data
|
|
|
@@ -1093,7 +1092,7 @@ sd_type_t StableDiffusionWrapper::stringToModelType(const std::string& type) {
|
|
|
} else if (lowerType == "q8_k") {
|
|
|
return SD_TYPE_Q8_K;
|
|
|
} else {
|
|
|
- return SD_TYPE_F16; // Default to F16
|
|
|
+ return SD_TYPE_F16; // Default to F16
|
|
|
}
|
|
|
}
|
|
|
|