|
@@ -264,7 +264,7 @@ ModelArchitecture ModelDetector::analyzeArchitecture(
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // Determine architecture based on analysis
|
|
|
|
|
|
|
+ // Determine architecture based on analysis with improved priority
|
|
|
if (hasFluxStructure) {
|
|
if (hasFluxStructure) {
|
|
|
// Check for Chroma variant (unlocked Flux)
|
|
// Check for Chroma variant (unlocked Flux)
|
|
|
if (lowerFilename.find("chroma") != std::string::npos) {
|
|
if (lowerFilename.find("chroma") != std::string::npos) {
|
|
@@ -284,15 +284,34 @@ ModelArchitecture ModelDetector::analyzeArchitecture(
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if (hasConditioner || hasTextEncoder2) {
|
|
if (hasConditioner || hasTextEncoder2) {
|
|
|
- // SDXL architecture
|
|
|
|
|
|
|
+ // SDXL architecture - check for refiner using multiple criteria
|
|
|
bool hasRefinerMarkers = false;
|
|
bool hasRefinerMarkers = false;
|
|
|
|
|
+ bool hasSmallUNet = false;
|
|
|
|
|
+
|
|
|
|
|
+ // Check for refiner markers in tensor names
|
|
|
for (const auto& [name, _] : tensorInfo) {
|
|
for (const auto& [name, _] : tensorInfo) {
|
|
|
if (name.find("refiner") != std::string::npos) {
|
|
if (name.find("refiner") != std::string::npos) {
|
|
|
hasRefinerMarkers = true;
|
|
hasRefinerMarkers = true;
|
|
|
break;
|
|
break;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
- return hasRefinerMarkers ? ModelArchitecture::SDXL_REFINER : ModelArchitecture::SDXL_BASE;
|
|
|
|
|
|
|
+
|
|
|
|
|
+ // Check for smaller UNet channel counts (typical of refiner models)
|
|
|
|
|
+ if (maxUNetChannels > 0 && maxUNetChannels < 2400) {
|
|
|
|
|
+ hasSmallUNet = true;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Additional check: look for refiner-specific metadata
|
|
|
|
|
+ auto refinerIt = metadata.find("refiner");
|
|
|
|
|
+ if (refinerIt != metadata.end() && refinerIt->second == "true") {
|
|
|
|
|
+ hasRefinerMarkers = true;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Return refiner if either marker is found, otherwise base
|
|
|
|
|
+ if (hasRefinerMarkers || hasSmallUNet) {
|
|
|
|
|
+ return ModelArchitecture::SDXL_REFINER;
|
|
|
|
|
+ }
|
|
|
|
|
+ return ModelArchitecture::SDXL_BASE;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// Check for Qwen2-VL specific patterns before falling back to dimension-based detection
|
|
// Check for Qwen2-VL specific patterns before falling back to dimension-based detection
|
|
@@ -360,7 +379,22 @@ ModelArchitecture ModelDetector::analyzeArchitecture(
|
|
|
return ModelArchitecture::QWEN2VL;
|
|
return ModelArchitecture::QWEN2VL;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // Check text encoder dimensions first (more reliable than UNet channel count)
|
|
|
|
|
|
|
+ // Improved detection priority order
|
|
|
|
|
+
|
|
|
|
|
+ // First, check for Flux-specific patterns even if text encoder dimension is 1280
|
|
|
|
|
+ if (hasFluxStructure) {
|
|
|
|
|
+ // This should have been caught earlier, but double-check for edge cases
|
|
|
|
|
+ if (lowerFilename.find("chroma") != std::string::npos) {
|
|
|
|
|
+ return ModelArchitecture::FLUX_CHROMA;
|
|
|
|
|
+ }
|
|
|
|
|
+ auto stepsIt = metadata.find("diffusion_steps");
|
|
|
|
|
+ if (stepsIt != metadata.end() && stepsIt->second.find("4") != std::string::npos) {
|
|
|
|
|
+ return ModelArchitecture::FLUX_SCHNELL;
|
|
|
|
|
+ }
|
|
|
|
|
+ return ModelArchitecture::FLUX_DEV;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Check text encoder dimensions with enhanced logic for 1280 dimension
|
|
|
if (textEncoderOutputDim == 768) {
|
|
if (textEncoderOutputDim == 768) {
|
|
|
return ModelArchitecture::SD_1_5;
|
|
return ModelArchitecture::SD_1_5;
|
|
|
}
|
|
}
|
|
@@ -370,6 +404,44 @@ ModelArchitecture ModelDetector::analyzeArchitecture(
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if (textEncoderOutputDim == 1280) {
|
|
if (textEncoderOutputDim == 1280) {
|
|
|
|
|
+ // Enhanced 1280 dimension detection: distinguish between SDXL Base, SDXL Refiner, and Flux
|
|
|
|
|
+ // Check if we already determined this is Flux (should have been caught earlier)
|
|
|
|
|
+ if (hasFluxStructure) {
|
|
|
|
|
+ if (lowerFilename.find("chroma") != std::string::npos) {
|
|
|
|
|
+ return ModelArchitecture::FLUX_CHROMA;
|
|
|
|
|
+ }
|
|
|
|
|
+ auto stepsIt = metadata.find("diffusion_steps");
|
|
|
|
|
+ if (stepsIt != metadata.end() && stepsIt->second.find("4") != std::string::npos) {
|
|
|
|
|
+ return ModelArchitecture::FLUX_SCHNELL;
|
|
|
|
|
+ }
|
|
|
|
|
+ return ModelArchitecture::FLUX_DEV;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Check for SDXL Refiner indicators
|
|
|
|
|
+ bool hasRefinerMarkers = false;
|
|
|
|
|
+ bool hasSmallUNet = false;
|
|
|
|
|
+
|
|
|
|
|
+ for (const auto& [name, _] : tensorInfo) {
|
|
|
|
|
+ if (name.find("refiner") != std::string::npos) {
|
|
|
|
|
+ hasRefinerMarkers = true;
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (maxUNetChannels > 0 && maxUNetChannels < 2400) {
|
|
|
|
|
+ hasSmallUNet = true;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ auto refinerIt = metadata.find("refiner");
|
|
|
|
|
+ if (refinerIt != metadata.end() && refinerIt->second == "true") {
|
|
|
|
|
+ hasRefinerMarkers = true;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (hasRefinerMarkers || hasSmallUNet) {
|
|
|
|
|
+ return ModelArchitecture::SDXL_REFINER;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Default to SDXL Base for 1280 dimension
|
|
|
return ModelArchitecture::SDXL_BASE;
|
|
return ModelArchitecture::SDXL_BASE;
|
|
|
}
|
|
}
|
|
|
|
|
|