ソースを参照

Implement model detection logic in stable_diffusion_wrapper.cpp

- Add ModelDetector header include
- Integrate model detection in loadModel method before setting ctxParams
- Implement logic to choose between model_path and diffusion_model_path based on detected architecture:
  * Traditional SD models (SD 1.5, 2.1, SDXL) use model_path
  * Modern architectures (Flux, SD3, Qwen2VL) use diffusion_model_path
  * Unknown architectures fall back to model_path for backward compatibility
- Add comprehensive error handling and fallback strategies
- Add detailed logging for model detection results and loading methods
- Maintain backward compatibility with existing functionality
- Follow project coding guidelines (no using aliases, Rule of 5, etc.)
Fszontagh 3 ヶ月 前
コミット
8496001fce
1 ファイル変更127 行追加3 行削除
  1. 127 3
      src/stable_diffusion_wrapper.cpp

+ 127 - 3
src/stable_diffusion_wrapper.cpp

@@ -1,4 +1,5 @@
 #include "stable_diffusion_wrapper.h"
+#include "model_detector.h"
 #include <iostream>
 #include <chrono>
 #include <cstring>
@@ -32,12 +33,62 @@ public:
             sdContext = nullptr;
         }
 
+        // Detect model architecture
+        ModelDetectionResult detectionResult;
+        bool detectionSuccessful = false;
+        try {
+            detectionResult = ModelDetector::detectModel(modelPath);
+            detectionSuccessful = true;
+            std::cout << "Detected model architecture: " << detectionResult.architectureName
+                      << " (" << ModelDetector::getArchitectureName(detectionResult.architecture) << ")" << std::endl;
+        } catch (const std::exception& e) {
+            std::cerr << "Warning: Model detection failed: " << e.what() << ". Falling back to default loading method." << std::endl;
+            detectionResult.architecture = ModelArchitecture::UNKNOWN;
+            detectionResult.architectureName = "Unknown";
+        }
+
         // Initialize context parameters
         sd_ctx_params_t ctxParams;
         sd_ctx_params_init(&ctxParams);
 
-        // Set model path
-        ctxParams.model_path = modelPath.c_str();
+        // Determine which model path to use based on detected architecture
+        bool useDiffusionModelPath = false;
+        if (detectionSuccessful) {
+            switch (detectionResult.architecture) {
+                case ModelArchitecture::FLUX_SCHNELL:
+                case ModelArchitecture::FLUX_DEV:
+                case ModelArchitecture::FLUX_CHROMA:
+                case ModelArchitecture::SD_3:
+                case ModelArchitecture::QWEN2VL:
+                    // Modern architectures use diffusion_model_path
+                    useDiffusionModelPath = true;
+                    break;
+                case ModelArchitecture::SD_1_5:
+                case ModelArchitecture::SD_2_1:
+                case ModelArchitecture::SDXL_BASE:
+                case ModelArchitecture::SDXL_REFINER:
+                    // Traditional SD models use model_path
+                    useDiffusionModelPath = false;
+                    break;
+                case ModelArchitecture::UNKNOWN:
+                default:
+                    // Unknown architectures fall back to model_path for backward compatibility
+                    useDiffusionModelPath = false;
+                    std::cout << "Warning: Unknown model architecture detected, using default model_path for backward compatibility" << std::endl;
+                    break;
+            }
+        }
+
+        // Set the appropriate model path based on architecture
+        if (useDiffusionModelPath) {
+            ctxParams.diffusion_model_path = modelPath.c_str();
+            ctxParams.model_path = nullptr; // Clear the traditional path
+            std::cout << "Using diffusion_model_path for modern architecture model" << std::endl;
+        } else {
+            ctxParams.model_path = modelPath.c_str();
+            ctxParams.diffusion_model_path = nullptr; // Clear the modern path
+            std::cout << "Using model_path for traditional architecture model" << std::endl;
+        }
 
         // Set optional model paths if provided
         if (!params.clipLPath.empty()) {
@@ -78,10 +129,83 @@ public:
         sdContext = new_sd_ctx(&ctxParams);
         if (!sdContext) {
             lastError = "Failed to create stable-diffusion context";
-            return false;
+
+            // If we used diffusion_model_path and it failed, try fallback to model_path
+            if (useDiffusionModelPath && detectionSuccessful) {
+                std::cout << "Warning: Failed to load with diffusion_model_path. Attempting fallback to model_path..." << std::endl;
+
+                // Re-initialize context parameters
+                sd_ctx_params_init(&ctxParams);
+
+                // Set fallback model path
+                ctxParams.model_path = modelPath.c_str();
+                ctxParams.diffusion_model_path = nullptr;
+
+                // Re-apply other parameters
+                if (!params.clipLPath.empty()) {
+                    ctxParams.clip_l_path = params.clipLPath.c_str();
+                }
+                if (!params.clipGPath.empty()) {
+                    ctxParams.clip_g_path = params.clipGPath.c_str();
+                }
+                if (!params.vaePath.empty()) {
+                    ctxParams.vae_path = params.vaePath.c_str();
+                }
+                if (!params.taesdPath.empty()) {
+                    ctxParams.taesd_path = params.taesdPath.c_str();
+                }
+                if (!params.controlNetPath.empty()) {
+                    ctxParams.control_net_path = params.controlNetPath.c_str();
+                }
+                if (!params.loraModelDir.empty()) {
+                    ctxParams.lora_model_dir = params.loraModelDir.c_str();
+                }
+                if (!params.embeddingDir.empty()) {
+                    ctxParams.embedding_dir = params.embeddingDir.c_str();
+                }
+
+                // Re-apply performance parameters
+                ctxParams.n_threads = params.nThreads;
+                ctxParams.offload_params_to_cpu = params.offloadParamsToCpu;
+                ctxParams.keep_clip_on_cpu = params.clipOnCpu;
+                ctxParams.keep_vae_on_cpu = params.vaeOnCpu;
+                ctxParams.diffusion_flash_attn = params.diffusionFlashAttn;
+                ctxParams.diffusion_conv_direct = params.diffusionConvDirect;
+                ctxParams.vae_conv_direct = params.vaeConvDirect;
+
+                // Re-apply model type
+                ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType);
+
+                // Try creating context again with fallback
+                sdContext = new_sd_ctx(&ctxParams);
+                if (!sdContext) {
+                    lastError = "Failed to create stable-diffusion context with both diffusion_model_path and model_path fallback";
+                    std::cerr << "Error: " << lastError << std::endl;
+                    return false;
+                }
+
+                std::cout << "Successfully loaded model with fallback to model_path: " << modelPath << std::endl;
+            } else {
+                std::cerr << "Error: " << lastError << std::endl;
+                return false;
+            }
         }
 
+        // Log successful loading with architecture information
         std::cout << "Successfully loaded model: " << modelPath << std::endl;
+        if (detectionSuccessful) {
+            std::cout << "  Architecture: " << detectionResult.architectureName << std::endl;
+            std::cout << "  Loading method: " << (useDiffusionModelPath ? "diffusion_model_path" : "model_path") << std::endl;
+
+            // Log additional model properties if available
+            if (detectionResult.textEncoderDim > 0) {
+                std::cout << "  Text encoder dimension: " << detectionResult.textEncoderDim << std::endl;
+            }
+            if (detectionResult.needsVAE) {
+                std::cout << "  Requires VAE: " << (detectionResult.recommendedVAE.empty() ? "Yes" : detectionResult.recommendedVAE) << std::endl;
+            }
+        }
+
         return true;
     }