#ifndef MODEL_DETECTOR_H #define MODEL_DETECTOR_H #include #include #include /** * @brief Detected model architecture types */ enum class ModelArchitecture { UNKNOWN, SD_1_5, // Stable Diffusion 1.x (768 text encoder) SD_2_1, // Stable Diffusion 2.x (1024 text encoder) SDXL_BASE, // Stable Diffusion XL Base SDXL_REFINER,// Stable Diffusion XL Refiner FLUX_SCHNELL,// Flux Schnell FLUX_DEV, // Flux Dev FLUX_CHROMA, // Flux Chroma (unlocked variant) SD_3, // Stable Diffusion 3 QWEN2VL, // Qwen2-VL vision-language model }; /** * @brief Model detection result with detailed information */ struct ModelDetectionResult { ModelArchitecture architecture = ModelArchitecture::UNKNOWN; std::string architectureName; // Model properties int textEncoderDim = 0; // Text encoder dimension (768, 1024, etc.) int unetChannels = 0; // UNet channel count bool hasConditioner = false; // SDXL conditioner bool hasRefiner = false; // SDXL refiner components // Required auxiliary models bool needsVAE = false; std::string recommendedVAE; // Recommended VAE model bool needsTAESD = false; // Can use TAESD for faster preview // Loading parameters std::map suggestedParams; // Raw metadata std::map metadata; std::vector tensorNames; }; /** * @brief Model detector class for analyzing model files */ class ModelDetector { public: /** * @brief Detect model architecture from file * * @param modelPath Path to model file (.safetensors, .ckpt, etc.) * @return ModelDetectionResult Detection result with architecture info */ static ModelDetectionResult detectModel(const std::string& modelPath); /** * @brief Parse safetensors file header * * @param filePath Path to safetensors file * @param metadata Output: metadata from header * @param tensorInfo Output: tensor names and shapes * @return true if successfully parsed, false otherwise */ static bool parseSafetensorsHeader( const std::string& filePath, std::map& metadata, std::map>& tensorInfo ); /** * @brief Parse GGUF file header * * @param filePath Path to GGUF file * @param metadata Output: metadata from header * @param tensorInfo Output: tensor names and dimensions * @return true if successfully parsed, false otherwise */ static bool parseGGUFHeader( const std::string& filePath, std::map& metadata, std::map>& tensorInfo ); /** * @brief Analyze tensor structure to determine architecture * * @param tensorInfo Map of tensor names to shapes * @param metadata Model metadata * @param filename Model filename (for special variant detection) * @return ModelArchitecture Detected architecture */ static ModelArchitecture analyzeArchitecture( const std::map>& tensorInfo, const std::map& metadata, const std::string& filename = "" ); /** * @brief Get architecture name as string * * @param arch Model architecture enum * @return std::string Human-readable architecture name */ static std::string getArchitectureName(ModelArchitecture arch); /** * @brief Get recommended loading parameters for architecture * * @param arch Model architecture * @return std::map Recommended parameters */ static std::map getRecommendedParams(ModelArchitecture arch); }; #endif // MODEL_DETECTOR_H