Răsfoiți Sursa

Fix Qwen detection

Fszontagh 3 luni în urmă
părinte
comite
ffae92d32e
1 a modificat fișierele cu 65 adăugiri și 0 ștergeri
  1. 65 0
      src/model_detector.cpp

+ 65 - 0
src/model_detector.cpp

@@ -295,6 +295,71 @@ ModelArchitecture ModelDetector::analyzeArchitecture(
         return hasRefinerMarkers ? ModelArchitecture::SDXL_REFINER : ModelArchitecture::SDXL_BASE;
     }
 
+    // Check for Qwen2-VL specific patterns before falling back to dimension-based detection
+    bool hasQwenPatterns = false;
+    
+    // Check metadata for Qwen pipeline class
+    auto pipelineIt = metadata.find("_model_name");
+    if (pipelineIt != metadata.end() && pipelineIt->second.find("QwenImagePipeline") != std::string::npos) {
+        hasQwenPatterns = true;
+    }
+    
+    // Check for Qwen-specific tensor patterns
+    bool hasTransformerBlocks = false;
+    bool hasImgMod = false;
+    bool hasTxtMod = false;
+    bool hasImgIn = false;
+    bool hasTxtIn = false;
+    bool hasProjOut = false;
+    bool hasVisualBlocks = false;
+    
+    for (const auto& [name, shape] : tensorInfo) {
+        // Check for transformer blocks
+        if (name.find("transformer_blocks") != std::string::npos) {
+            hasTransformerBlocks = true;
+        }
+        
+        // Check for modulation patterns
+        if (name.find("img_mod") != std::string::npos) {
+            hasImgMod = true;
+        }
+        if (name.find("txt_mod") != std::string::npos) {
+            hasTxtMod = true;
+        }
+        
+        // Check for input patterns
+        if (name.find("img_in") != std::string::npos) {
+            hasImgIn = true;
+        }
+        if (name.find("txt_in") != std::string::npos) {
+            hasTxtIn = true;
+        }
+        
+        // Check for output projection
+        if (name.find("proj_out") != std::string::npos) {
+            hasProjOut = true;
+        }
+        
+        // Check for visual blocks (Qwen2-VL structure)
+        if (name.find("visual.blocks") != std::string::npos) {
+            hasVisualBlocks = true;
+        }
+    }
+    
+    // Determine if this is a Qwen model based on multiple patterns
+    if (hasTransformerBlocks && (hasImgMod || hasTxtMod) && (hasImgIn || hasTxtIn) && hasProjOut) {
+        hasQwenPatterns = true;
+    }
+    
+    // Additional check for visual blocks pattern
+    if (hasVisualBlocks && (hasImgMod || hasTxtMod)) {
+        hasQwenPatterns = true;
+    }
+    
+    if (hasQwenPatterns) {
+        return ModelArchitecture::QWEN2VL;
+    }
+
     if (maxUNetChannels >= 2048) {
         return ModelArchitecture::SDXL_BASE;
     }