| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547 |
- #include "model_detector.h"
- #include <algorithm>
- #include <cstring>
- #include <fstream>
- #include <iostream>
- #include <nlohmann/json.hpp>
- // Helper function for C++17 compatibility (ends_with is C++20)
- static bool endsWith(const std::string& str, const std::string& suffix) {
- if (suffix.size() > str.size())
- return false;
- return str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0;
- }
- ModelDetectionResult ModelDetector::detectModel(const std::string& modelPath) {
- ModelDetectionResult result;
- std::map<std::string, std::vector<int64_t>> tensorInfo;
- // Determine file type and parse accordingly
- bool parsed = false;
- if (endsWith(modelPath, ".safetensors")) {
- parsed = parseSafetensorsHeader(modelPath, result.metadata, tensorInfo);
- } else if (endsWith(modelPath, ".gguf")) {
- parsed = parseGGUFHeader(modelPath, result.metadata, tensorInfo);
- } else if (endsWith(modelPath, ".ckpt") || endsWith(modelPath, ".pt")) {
- // PyTorch pickle files - these require the full PyTorch library to parse safely
- // For now, we cannot detect their architecture without loading the model
- // Return unknown architecture with a note in metadata
- result.metadata["format"] = "pytorch_pickle";
- result.metadata["note"] = "Architecture detection not supported for .ckpt/.pt files";
- return result;
- }
- if (!parsed) {
- return result; // Unknown if we can't parse
- }
- // Store tensor names for reference
- for (const auto& [name, _] : tensorInfo) {
- result.tensorNames.push_back(name);
- }
- // Analyze architecture (pass filename for special detection)
- std::string filename = modelPath.substr(modelPath.find_last_of("/\\") + 1);
- result.architecture = analyzeArchitecture(tensorInfo, result.metadata, filename);
- result.architectureName = getArchitectureName(result.architecture);
- // Set architecture-specific properties and required models
- switch (result.architecture) {
- case ModelArchitecture::SD_1_5:
- result.textEncoderDim = 768;
- result.unetChannels = 1280;
- result.needsVAE = true;
- result.recommendedVAE = "vae-ft-mse-840000-ema-pruned.safetensors";
- result.needsTAESD = true;
- result.suggestedParams["vae_flag"] = "--vae";
- break;
- case ModelArchitecture::SD_2_1:
- result.textEncoderDim = 1024;
- result.unetChannels = 1280;
- result.needsVAE = true;
- result.recommendedVAE = "vae-ft-ema-560000.safetensors";
- result.needsTAESD = true;
- result.suggestedParams["vae_flag"] = "--vae";
- break;
- case ModelArchitecture::SDXL_BASE:
- case ModelArchitecture::SDXL_REFINER:
- result.textEncoderDim = 1280;
- result.unetChannels = 2560;
- result.hasConditioner = true;
- result.needsVAE = true;
- result.recommendedVAE = "sdxl_vae.safetensors";
- result.needsTAESD = true;
- result.suggestedParams["vae_flag"] = "--vae";
- break;
- case ModelArchitecture::FLUX_SCHNELL:
- case ModelArchitecture::FLUX_DEV:
- result.textEncoderDim = 4096;
- result.needsVAE = true;
- result.recommendedVAE = "ae.safetensors";
- // Flux requires CLIP-L and T5XXL
- result.suggestedParams["vae_flag"] = "--vae";
- result.suggestedParams["clip_l_required"] = "clip_l.safetensors";
- result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors";
- result.suggestedParams["clip_l_flag"] = "--clip-l";
- result.suggestedParams["t5xxl_flag"] = "--t5xxl";
- break;
- case ModelArchitecture::FLUX_CHROMA:
- result.textEncoderDim = 4096;
- result.needsVAE = true;
- result.recommendedVAE = "ae.safetensors";
- // Chroma (Flux Unlocked) requires VAE and T5XXL
- result.suggestedParams["vae_flag"] = "--vae";
- result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors";
- result.suggestedParams["t5xxl_flag"] = "--t5xxl";
- break;
- case ModelArchitecture::SD_3:
- result.textEncoderDim = 4096;
- result.needsVAE = true;
- result.recommendedVAE = "sd3_vae.safetensors";
- // SD3 requires CLIP-L, CLIP-G, and T5XXL
- result.suggestedParams["vae_flag"] = "--vae";
- result.suggestedParams["clip_l_required"] = "clip_l.safetensors";
- result.suggestedParams["clip_g_required"] = "clip_g.safetensors";
- result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors";
- result.suggestedParams["clip_l_flag"] = "--clip-l";
- result.suggestedParams["clip_g_flag"] = "--clip-g";
- result.suggestedParams["t5xxl_flag"] = "--t5xxl";
- break;
- case ModelArchitecture::QWEN2VL:
- // Qwen2-VL requires vision and language model components
- result.suggestedParams["qwen2vl_required"] = "qwen2vl.safetensors";
- result.suggestedParams["qwen2vl_vision_required"] = "qwen2vl_vision.safetensors";
- result.suggestedParams["qwen2vl_flag"] = "--qwen2vl";
- result.suggestedParams["qwen2vl_vision_flag"] = "--qwen2vl-vision";
- break;
- default:
- break;
- }
- // Merge with general recommended parameters (width, height, steps, etc.)
- auto generalParams = getRecommendedParams(result.architecture);
- for (const auto& [key, value] : generalParams) {
- // Only add if not already set (preserve architecture-specific flags)
- if (result.suggestedParams.find(key) == result.suggestedParams.end()) {
- result.suggestedParams[key] = value;
- }
- }
- return result;
- }
- bool ModelDetector::parseSafetensorsHeader(
- const std::string& filePath,
- std::map<std::string, std::string>& metadata,
- std::map<std::string, std::vector<int64_t>>& tensorInfo) {
- std::ifstream file(filePath, std::ios::binary);
- if (!file.is_open()) {
- return false;
- }
- // Read header length (first 8 bytes, little-endian uint64)
- uint64_t headerLength = 0;
- file.read(reinterpret_cast<char*>(&headerLength), 8);
- if (file.gcount() != 8) {
- return false;
- }
- // Sanity check: header should be reasonable size (< 100MB)
- if (headerLength == 0 || headerLength > 100 * 1024 * 1024) {
- return false;
- }
- // Read header JSON
- std::vector<char> headerBuffer(headerLength);
- file.read(headerBuffer.data(), headerLength);
- if (file.gcount() != static_cast<std::streamsize>(headerLength)) {
- return false;
- }
- // Parse JSON
- try {
- nlohmann::json headerJson = nlohmann::json::parse(headerBuffer.begin(), headerBuffer.end());
- // Extract metadata if present
- if (headerJson.contains("__metadata__")) {
- auto metadataJson = headerJson["__metadata__"];
- for (auto it = metadataJson.begin(); it != metadataJson.end(); ++it) {
- metadata[it.key()] = it.value().get<std::string>();
- }
- }
- // Extract tensor information
- for (auto it = headerJson.begin(); it != headerJson.end(); ++it) {
- if (it.key() == "__metadata__")
- continue;
- if (it.value().contains("shape")) {
- std::vector<int64_t> shape;
- for (const auto& dim : it.value()["shape"]) {
- shape.push_back(dim.get<int64_t>());
- }
- tensorInfo[it.key()] = shape;
- }
- }
- return true;
- } catch (const std::exception& e) {
- return false;
- }
- }
- ModelArchitecture ModelDetector::analyzeArchitecture(
- const std::map<std::string, std::vector<int64_t>>& tensorInfo,
- const std::map<std::string, std::string>& metadata,
- const std::string& filename) {
- // Check metadata first for explicit architecture hints
- auto modelTypeIt = metadata.find("modelspec.architecture");
- if (modelTypeIt != metadata.end()) {
- const std::string& archName = modelTypeIt->second;
- if (archName.find("stable-diffusion-xl") != std::string::npos) {
- return ModelArchitecture::SDXL_BASE;
- } else if (archName.find("stable-diffusion-v2") != std::string::npos) {
- return ModelArchitecture::SD_2_1;
- } else if (archName.find("stable-diffusion-v1") != std::string::npos) {
- return ModelArchitecture::SD_1_5;
- }
- }
- // Check filename for special variants
- std::string lowerFilename = filename;
- std::transform(lowerFilename.begin(), lowerFilename.end(), lowerFilename.begin(), ::tolower);
- // Analyze tensor structure for architecture detection
- bool hasConditioner = false;
- bool hasTextEncoder2 = false;
- bool hasFluxStructure = false;
- bool hasSD3Structure = false;
- int maxUNetChannels = 0;
- int textEncoderOutputDim = 0;
- for (const auto& [name, shape] : tensorInfo) {
- // Check for SDXL-specific components
- if (name.find("conditioner") != std::string::npos) {
- hasConditioner = true;
- }
- if (name.find("text_encoder_2") != std::string::npos ||
- name.find("cond_stage_model.1") != std::string::npos) {
- hasTextEncoder2 = true;
- }
- // Check for Flux-specific patterns
- if (name.find("double_blocks") != std::string::npos ||
- name.find("single_blocks") != std::string::npos) {
- hasFluxStructure = true;
- }
- // Check for SD3-specific patterns
- if (name.find("joint_blocks") != std::string::npos) {
- hasSD3Structure = true;
- }
- // Analyze UNet structure
- if (name.find("model.diffusion_model") != std::string::npos ||
- name.find("unet") != std::string::npos) {
- if (shape.size() >= 2) {
- maxUNetChannels = std::max(maxUNetChannels, static_cast<int>(shape[0]));
- }
- }
- // Check text encoder dimensions
- if (name.find("cond_stage_model") != std::string::npos ||
- name.find("text_encoder") != std::string::npos) {
- if (name.find("proj") != std::string::npos && shape.size() >= 2) {
- textEncoderOutputDim = std::max(textEncoderOutputDim, static_cast<int>(shape[1]));
- }
- }
- }
- // Determine architecture based on analysis
- if (hasFluxStructure) {
- // Check for Chroma variant (unlocked Flux)
- if (lowerFilename.find("chroma") != std::string::npos) {
- return ModelArchitecture::FLUX_CHROMA;
- }
- // Check if it's Schnell or Dev based on step count hints
- auto stepsIt = metadata.find("diffusion_steps");
- if (stepsIt != metadata.end() && stepsIt->second.find("4") != std::string::npos) {
- return ModelArchitecture::FLUX_SCHNELL;
- }
- return ModelArchitecture::FLUX_DEV;
- }
- if (hasSD3Structure) {
- return ModelArchitecture::SD_3;
- }
- if (hasConditioner || hasTextEncoder2) {
- // SDXL architecture
- bool hasRefinerMarkers = false;
- for (const auto& [name, _] : tensorInfo) {
- if (name.find("refiner") != std::string::npos) {
- hasRefinerMarkers = true;
- break;
- }
- }
- return hasRefinerMarkers ? ModelArchitecture::SDXL_REFINER : ModelArchitecture::SDXL_BASE;
- }
- if (maxUNetChannels >= 2048) {
- return ModelArchitecture::SDXL_BASE;
- }
- // Distinguish between SD1.x and SD2.x by text encoder dimension
- if (textEncoderOutputDim >= 1024 || maxUNetChannels == 1280) {
- return ModelArchitecture::SD_2_1;
- }
- if (textEncoderOutputDim == 768 || maxUNetChannels <= 1280) {
- return ModelArchitecture::SD_1_5;
- }
- return ModelArchitecture::UNKNOWN;
- }
- std::string ModelDetector::getArchitectureName(ModelArchitecture arch) {
- switch (arch) {
- case ModelArchitecture::SD_1_5:
- return "Stable Diffusion 1.5";
- case ModelArchitecture::SD_2_1:
- return "Stable Diffusion 2.1";
- case ModelArchitecture::SDXL_BASE:
- return "Stable Diffusion XL Base";
- case ModelArchitecture::SDXL_REFINER:
- return "Stable Diffusion XL Refiner";
- case ModelArchitecture::FLUX_SCHNELL:
- return "Flux Schnell";
- case ModelArchitecture::FLUX_DEV:
- return "Flux Dev";
- case ModelArchitecture::FLUX_CHROMA:
- return "Flux Chroma (Unlocked)";
- case ModelArchitecture::SD_3:
- return "Stable Diffusion 3";
- case ModelArchitecture::QWEN2VL:
- return "Qwen2-VL";
- default:
- return "Unknown";
- }
- }
- std::map<std::string, std::string> ModelDetector::getRecommendedParams(ModelArchitecture arch) {
- std::map<std::string, std::string> params;
- switch (arch) {
- case ModelArchitecture::SD_1_5:
- params["width"] = "512";
- params["height"] = "512";
- params["cfg_scale"] = "7.5";
- params["steps"] = "20";
- params["sampler"] = "euler_a";
- break;
- case ModelArchitecture::SD_2_1:
- params["width"] = "768";
- params["height"] = "768";
- params["cfg_scale"] = "7.0";
- params["steps"] = "25";
- params["sampler"] = "euler_a";
- break;
- case ModelArchitecture::SDXL_BASE:
- case ModelArchitecture::SDXL_REFINER:
- params["width"] = "1024";
- params["height"] = "1024";
- params["cfg_scale"] = "7.0";
- params["steps"] = "30";
- params["sampler"] = "dpm++2m";
- break;
- case ModelArchitecture::FLUX_SCHNELL:
- params["width"] = "1024";
- params["height"] = "1024";
- params["cfg_scale"] = "1.0";
- params["steps"] = "4";
- params["sampler"] = "euler";
- break;
- case ModelArchitecture::FLUX_DEV:
- params["width"] = "1024";
- params["height"] = "1024";
- params["cfg_scale"] = "1.0";
- params["steps"] = "20";
- params["sampler"] = "euler";
- break;
- case ModelArchitecture::FLUX_CHROMA:
- params["width"] = "1024";
- params["height"] = "1024";
- params["cfg_scale"] = "1.0";
- params["steps"] = "20";
- params["sampler"] = "euler";
- break;
- case ModelArchitecture::SD_3:
- params["width"] = "1024";
- params["height"] = "1024";
- params["cfg_scale"] = "5.0";
- params["steps"] = "28";
- params["sampler"] = "dpm++2m";
- break;
- default:
- break;
- }
- return params;
- }
- bool ModelDetector::parseGGUFHeader(
- const std::string& filePath,
- std::map<std::string, std::string>& metadata,
- std::map<std::string, std::vector<int64_t>>& tensorInfo) {
- std::ifstream file(filePath, std::ios::binary);
- if (!file.is_open()) {
- return false;
- }
- // Read and verify magic number "GGUF"
- char magic[4];
- file.read(magic, 4);
- if (file.gcount() != 4 || std::memcmp(magic, "GGUF", 4) != 0) {
- return false;
- }
- // Read version (uint32)
- uint32_t version;
- file.read(reinterpret_cast<char*>(&version), 4);
- if (file.gcount() != 4) {
- return false;
- }
- // Read tensor count (uint64)
- uint64_t tensorCount;
- file.read(reinterpret_cast<char*>(&tensorCount), 8);
- if (file.gcount() != 8) {
- return false;
- }
- // Read metadata KV count (uint64)
- uint64_t metadataCount;
- file.read(reinterpret_cast<char*>(&metadataCount), 8);
- if (file.gcount() != 8) {
- return false;
- }
- // Helper function to read string
- auto readString = [&file]() -> std::string {
- uint64_t length;
- file.read(reinterpret_cast<char*>(&length), 8);
- if (file.gcount() != 8 || length == 0 || length > 10000) {
- return "";
- }
- std::vector<char> buffer(length);
- file.read(buffer.data(), length);
- if (file.gcount() != static_cast<std::streamsize>(length)) {
- return "";
- }
- return std::string(buffer.begin(), buffer.end());
- };
- // Read metadata key-value pairs
- for (uint64_t i = 0; i < metadataCount && file.good(); ++i) {
- std::string key = readString();
- if (key.empty())
- break;
- // Read value type (uint32)
- uint32_t valueType;
- file.read(reinterpret_cast<char*>(&valueType), 4);
- if (file.gcount() != 4)
- break;
- // Parse value based on type
- std::string value;
- switch (valueType) {
- case 8: // String
- value = readString();
- break;
- case 4: { // Uint32
- uint32_t val;
- file.read(reinterpret_cast<char*>(&val), 4);
- value = std::to_string(val);
- break;
- }
- case 5: { // Int32
- int32_t val;
- file.read(reinterpret_cast<char*>(&val), 4);
- value = std::to_string(val);
- break;
- }
- case 6: { // Float32
- float val;
- file.read(reinterpret_cast<char*>(&val), 4);
- value = std::to_string(val);
- break;
- }
- case 0: { // Uint8
- uint8_t val;
- file.read(reinterpret_cast<char*>(&val), 1);
- value = std::to_string(val);
- break;
- }
- case 1: { // Int8
- int8_t val;
- file.read(reinterpret_cast<char*>(&val), 1);
- value = std::to_string(val);
- break;
- }
- default:
- // Skip unknown types
- file.seekg(8, std::ios::cur);
- continue;
- }
- if (!value.empty()) {
- metadata[key] = value;
- }
- }
- // Read tensor information
- for (uint64_t i = 0; i < tensorCount && file.good(); ++i) {
- std::string tensorName = readString();
- if (tensorName.empty())
- break;
- // Read number of dimensions (uint32)
- uint32_t nDims;
- file.read(reinterpret_cast<char*>(&nDims), 4);
- if (file.gcount() != 4 || nDims > 10)
- break;
- // Read dimensions (uint64 array)
- std::vector<int64_t> shape(nDims);
- for (uint32_t d = 0; d < nDims; ++d) {
- uint64_t dim;
- file.read(reinterpret_cast<char*>(&dim), 8);
- if (file.gcount() != 8)
- break;
- shape[d] = static_cast<int64_t>(dim);
- }
- // Skip type (uint32) and offset (uint64)
- file.seekg(12, std::ios::cur);
- tensorInfo[tensorName] = shape;
- }
- return !tensorInfo.empty();
- }
|