#include "model_detector.h" #include #include #include #include #include // Helper function for C++17 compatibility (ends_with is C++20) static bool endsWith(const std::string& str, const std::string& suffix) { if (suffix.size() > str.size()) return false; return str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0; } ModelDetectionResult ModelDetector::detectModel(const std::string& modelPath) { ModelDetectionResult result; std::map> tensorInfo; // Determine file type and parse accordingly bool parsed = false; if (endsWith(modelPath, ".safetensors")) { parsed = parseSafetensorsHeader(modelPath, result.metadata, tensorInfo); } else if (endsWith(modelPath, ".gguf")) { parsed = parseGGUFHeader(modelPath, result.metadata, tensorInfo); } else if (endsWith(modelPath, ".ckpt") || endsWith(modelPath, ".pt")) { // PyTorch pickle files - these require the full PyTorch library to parse safely // For now, we cannot detect their architecture without loading the model // Return unknown architecture with a note in metadata result.metadata["format"] = "pytorch_pickle"; result.metadata["note"] = "Architecture detection not supported for .ckpt/.pt files"; return result; } if (!parsed) { return result; // Unknown if we can't parse } // Store tensor names for reference for (const auto& [name, _] : tensorInfo) { result.tensorNames.push_back(name); } // Analyze architecture (pass filename for special detection) std::string filename = modelPath.substr(modelPath.find_last_of("/\\") + 1); result.architecture = analyzeArchitecture(tensorInfo, result.metadata, filename); result.architectureName = getArchitectureName(result.architecture); // Set architecture-specific properties and required models switch (result.architecture) { case ModelArchitecture::SD_1_5: result.textEncoderDim = 768; result.unetChannels = 1280; result.needsVAE = true; result.recommendedVAE = "vae-ft-mse-840000-ema-pruned.safetensors"; result.needsTAESD = true; result.suggestedParams["vae_flag"] = "--vae"; break; case ModelArchitecture::SD_2_1: result.textEncoderDim = 1024; result.unetChannels = 1280; result.needsVAE = true; result.recommendedVAE = "vae-ft-ema-560000.safetensors"; result.needsTAESD = true; result.suggestedParams["vae_flag"] = "--vae"; break; case ModelArchitecture::SDXL_BASE: case ModelArchitecture::SDXL_REFINER: result.textEncoderDim = 1280; result.unetChannels = 2560; result.hasConditioner = true; result.needsVAE = true; result.recommendedVAE = "sdxl_vae.safetensors"; result.needsTAESD = true; result.suggestedParams["vae_flag"] = "--vae"; break; case ModelArchitecture::FLUX_SCHNELL: case ModelArchitecture::FLUX_DEV: result.textEncoderDim = 4096; result.needsVAE = true; result.recommendedVAE = "ae.safetensors"; // Flux requires CLIP-L and T5XXL result.suggestedParams["vae_flag"] = "--vae"; result.suggestedParams["clip_l_required"] = "clip_l.safetensors"; result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors"; result.suggestedParams["clip_l_flag"] = "--clip-l"; result.suggestedParams["t5xxl_flag"] = "--t5xxl"; break; case ModelArchitecture::FLUX_CHROMA: result.textEncoderDim = 4096; result.needsVAE = true; result.recommendedVAE = "ae.safetensors"; // Chroma (Flux Unlocked) requires VAE and T5XXL result.suggestedParams["vae_flag"] = "--vae"; result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors"; result.suggestedParams["t5xxl_flag"] = "--t5xxl"; break; case ModelArchitecture::SD_3: result.textEncoderDim = 4096; result.needsVAE = true; result.recommendedVAE = "sd3_vae.safetensors"; // SD3 requires CLIP-L, CLIP-G, and T5XXL result.suggestedParams["vae_flag"] = "--vae"; result.suggestedParams["clip_l_required"] = "clip_l.safetensors"; result.suggestedParams["clip_g_required"] = "clip_g.safetensors"; result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors"; result.suggestedParams["clip_l_flag"] = "--clip-l"; result.suggestedParams["clip_g_flag"] = "--clip-g"; result.suggestedParams["t5xxl_flag"] = "--t5xxl"; break; case ModelArchitecture::QWEN2VL: // Qwen2-VL requires vision and language model components result.suggestedParams["qwen2vl_required"] = "qwen2vl.safetensors"; result.suggestedParams["qwen2vl_vision_required"] = "qwen2vl_vision.safetensors"; result.suggestedParams["qwen2vl_flag"] = "--qwen2vl"; result.suggestedParams["qwen2vl_vision_flag"] = "--qwen2vl-vision"; break; default: break; } // Merge with general recommended parameters (width, height, steps, etc.) auto generalParams = getRecommendedParams(result.architecture); for (const auto& [key, value] : generalParams) { // Only add if not already set (preserve architecture-specific flags) if (result.suggestedParams.find(key) == result.suggestedParams.end()) { result.suggestedParams[key] = value; } } return result; } bool ModelDetector::parseSafetensorsHeader( const std::string& filePath, std::map& metadata, std::map>& tensorInfo) { std::ifstream file(filePath, std::ios::binary); if (!file.is_open()) { return false; } // Read header length (first 8 bytes, little-endian uint64) uint64_t headerLength = 0; file.read(reinterpret_cast(&headerLength), 8); if (file.gcount() != 8) { return false; } // Sanity check: header should be reasonable size (< 100MB) if (headerLength == 0 || headerLength > 100 * 1024 * 1024) { return false; } // Read header JSON std::vector headerBuffer(headerLength); file.read(headerBuffer.data(), headerLength); if (file.gcount() != static_cast(headerLength)) { return false; } // Parse JSON try { nlohmann::json headerJson = nlohmann::json::parse(headerBuffer.begin(), headerBuffer.end()); // Extract metadata if present if (headerJson.contains("__metadata__")) { auto metadataJson = headerJson["__metadata__"]; for (auto it = metadataJson.begin(); it != metadataJson.end(); ++it) { metadata[it.key()] = it.value().get(); } } // Extract tensor information for (auto it = headerJson.begin(); it != headerJson.end(); ++it) { if (it.key() == "__metadata__") continue; if (it.value().contains("shape")) { std::vector shape; for (const auto& dim : it.value()["shape"]) { shape.push_back(dim.get()); } tensorInfo[it.key()] = shape; } } return true; } catch (const std::exception& e) { return false; } } ModelArchitecture ModelDetector::analyzeArchitecture( const std::map>& tensorInfo, const std::map& metadata, const std::string& filename) { // Check metadata first for explicit architecture hints auto modelTypeIt = metadata.find("modelspec.architecture"); if (modelTypeIt != metadata.end()) { const std::string& archName = modelTypeIt->second; if (archName.find("stable-diffusion-xl") != std::string::npos) { return ModelArchitecture::SDXL_BASE; } else if (archName.find("stable-diffusion-v2") != std::string::npos) { return ModelArchitecture::SD_2_1; } else if (archName.find("stable-diffusion-v1") != std::string::npos) { return ModelArchitecture::SD_1_5; } } // Check filename for special variants std::string lowerFilename = filename; std::transform(lowerFilename.begin(), lowerFilename.end(), lowerFilename.begin(), ::tolower); // Analyze tensor structure for architecture detection bool hasConditioner = false; bool hasTextEncoder2 = false; bool hasFluxStructure = false; bool hasSD3Structure = false; int maxUNetChannels = 0; int textEncoderOutputDim = 0; for (const auto& [name, shape] : tensorInfo) { // Check for SDXL-specific components if (name.find("conditioner") != std::string::npos) { hasConditioner = true; } if (name.find("text_encoder_2") != std::string::npos || name.find("cond_stage_model.1") != std::string::npos) { hasTextEncoder2 = true; } // Check for Flux-specific patterns if (name.find("double_blocks") != std::string::npos || name.find("single_blocks") != std::string::npos) { hasFluxStructure = true; } // Check for SD3-specific patterns if (name.find("joint_blocks") != std::string::npos) { hasSD3Structure = true; } // Analyze UNet structure if (name.find("model.diffusion_model") != std::string::npos || name.find("unet") != std::string::npos) { if (shape.size() >= 2) { maxUNetChannels = std::max(maxUNetChannels, static_cast(shape[0])); } } // Check text encoder dimensions if (name.find("cond_stage_model") != std::string::npos || name.find("text_encoder") != std::string::npos) { if (name.find("proj") != std::string::npos && shape.size() >= 2) { textEncoderOutputDim = std::max(textEncoderOutputDim, static_cast(shape[1])); } } } // Determine architecture based on analysis if (hasFluxStructure) { // Check for Chroma variant (unlocked Flux) if (lowerFilename.find("chroma") != std::string::npos) { return ModelArchitecture::FLUX_CHROMA; } // Check if it's Schnell or Dev based on step count hints auto stepsIt = metadata.find("diffusion_steps"); if (stepsIt != metadata.end() && stepsIt->second.find("4") != std::string::npos) { return ModelArchitecture::FLUX_SCHNELL; } return ModelArchitecture::FLUX_DEV; } if (hasSD3Structure) { return ModelArchitecture::SD_3; } if (hasConditioner || hasTextEncoder2) { // SDXL architecture bool hasRefinerMarkers = false; for (const auto& [name, _] : tensorInfo) { if (name.find("refiner") != std::string::npos) { hasRefinerMarkers = true; break; } } return hasRefinerMarkers ? ModelArchitecture::SDXL_REFINER : ModelArchitecture::SDXL_BASE; } if (maxUNetChannels >= 2048) { return ModelArchitecture::SDXL_BASE; } // Distinguish between SD1.x and SD2.x by text encoder dimension if (textEncoderOutputDim >= 1024 || maxUNetChannels == 1280) { return ModelArchitecture::SD_2_1; } if (textEncoderOutputDim == 768 || maxUNetChannels <= 1280) { return ModelArchitecture::SD_1_5; } return ModelArchitecture::UNKNOWN; } std::string ModelDetector::getArchitectureName(ModelArchitecture arch) { switch (arch) { case ModelArchitecture::SD_1_5: return "Stable Diffusion 1.5"; case ModelArchitecture::SD_2_1: return "Stable Diffusion 2.1"; case ModelArchitecture::SDXL_BASE: return "Stable Diffusion XL Base"; case ModelArchitecture::SDXL_REFINER: return "Stable Diffusion XL Refiner"; case ModelArchitecture::FLUX_SCHNELL: return "Flux Schnell"; case ModelArchitecture::FLUX_DEV: return "Flux Dev"; case ModelArchitecture::FLUX_CHROMA: return "Flux Chroma (Unlocked)"; case ModelArchitecture::SD_3: return "Stable Diffusion 3"; case ModelArchitecture::QWEN2VL: return "Qwen2-VL"; default: return "Unknown"; } } std::map ModelDetector::getRecommendedParams(ModelArchitecture arch) { std::map params; switch (arch) { case ModelArchitecture::SD_1_5: params["width"] = "512"; params["height"] = "512"; params["cfg_scale"] = "7.5"; params["steps"] = "20"; params["sampler"] = "euler_a"; break; case ModelArchitecture::SD_2_1: params["width"] = "768"; params["height"] = "768"; params["cfg_scale"] = "7.0"; params["steps"] = "25"; params["sampler"] = "euler_a"; break; case ModelArchitecture::SDXL_BASE: case ModelArchitecture::SDXL_REFINER: params["width"] = "1024"; params["height"] = "1024"; params["cfg_scale"] = "7.0"; params["steps"] = "30"; params["sampler"] = "dpm++2m"; break; case ModelArchitecture::FLUX_SCHNELL: params["width"] = "1024"; params["height"] = "1024"; params["cfg_scale"] = "1.0"; params["steps"] = "4"; params["sampler"] = "euler"; break; case ModelArchitecture::FLUX_DEV: params["width"] = "1024"; params["height"] = "1024"; params["cfg_scale"] = "1.0"; params["steps"] = "20"; params["sampler"] = "euler"; break; case ModelArchitecture::FLUX_CHROMA: params["width"] = "1024"; params["height"] = "1024"; params["cfg_scale"] = "1.0"; params["steps"] = "20"; params["sampler"] = "euler"; break; case ModelArchitecture::SD_3: params["width"] = "1024"; params["height"] = "1024"; params["cfg_scale"] = "5.0"; params["steps"] = "28"; params["sampler"] = "dpm++2m"; break; default: break; } return params; } bool ModelDetector::parseGGUFHeader( const std::string& filePath, std::map& metadata, std::map>& tensorInfo) { std::ifstream file(filePath, std::ios::binary); if (!file.is_open()) { return false; } // Read and verify magic number "GGUF" char magic[4]; file.read(magic, 4); if (file.gcount() != 4 || std::memcmp(magic, "GGUF", 4) != 0) { return false; } // Read version (uint32) uint32_t version; file.read(reinterpret_cast(&version), 4); if (file.gcount() != 4) { return false; } // Read tensor count (uint64) uint64_t tensorCount; file.read(reinterpret_cast(&tensorCount), 8); if (file.gcount() != 8) { return false; } // Read metadata KV count (uint64) uint64_t metadataCount; file.read(reinterpret_cast(&metadataCount), 8); if (file.gcount() != 8) { return false; } // Helper function to read string auto readString = [&file]() -> std::string { uint64_t length; file.read(reinterpret_cast(&length), 8); if (file.gcount() != 8 || length == 0 || length > 10000) { return ""; } std::vector buffer(length); file.read(buffer.data(), length); if (file.gcount() != static_cast(length)) { return ""; } return std::string(buffer.begin(), buffer.end()); }; // Read metadata key-value pairs for (uint64_t i = 0; i < metadataCount && file.good(); ++i) { std::string key = readString(); if (key.empty()) break; // Read value type (uint32) uint32_t valueType; file.read(reinterpret_cast(&valueType), 4); if (file.gcount() != 4) break; // Parse value based on type std::string value; switch (valueType) { case 8: // String value = readString(); break; case 4: { // Uint32 uint32_t val; file.read(reinterpret_cast(&val), 4); value = std::to_string(val); break; } case 5: { // Int32 int32_t val; file.read(reinterpret_cast(&val), 4); value = std::to_string(val); break; } case 6: { // Float32 float val; file.read(reinterpret_cast(&val), 4); value = std::to_string(val); break; } case 0: { // Uint8 uint8_t val; file.read(reinterpret_cast(&val), 1); value = std::to_string(val); break; } case 1: { // Int8 int8_t val; file.read(reinterpret_cast(&val), 1); value = std::to_string(val); break; } default: // Skip unknown types file.seekg(8, std::ios::cur); continue; } if (!value.empty()) { metadata[key] = value; } } // Read tensor information for (uint64_t i = 0; i < tensorCount && file.good(); ++i) { std::string tensorName = readString(); if (tensorName.empty()) break; // Read number of dimensions (uint32) uint32_t nDims; file.read(reinterpret_cast(&nDims), 4); if (file.gcount() != 4 || nDims > 10) break; // Read dimensions (uint64 array) std::vector shape(nDims); for (uint32_t d = 0; d < nDims; ++d) { uint64_t dim; file.read(reinterpret_cast(&dim), 8); if (file.gcount() != 8) break; shape[d] = static_cast(dim); } // Skip type (uint32) and offset (uint64) file.seekg(12, std::ios::cur); tensorInfo[tensorName] = shape; } return !tensorInfo.empty(); }