|
@@ -295,6 +295,71 @@ ModelArchitecture ModelDetector::analyzeArchitecture(
|
|
|
return hasRefinerMarkers ? ModelArchitecture::SDXL_REFINER : ModelArchitecture::SDXL_BASE;
|
|
return hasRefinerMarkers ? ModelArchitecture::SDXL_REFINER : ModelArchitecture::SDXL_BASE;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ // Check for Qwen2-VL specific patterns before falling back to dimension-based detection
|
|
|
|
|
+ bool hasQwenPatterns = false;
|
|
|
|
|
+
|
|
|
|
|
+ // Check metadata for Qwen pipeline class
|
|
|
|
|
+ auto pipelineIt = metadata.find("_model_name");
|
|
|
|
|
+ if (pipelineIt != metadata.end() && pipelineIt->second.find("QwenImagePipeline") != std::string::npos) {
|
|
|
|
|
+ hasQwenPatterns = true;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Check for Qwen-specific tensor patterns
|
|
|
|
|
+ bool hasTransformerBlocks = false;
|
|
|
|
|
+ bool hasImgMod = false;
|
|
|
|
|
+ bool hasTxtMod = false;
|
|
|
|
|
+ bool hasImgIn = false;
|
|
|
|
|
+ bool hasTxtIn = false;
|
|
|
|
|
+ bool hasProjOut = false;
|
|
|
|
|
+ bool hasVisualBlocks = false;
|
|
|
|
|
+
|
|
|
|
|
+ for (const auto& [name, shape] : tensorInfo) {
|
|
|
|
|
+ // Check for transformer blocks
|
|
|
|
|
+ if (name.find("transformer_blocks") != std::string::npos) {
|
|
|
|
|
+ hasTransformerBlocks = true;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Check for modulation patterns
|
|
|
|
|
+ if (name.find("img_mod") != std::string::npos) {
|
|
|
|
|
+ hasImgMod = true;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (name.find("txt_mod") != std::string::npos) {
|
|
|
|
|
+ hasTxtMod = true;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Check for input patterns
|
|
|
|
|
+ if (name.find("img_in") != std::string::npos) {
|
|
|
|
|
+ hasImgIn = true;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (name.find("txt_in") != std::string::npos) {
|
|
|
|
|
+ hasTxtIn = true;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Check for output projection
|
|
|
|
|
+ if (name.find("proj_out") != std::string::npos) {
|
|
|
|
|
+ hasProjOut = true;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Check for visual blocks (Qwen2-VL structure)
|
|
|
|
|
+ if (name.find("visual.blocks") != std::string::npos) {
|
|
|
|
|
+ hasVisualBlocks = true;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Determine if this is a Qwen model based on multiple patterns
|
|
|
|
|
+ if (hasTransformerBlocks && (hasImgMod || hasTxtMod) && (hasImgIn || hasTxtIn) && hasProjOut) {
|
|
|
|
|
+ hasQwenPatterns = true;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Additional check for visual blocks pattern
|
|
|
|
|
+ if (hasVisualBlocks && (hasImgMod || hasTxtMod)) {
|
|
|
|
|
+ hasQwenPatterns = true;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (hasQwenPatterns) {
|
|
|
|
|
+ return ModelArchitecture::QWEN2VL;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
if (maxUNetChannels >= 2048) {
|
|
if (maxUNetChannels >= 2048) {
|
|
|
return ModelArchitecture::SDXL_BASE;
|
|
return ModelArchitecture::SDXL_BASE;
|
|
|
}
|
|
}
|