Bladeren bron

Fix model architecture detection regressions for SDXL Refiner and Flux models

- Enhanced SDXL Refiner detection when text encoder dimension is 1280
- Improved Flux model detection to prioritize Flux-specific patterns over SDXL
- Refined detection priority order: Flux/SD3/Qwen patterns first, then SDXL, then text encoder dimensions
- Maintained original fix for SD1.5 models with high UNet channel counts
- Added additional logic to distinguish between SDXL Base, SDXL Refiner, and Flux when text encoder dimension is 1280

Resolves regressions where SDXL Refiner models were detected as SDXL Base
and some Flux Dev models were detected as SDXL Base despite having 1280-dimensional text encoders.
Fszontagh 3 maanden geleden
bovenliggende
commit
8f8907cf2b
1 gewijzigde bestanden met toevoegingen van 76 en 4 verwijderingen
  1. 76 4
      src/model_detector.cpp

+ 76 - 4
src/model_detector.cpp

@@ -264,7 +264,7 @@ ModelArchitecture ModelDetector::analyzeArchitecture(
         }
     }
 
-    // Determine architecture based on analysis
+    // Determine architecture based on analysis with improved priority
     if (hasFluxStructure) {
         // Check for Chroma variant (unlocked Flux)
         if (lowerFilename.find("chroma") != std::string::npos) {
@@ -284,15 +284,34 @@ ModelArchitecture ModelDetector::analyzeArchitecture(
     }
 
     if (hasConditioner || hasTextEncoder2) {
-        // SDXL architecture
+        // SDXL architecture - check for refiner using multiple criteria
         bool hasRefinerMarkers = false;
+        bool hasSmallUNet = false;
+        
+        // Check for refiner markers in tensor names
         for (const auto& [name, _] : tensorInfo) {
             if (name.find("refiner") != std::string::npos) {
                 hasRefinerMarkers = true;
                 break;
             }
         }
-        return hasRefinerMarkers ? ModelArchitecture::SDXL_REFINER : ModelArchitecture::SDXL_BASE;
+        
+        // Check for smaller UNet channel counts (typical of refiner models)
+        if (maxUNetChannels > 0 && maxUNetChannels < 2400) {
+            hasSmallUNet = true;
+        }
+        
+        // Additional check: look for refiner-specific metadata
+        auto refinerIt = metadata.find("refiner");
+        if (refinerIt != metadata.end() && refinerIt->second == "true") {
+            hasRefinerMarkers = true;
+        }
+        
+        // Return refiner if either marker is found, otherwise base
+        if (hasRefinerMarkers || hasSmallUNet) {
+            return ModelArchitecture::SDXL_REFINER;
+        }
+        return ModelArchitecture::SDXL_BASE;
     }
 
     // Check for Qwen2-VL specific patterns before falling back to dimension-based detection
@@ -360,7 +379,22 @@ ModelArchitecture ModelDetector::analyzeArchitecture(
         return ModelArchitecture::QWEN2VL;
     }
 
-    // Check text encoder dimensions first (more reliable than UNet channel count)
+    // Improved detection priority order
+    
+    // First, check for Flux-specific patterns even if text encoder dimension is 1280
+    if (hasFluxStructure) {
+        // This should have been caught earlier, but double-check for edge cases
+        if (lowerFilename.find("chroma") != std::string::npos) {
+            return ModelArchitecture::FLUX_CHROMA;
+        }
+        auto stepsIt = metadata.find("diffusion_steps");
+        if (stepsIt != metadata.end() && stepsIt->second.find("4") != std::string::npos) {
+            return ModelArchitecture::FLUX_SCHNELL;
+        }
+        return ModelArchitecture::FLUX_DEV;
+    }
+    
+    // Check text encoder dimensions with enhanced logic for 1280 dimension
     if (textEncoderOutputDim == 768) {
         return ModelArchitecture::SD_1_5;
     }
@@ -370,6 +404,44 @@ ModelArchitecture ModelDetector::analyzeArchitecture(
     }
     
     if (textEncoderOutputDim == 1280) {
+        // Enhanced 1280 dimension detection: distinguish between SDXL Base, SDXL Refiner, and Flux
+        // Check if we already determined this is Flux (should have been caught earlier)
+        if (hasFluxStructure) {
+            if (lowerFilename.find("chroma") != std::string::npos) {
+                return ModelArchitecture::FLUX_CHROMA;
+            }
+            auto stepsIt = metadata.find("diffusion_steps");
+            if (stepsIt != metadata.end() && stepsIt->second.find("4") != std::string::npos) {
+                return ModelArchitecture::FLUX_SCHNELL;
+            }
+            return ModelArchitecture::FLUX_DEV;
+        }
+        
+        // Check for SDXL Refiner indicators
+        bool hasRefinerMarkers = false;
+        bool hasSmallUNet = false;
+        
+        for (const auto& [name, _] : tensorInfo) {
+            if (name.find("refiner") != std::string::npos) {
+                hasRefinerMarkers = true;
+                break;
+            }
+        }
+        
+        if (maxUNetChannels > 0 && maxUNetChannels < 2400) {
+            hasSmallUNet = true;
+        }
+        
+        auto refinerIt = metadata.find("refiner");
+        if (refinerIt != metadata.end() && refinerIt->second == "true") {
+            hasRefinerMarkers = true;
+        }
+        
+        if (hasRefinerMarkers || hasSmallUNet) {
+            return ModelArchitecture::SDXL_REFINER;
+        }
+        
+        // Default to SDXL Base for 1280 dimension
         return ModelArchitecture::SDXL_BASE;
     }