|
|
@@ -1,4 +1,5 @@
|
|
|
#include "stable_diffusion_wrapper.h"
|
|
|
+#include "model_detector.h"
|
|
|
#include <iostream>
|
|
|
#include <chrono>
|
|
|
#include <cstring>
|
|
|
@@ -32,12 +33,62 @@ public:
|
|
|
sdContext = nullptr;
|
|
|
}
|
|
|
|
|
|
+ // Detect model architecture
|
|
|
+ ModelDetectionResult detectionResult;
|
|
|
+ bool detectionSuccessful = false;
|
|
|
+ try {
|
|
|
+ detectionResult = ModelDetector::detectModel(modelPath);
|
|
|
+ detectionSuccessful = true;
|
|
|
+ std::cout << "Detected model architecture: " << detectionResult.architectureName
|
|
|
+ << " (" << ModelDetector::getArchitectureName(detectionResult.architecture) << ")" << std::endl;
|
|
|
+ } catch (const std::exception& e) {
|
|
|
+ std::cerr << "Warning: Model detection failed: " << e.what() << ". Falling back to default loading method." << std::endl;
|
|
|
+ detectionResult.architecture = ModelArchitecture::UNKNOWN;
|
|
|
+ detectionResult.architectureName = "Unknown";
|
|
|
+ }
|
|
|
+
|
|
|
// Initialize context parameters
|
|
|
sd_ctx_params_t ctxParams;
|
|
|
sd_ctx_params_init(&ctxParams);
|
|
|
|
|
|
- // Set model path
|
|
|
- ctxParams.model_path = modelPath.c_str();
|
|
|
+ // Determine which model path to use based on detected architecture
|
|
|
+ bool useDiffusionModelPath = false;
|
|
|
+ if (detectionSuccessful) {
|
|
|
+ switch (detectionResult.architecture) {
|
|
|
+ case ModelArchitecture::FLUX_SCHNELL:
|
|
|
+ case ModelArchitecture::FLUX_DEV:
|
|
|
+ case ModelArchitecture::FLUX_CHROMA:
|
|
|
+ case ModelArchitecture::SD_3:
|
|
|
+ case ModelArchitecture::QWEN2VL:
|
|
|
+ // Modern architectures use diffusion_model_path
|
|
|
+ useDiffusionModelPath = true;
|
|
|
+ break;
|
|
|
+ case ModelArchitecture::SD_1_5:
|
|
|
+ case ModelArchitecture::SD_2_1:
|
|
|
+ case ModelArchitecture::SDXL_BASE:
|
|
|
+ case ModelArchitecture::SDXL_REFINER:
|
|
|
+ // Traditional SD models use model_path
|
|
|
+ useDiffusionModelPath = false;
|
|
|
+ break;
|
|
|
+ case ModelArchitecture::UNKNOWN:
|
|
|
+ default:
|
|
|
+ // Unknown architectures fall back to model_path for backward compatibility
|
|
|
+ useDiffusionModelPath = false;
|
|
|
+ std::cout << "Warning: Unknown model architecture detected, using default model_path for backward compatibility" << std::endl;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Set the appropriate model path based on architecture
|
|
|
+ if (useDiffusionModelPath) {
|
|
|
+ ctxParams.diffusion_model_path = modelPath.c_str();
|
|
|
+ ctxParams.model_path = nullptr; // Clear the traditional path
|
|
|
+ std::cout << "Using diffusion_model_path for modern architecture model" << std::endl;
|
|
|
+ } else {
|
|
|
+ ctxParams.model_path = modelPath.c_str();
|
|
|
+ ctxParams.diffusion_model_path = nullptr; // Clear the modern path
|
|
|
+ std::cout << "Using model_path for traditional architecture model" << std::endl;
|
|
|
+ }
|
|
|
|
|
|
// Set optional model paths if provided
|
|
|
if (!params.clipLPath.empty()) {
|
|
|
@@ -78,10 +129,83 @@ public:
|
|
|
sdContext = new_sd_ctx(&ctxParams);
|
|
|
if (!sdContext) {
|
|
|
lastError = "Failed to create stable-diffusion context";
|
|
|
- return false;
|
|
|
+
|
|
|
+ // If we used diffusion_model_path and it failed, try fallback to model_path
|
|
|
+ if (useDiffusionModelPath && detectionSuccessful) {
|
|
|
+ std::cout << "Warning: Failed to load with diffusion_model_path. Attempting fallback to model_path..." << std::endl;
|
|
|
+
|
|
|
+ // Re-initialize context parameters
|
|
|
+ sd_ctx_params_init(&ctxParams);
|
|
|
+
|
|
|
+ // Set fallback model path
|
|
|
+ ctxParams.model_path = modelPath.c_str();
|
|
|
+ ctxParams.diffusion_model_path = nullptr;
|
|
|
+
|
|
|
+ // Re-apply other parameters
|
|
|
+ if (!params.clipLPath.empty()) {
|
|
|
+ ctxParams.clip_l_path = params.clipLPath.c_str();
|
|
|
+ }
|
|
|
+ if (!params.clipGPath.empty()) {
|
|
|
+ ctxParams.clip_g_path = params.clipGPath.c_str();
|
|
|
+ }
|
|
|
+ if (!params.vaePath.empty()) {
|
|
|
+ ctxParams.vae_path = params.vaePath.c_str();
|
|
|
+ }
|
|
|
+ if (!params.taesdPath.empty()) {
|
|
|
+ ctxParams.taesd_path = params.taesdPath.c_str();
|
|
|
+ }
|
|
|
+ if (!params.controlNetPath.empty()) {
|
|
|
+ ctxParams.control_net_path = params.controlNetPath.c_str();
|
|
|
+ }
|
|
|
+ if (!params.loraModelDir.empty()) {
|
|
|
+ ctxParams.lora_model_dir = params.loraModelDir.c_str();
|
|
|
+ }
|
|
|
+ if (!params.embeddingDir.empty()) {
|
|
|
+ ctxParams.embedding_dir = params.embeddingDir.c_str();
|
|
|
+ }
|
|
|
+
|
|
|
+ // Re-apply performance parameters
|
|
|
+ ctxParams.n_threads = params.nThreads;
|
|
|
+ ctxParams.offload_params_to_cpu = params.offloadParamsToCpu;
|
|
|
+ ctxParams.keep_clip_on_cpu = params.clipOnCpu;
|
|
|
+ ctxParams.keep_vae_on_cpu = params.vaeOnCpu;
|
|
|
+ ctxParams.diffusion_flash_attn = params.diffusionFlashAttn;
|
|
|
+ ctxParams.diffusion_conv_direct = params.diffusionConvDirect;
|
|
|
+ ctxParams.vae_conv_direct = params.vaeConvDirect;
|
|
|
+
|
|
|
+ // Re-apply model type
|
|
|
+ ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType);
|
|
|
+
|
|
|
+ // Try creating context again with fallback
|
|
|
+ sdContext = new_sd_ctx(&ctxParams);
|
|
|
+ if (!sdContext) {
|
|
|
+ lastError = "Failed to create stable-diffusion context with both diffusion_model_path and model_path fallback";
|
|
|
+ std::cerr << "Error: " << lastError << std::endl;
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ std::cout << "Successfully loaded model with fallback to model_path: " << modelPath << std::endl;
|
|
|
+ } else {
|
|
|
+ std::cerr << "Error: " << lastError << std::endl;
|
|
|
+ return false;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
+ // Log successful loading with architecture information
|
|
|
std::cout << "Successfully loaded model: " << modelPath << std::endl;
|
|
|
+ if (detectionSuccessful) {
|
|
|
+ std::cout << " Architecture: " << detectionResult.architectureName << std::endl;
|
|
|
+ std::cout << " Loading method: " << (useDiffusionModelPath ? "diffusion_model_path" : "model_path") << std::endl;
|
|
|
+
|
|
|
+ // Log additional model properties if available
|
|
|
+ if (detectionResult.textEncoderDim > 0) {
|
|
|
+ std::cout << " Text encoder dimension: " << detectionResult.textEncoderDim << std::endl;
|
|
|
+ }
|
|
|
+ if (detectionResult.needsVAE) {
|
|
|
+ std::cout << " Requires VAE: " << (detectionResult.recommendedVAE.empty() ? "Yes" : detectionResult.recommendedVAE) << std::endl;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
return true;
|
|
|
}
|
|
|
|