| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123 |
- #ifndef MODEL_DETECTOR_H
- #define MODEL_DETECTOR_H
- #include <string>
- #include <vector>
- #include <map>
- #include <optional>
- /**
- * @brief Detected model architecture types
- */
- enum class ModelArchitecture {
- UNKNOWN,
- SD_1_5, // Stable Diffusion 1.x (768 text encoder)
- SD_2_1, // Stable Diffusion 2.x (1024 text encoder)
- SDXL_BASE, // Stable Diffusion XL Base
- SDXL_REFINER,// Stable Diffusion XL Refiner
- FLUX_SCHNELL,// Flux Schnell
- FLUX_DEV, // Flux Dev
- FLUX_CHROMA, // Flux Chroma (unlocked variant)
- SD_3, // Stable Diffusion 3
- QWEN2VL, // Qwen2-VL vision-language model
- };
- /**
- * @brief Model detection result with detailed information
- */
- struct ModelDetectionResult {
- ModelArchitecture architecture = ModelArchitecture::UNKNOWN;
- std::string architectureName;
- // Model properties
- int textEncoderDim = 0; // Text encoder dimension (768, 1024, etc.)
- int unetChannels = 0; // UNet channel count
- bool hasConditioner = false; // SDXL conditioner
- bool hasRefiner = false; // SDXL refiner components
- // Required auxiliary models
- bool needsVAE = false;
- std::string recommendedVAE; // Recommended VAE model
- bool needsTAESD = false; // Can use TAESD for faster preview
- // Loading parameters
- std::map<std::string, std::string> suggestedParams;
- // Raw metadata
- std::map<std::string, std::string> metadata;
- std::vector<std::string> tensorNames;
- };
- /**
- * @brief Model detector class for analyzing model files
- */
- class ModelDetector {
- public:
- /**
- * @brief Detect model architecture from file
- *
- * @param modelPath Path to model file (.safetensors, .ckpt, etc.)
- * @return ModelDetectionResult Detection result with architecture info
- */
- static ModelDetectionResult detectModel(const std::string& modelPath);
- /**
- * @brief Parse safetensors file header
- *
- * @param filePath Path to safetensors file
- * @param metadata Output: metadata from header
- * @param tensorInfo Output: tensor names and shapes
- * @return true if successfully parsed, false otherwise
- */
- static bool parseSafetensorsHeader(
- const std::string& filePath,
- std::map<std::string, std::string>& metadata,
- std::map<std::string, std::vector<int64_t>>& tensorInfo
- );
- /**
- * @brief Parse GGUF file header
- *
- * @param filePath Path to GGUF file
- * @param metadata Output: metadata from header
- * @param tensorInfo Output: tensor names and dimensions
- * @return true if successfully parsed, false otherwise
- */
- static bool parseGGUFHeader(
- const std::string& filePath,
- std::map<std::string, std::string>& metadata,
- std::map<std::string, std::vector<int64_t>>& tensorInfo
- );
- /**
- * @brief Analyze tensor structure to determine architecture
- *
- * @param tensorInfo Map of tensor names to shapes
- * @param metadata Model metadata
- * @param filename Model filename (for special variant detection)
- * @return ModelArchitecture Detected architecture
- */
- static ModelArchitecture analyzeArchitecture(
- const std::map<std::string, std::vector<int64_t>>& tensorInfo,
- const std::map<std::string, std::string>& metadata,
- const std::string& filename = ""
- );
- /**
- * @brief Get architecture name as string
- *
- * @param arch Model architecture enum
- * @return std::string Human-readable architecture name
- */
- static std::string getArchitectureName(ModelArchitecture arch);
- /**
- * @brief Get recommended loading parameters for architecture
- *
- * @param arch Model architecture
- * @return std::map<std::string, std::string> Recommended parameters
- */
- static std::map<std::string, std::string> getRecommendedParams(ModelArchitecture arch);
- };
- #endif // MODEL_DETECTOR_H
|