#include "model_detector.h" #include #include #include #include #include // Helper function for C++17 compatibility (ends_with is C++20) static bool endsWith(const std::string& str, const std::string& suffix) { if (suffix.size() > str.size()) return false; return str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0; } ModelDetectionResult ModelDetector::detectModel(const std::string& modelPath) { ModelDetectionResult result; std::map> tensorInfo; // Determine file type and parse accordingly bool parsed = false; if (endsWith(modelPath, ".safetensors")) { parsed = parseSafetensorsHeader(modelPath, result.metadata, tensorInfo); } else if (endsWith(modelPath, ".gguf")) { parsed = parseGGUFHeader(modelPath, result.metadata, tensorInfo); } else if (endsWith(modelPath, ".ckpt") || endsWith(modelPath, ".pt")) { // PyTorch pickle files - these require the full PyTorch library to parse safely // For now, we cannot detect their architecture without loading the model // Return unknown architecture with a note in metadata result.metadata["format"] = "pytorch_pickle"; result.metadata["note"] = "Architecture detection not supported for .ckpt/.pt files"; return result; } if (!parsed) { return result; // Unknown if we can't parse } // Store tensor names for reference for (const auto& [name, _] : tensorInfo) { result.tensorNames.push_back(name); } // Analyze architecture (pass filename for special detection) std::string filename = modelPath.substr(modelPath.find_last_of("/\\") + 1); result.architecture = analyzeArchitecture(tensorInfo, result.metadata, filename); result.architectureName = getArchitectureName(result.architecture); // Set architecture-specific properties and required models switch (result.architecture) { case ModelArchitecture::SD_1_5: result.textEncoderDim = 512; result.unetChannels = 512; result.needsVAE = false; result.recommendedVAE = "vae-ft-mse-840000-ema-pruned.safetensors"; result.needsTAESD = true; result.suggestedParams["vae_flag"] = "--vae"; break; case ModelArchitecture::SD_2_1: result.textEncoderDim = 1024; result.unetChannels = 1024; result.needsVAE = false; result.recommendedVAE = "vae-ft-ema-560000.safetensors"; result.needsTAESD = true; result.suggestedParams["vae_flag"] = "--vae"; break; case ModelArchitecture::SDXL_BASE: case ModelArchitecture::SDXL_REFINER: result.textEncoderDim = 1024; result.unetChannels = 1024; result.hasConditioner = true; result.needsVAE = false; result.recommendedVAE = "sdxl_vae.safetensors"; result.needsTAESD = true; result.suggestedParams["vae_flag"] = "--vae"; break; case ModelArchitecture::FLUX_SCHNELL: case ModelArchitecture::FLUX_DEV: result.textEncoderDim = 4096; result.needsVAE = true; result.recommendedVAE = "ae.safetensors"; // Flux requires CLIP-L and T5XXL result.suggestedParams["vae_flag"] = "--vae"; result.suggestedParams["clip_l_required"] = "clip_l.safetensors"; result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors"; result.suggestedParams["clip_l_flag"] = "--clip-l"; result.suggestedParams["t5xxl_flag"] = "--t5xxl"; break; case ModelArchitecture::FLUX_CHROMA: result.textEncoderDim = 4096; result.needsVAE = true; result.recommendedVAE = "ae.safetensors"; // Chroma (Flux Unlocked) requires VAE and T5XXL result.suggestedParams["vae_flag"] = "--vae"; result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors"; result.suggestedParams["t5xxl_flag"] = "--t5xxl"; break; case ModelArchitecture::SD_3: result.textEncoderDim = 4096; result.needsVAE = true; result.recommendedVAE = "sd3_vae.safetensors"; // SD3 requires CLIP-L, CLIP-G, and T5XXL result.suggestedParams["vae_flag"] = "--vae"; result.suggestedParams["clip_l_required"] = "clip_l.safetensors"; result.suggestedParams["clip_g_required"] = "clip_g.safetensors"; result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors"; result.suggestedParams["clip_l_flag"] = "--clip-l"; result.suggestedParams["clip_g_flag"] = "--clip-g"; result.suggestedParams["t5xxl_flag"] = "--t5xxl"; break; case ModelArchitecture::QWEN2VL: // Qwen2-VL requires vision and language model components result.suggestedParams["qwen2vl_required"] = "qwen2vl.safetensors"; result.suggestedParams["qwen2vl_vision_required"] = "qwen2vl_vision.safetensors"; result.suggestedParams["qwen2vl_flag"] = "--qwen2vl"; result.suggestedParams["qwen2vl_vision_flag"] = "--qwen2vl-vision"; break; default: break; } // Merge with general recommended parameters (width, height, steps, etc.) auto generalParams = getRecommendedParams(result.architecture); for (const auto& [key, value] : generalParams) { // Only add if not already set (preserve architecture-specific flags) if (result.suggestedParams.find(key) == result.suggestedParams.end()) { result.suggestedParams[key] = value; } } return result; } bool ModelDetector::parseSafetensorsHeader( const std::string& filePath, std::map& metadata, std::map>& tensorInfo) { std::ifstream file(filePath, std::ios::binary); if (!file.is_open()) { return false; } // Read header length (first 8 bytes, little-endian uint64) uint64_t headerLength = 0; file.read(reinterpret_cast(&headerLength), 8); if (file.gcount() != 8) { return false; } // Sanity check: header should be reasonable size (< 100MB) if (headerLength == 0 || headerLength > 100 * 1024 * 1024) { return false; } // Read header JSON std::vector headerBuffer(headerLength); file.read(headerBuffer.data(), headerLength); if (file.gcount() != static_cast(headerLength)) { return false; } // Parse JSON try { nlohmann::json headerJson = nlohmann::json::parse(headerBuffer.begin(), headerBuffer.end()); // Extract metadata if present if (headerJson.contains("__metadata__")) { auto metadataJson = headerJson["__metadata__"]; for (auto it = metadataJson.begin(); it != metadataJson.end(); ++it) { metadata[it.key()] = it.value().get(); } } // Extract tensor information for (auto it = headerJson.begin(); it != headerJson.end(); ++it) { if (it.key() == "__metadata__") continue; if (it.value().contains("shape")) { std::vector shape; for (const auto& dim : it.value()["shape"]) { shape.push_back(dim.get()); } tensorInfo[it.key()] = shape; } } return true; } catch (const std::exception& e) { return false; } } ModelArchitecture ModelDetector::analyzeArchitecture( const std::map>& tensorInfo, const std::map& metadata, const std::string& filename) { // Check metadata first for explicit architecture hints auto modelTypeIt = metadata.find("modelspec.architecture"); if (modelTypeIt != metadata.end()) { const std::string& archName = modelTypeIt->second; if (archName.find("stable-diffusion-xl") != std::string::npos) { return ModelArchitecture::SDXL_BASE; } else if (archName.find("stable-diffusion-v2") != std::string::npos) { return ModelArchitecture::SD_2_1; } else if (archName.find("stable-diffusion-v1") != std::string::npos) { return ModelArchitecture::SD_1_5; } } // Check filename for special variants std::string lowerFilename = filename; std::transform(lowerFilename.begin(), lowerFilename.end(), lowerFilename.begin(), ::tolower); // Analyze tensor structure for architecture detection bool hasConditioner = false; bool hasTextEncoder2 = false; bool hasFluxStructure = false; bool hasSD3Structure = false; int maxUNetChannels = 0; int textEncoderOutputDim = 0; for (const auto& [name, shape] : tensorInfo) { // Check for SDXL-specific components if (name.find("conditioner") != std::string::npos) { hasConditioner = true; } if (name.find("text_encoder_2") != std::string::npos || name.find("cond_stage_model.1") != std::string::npos) { hasTextEncoder2 = true; } // Check for Flux-specific patterns if (name.find("double_blocks") != std::string::npos || name.find("single_blocks") != std::string::npos) { hasFluxStructure = true; } // Check for SD3-specific patterns if (name.find("joint_blocks") != std::string::npos) { hasSD3Structure = true; } // Analyze UNet structure if (name.find("model.diffusion_model") != std::string::npos || name.find("unet") != std::string::npos) { if (shape.size() >= 2) { maxUNetChannels = std::max(maxUNetChannels, static_cast(shape[0])); } } // Check text encoder dimensions if (name.find("cond_stage_model") != std::string::npos || name.find("text_encoder") != std::string::npos) { if (name.find("proj") != std::string::npos && shape.size() >= 2) { textEncoderOutputDim = std::max(textEncoderOutputDim, static_cast(shape[1])); } } } // Determine architecture based on analysis with improved priority if (hasFluxStructure) { // Check for Chroma variant (unlocked Flux) if (lowerFilename.find("chroma") != std::string::npos) { return ModelArchitecture::FLUX_CHROMA; } // Check if it's Schnell or Dev based on step count hints auto stepsIt = metadata.find("diffusion_steps"); if (stepsIt != metadata.end() && stepsIt->second.find("4") != std::string::npos) { return ModelArchitecture::FLUX_SCHNELL; } return ModelArchitecture::FLUX_DEV; } if (hasSD3Structure) { return ModelArchitecture::SD_3; } if (hasConditioner || hasTextEncoder2) { // SDXL architecture - check for refiner using multiple criteria bool hasRefinerMarkers = false; bool hasSmallUNet = false; // Check for refiner markers in tensor names for (const auto& [name, _] : tensorInfo) { if (name.find("refiner") != std::string::npos) { hasRefinerMarkers = true; break; } } // Check for smaller UNet channel counts (typical of refiner models) if (maxUNetChannels > 0 && maxUNetChannels < 2400) { hasSmallUNet = true; } // Additional check: look for refiner-specific metadata auto refinerIt = metadata.find("refiner"); if (refinerIt != metadata.end() && refinerIt->second == "true") { hasRefinerMarkers = true; } // Return refiner if either marker is found, otherwise base if (hasRefinerMarkers || hasSmallUNet) { return ModelArchitecture::SDXL_REFINER; } return ModelArchitecture::SDXL_BASE; } // Check for Qwen2-VL specific patterns before falling back to dimension-based detection bool hasQwenPatterns = false; // Check metadata for Qwen pipeline class auto pipelineIt = metadata.find("_model_name"); if (pipelineIt != metadata.end() && pipelineIt->second.find("QwenImagePipeline") != std::string::npos) { hasQwenPatterns = true; } // Check for Qwen-specific tensor patterns bool hasTransformerBlocks = false; bool hasImgMod = false; bool hasTxtMod = false; bool hasImgIn = false; bool hasTxtIn = false; bool hasProjOut = false; bool hasVisualBlocks = false; for (const auto& [name, shape] : tensorInfo) { // Check for transformer blocks if (name.find("transformer_blocks") != std::string::npos) { hasTransformerBlocks = true; } // Check for modulation patterns if (name.find("img_mod") != std::string::npos) { hasImgMod = true; } if (name.find("txt_mod") != std::string::npos) { hasTxtMod = true; } // Check for input patterns if (name.find("img_in") != std::string::npos) { hasImgIn = true; } if (name.find("txt_in") != std::string::npos) { hasTxtIn = true; } // Check for output projection if (name.find("proj_out") != std::string::npos) { hasProjOut = true; } // Check for visual blocks (Qwen2-VL structure) if (name.find("visual.blocks") != std::string::npos) { hasVisualBlocks = true; } } // Determine if this is a Qwen model based on multiple patterns if (hasTransformerBlocks && (hasImgMod || hasTxtMod) && (hasImgIn || hasTxtIn) && hasProjOut) { hasQwenPatterns = true; } // Additional check for visual blocks pattern if (hasVisualBlocks && (hasImgMod || hasTxtMod)) { hasQwenPatterns = true; } if (hasQwenPatterns) { return ModelArchitecture::QWEN2VL; } // Improved detection priority order // First, check for Flux-specific patterns even if text encoder dimension is 1280 if (hasFluxStructure) { // This should have been caught earlier, but double-check for edge cases if (lowerFilename.find("chroma") != std::string::npos) { return ModelArchitecture::FLUX_CHROMA; } auto stepsIt = metadata.find("diffusion_steps"); if (stepsIt != metadata.end() && stepsIt->second.find("4") != std::string::npos) { return ModelArchitecture::FLUX_SCHNELL; } return ModelArchitecture::FLUX_DEV; } // Check text encoder dimensions with enhanced logic for 1280 dimension if (textEncoderOutputDim == 768) { return ModelArchitecture::SD_1_5; } if (textEncoderOutputDim >= 1024 && textEncoderOutputDim < 1280) { return ModelArchitecture::SD_2_1; } if (textEncoderOutputDim == 1280) { // Enhanced 1280 dimension detection: distinguish between SDXL Base, SDXL Refiner, and Flux // Check if we already determined this is Flux (should have been caught earlier) if (hasFluxStructure) { if (lowerFilename.find("chroma") != std::string::npos) { return ModelArchitecture::FLUX_CHROMA; } auto stepsIt = metadata.find("diffusion_steps"); if (stepsIt != metadata.end() && stepsIt->second.find("4") != std::string::npos) { return ModelArchitecture::FLUX_SCHNELL; } return ModelArchitecture::FLUX_DEV; } // Check for SDXL Refiner indicators bool hasRefinerMarkers = false; bool hasSmallUNet = false; for (const auto& [name, _] : tensorInfo) { if (name.find("refiner") != std::string::npos) { hasRefinerMarkers = true; break; } } if (maxUNetChannels > 0 && maxUNetChannels < 2400) { hasSmallUNet = true; } auto refinerIt = metadata.find("refiner"); if (refinerIt != metadata.end() && refinerIt->second == "true") { hasRefinerMarkers = true; } if (hasRefinerMarkers || hasSmallUNet) { return ModelArchitecture::SDXL_REFINER; } // Default to SDXL Base for 1280 dimension return ModelArchitecture::SDXL_BASE; } // Only use UNet channel count as a last resort when text encoder dimensions are unclear if (maxUNetChannels >= 2048) { return ModelArchitecture::SDXL_BASE; } // Fallback detection based on UNet channels when text encoder info is unavailable if (maxUNetChannels == 1280) { return ModelArchitecture::SD_2_1; } if (maxUNetChannels <= 1280) { return ModelArchitecture::SD_1_5; } return ModelArchitecture::UNKNOWN; } std::string ModelDetector::getArchitectureName(ModelArchitecture arch) { switch (arch) { case ModelArchitecture::SD_1_5: return "Stable Diffusion 1.5"; case ModelArchitecture::SD_2_1: return "Stable Diffusion 2.1"; case ModelArchitecture::SDXL_BASE: return "Stable Diffusion XL Base"; case ModelArchitecture::SDXL_REFINER: return "Stable Diffusion XL Refiner"; case ModelArchitecture::FLUX_SCHNELL: return "Flux Schnell"; case ModelArchitecture::FLUX_DEV: return "Flux Dev"; case ModelArchitecture::FLUX_CHROMA: return "Flux Chroma (Unlocked)"; case ModelArchitecture::SD_3: return "Stable Diffusion 3"; case ModelArchitecture::QWEN2VL: return "Qwen2-VL"; default: return "Unknown"; } } std::map ModelDetector::getRecommendedParams(ModelArchitecture arch) { std::map params; switch (arch) { case ModelArchitecture::SD_1_5: params["width"] = "512"; params["height"] = "512"; params["cfg_scale"] = "7.5"; params["steps"] = "20"; params["sampler"] = "euler_a"; break; case ModelArchitecture::SD_2_1: params["width"] = "768"; params["height"] = "768"; params["cfg_scale"] = "7.0"; params["steps"] = "25"; params["sampler"] = "euler_a"; break; case ModelArchitecture::SDXL_BASE: case ModelArchitecture::SDXL_REFINER: params["width"] = "1024"; params["height"] = "1024"; params["cfg_scale"] = "7.0"; params["steps"] = "30"; params["sampler"] = "dpm++2m"; break; case ModelArchitecture::FLUX_SCHNELL: params["width"] = "1024"; params["height"] = "1024"; params["cfg_scale"] = "1.0"; params["steps"] = "4"; params["sampler"] = "euler"; break; case ModelArchitecture::FLUX_DEV: params["width"] = "1024"; params["height"] = "1024"; params["cfg_scale"] = "1.0"; params["steps"] = "20"; params["sampler"] = "euler"; break; case ModelArchitecture::FLUX_CHROMA: params["width"] = "1024"; params["height"] = "1024"; params["cfg_scale"] = "1.0"; params["steps"] = "20"; params["sampler"] = "euler"; break; case ModelArchitecture::SD_3: params["width"] = "1024"; params["height"] = "1024"; params["cfg_scale"] = "5.0"; params["steps"] = "28"; params["sampler"] = "dpm++2m"; break; default: break; } return params; } bool ModelDetector::parseGGUFHeader( const std::string& filePath, std::map& metadata, std::map>& tensorInfo) { std::ifstream file(filePath, std::ios::binary); if (!file.is_open()) { return false; } // Read and verify magic number "GGUF" char magic[4]; file.read(magic, 4); if (file.gcount() != 4 || std::memcmp(magic, "GGUF", 4) != 0) { return false; } // Read version (uint32) uint32_t version; file.read(reinterpret_cast(&version), 4); if (file.gcount() != 4) { return false; } // Read tensor count (uint64) uint64_t tensorCount; file.read(reinterpret_cast(&tensorCount), 8); if (file.gcount() != 8) { return false; } // Read metadata KV count (uint64) uint64_t metadataCount; file.read(reinterpret_cast(&metadataCount), 8); if (file.gcount() != 8) { return false; } // Helper function to read string auto readString = [&file]() -> std::string { uint64_t length; file.read(reinterpret_cast(&length), 8); if (file.gcount() != 8 || length == 0 || length > 10000) { return ""; } std::vector buffer(length); file.read(buffer.data(), length); if (file.gcount() != static_cast(length)) { return ""; } return std::string(buffer.begin(), buffer.end()); }; // Read metadata key-value pairs for (uint64_t i = 0; i < metadataCount && file.good(); ++i) { std::string key = readString(); if (key.empty()) break; // Read value type (uint32) uint32_t valueType; file.read(reinterpret_cast(&valueType), 4); if (file.gcount() != 4) break; // Parse value based on type std::string value; switch (valueType) { case 8: // String value = readString(); break; case 4: { // Uint32 uint32_t val; file.read(reinterpret_cast(&val), 4); value = std::to_string(val); break; } case 5: { // Int32 int32_t val; file.read(reinterpret_cast(&val), 4); value = std::to_string(val); break; } case 6: { // Float32 float val; file.read(reinterpret_cast(&val), 4); value = std::to_string(val); break; } case 0: { // Uint8 uint8_t val; file.read(reinterpret_cast(&val), 1); value = std::to_string(val); break; } case 1: { // Int8 int8_t val; file.read(reinterpret_cast(&val), 1); value = std::to_string(val); break; } default: // Skip unknown types file.seekg(8, std::ios::cur); continue; } if (!value.empty()) { metadata[key] = value; } } // Read tensor information for (uint64_t i = 0; i < tensorCount && file.good(); ++i) { std::string tensorName = readString(); if (tensorName.empty()) break; // Read number of dimensions (uint32) uint32_t nDims; file.read(reinterpret_cast(&nDims), 4); if (file.gcount() != 4 || nDims > 10) break; // Read dimensions (uint64 array) std::vector shape(nDims); for (uint32_t d = 0; d < nDims; ++d) { uint64_t dim; file.read(reinterpret_cast(&dim), 8); if (file.gcount() != 8) break; shape[d] = static_cast(dim); } // Skip type (uint32) and offset (uint64) file.seekg(12, std::ios::cur); tensorInfo[tensorName] = shape; } return !tensorInfo.empty(); }