model_detector.h 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. #ifndef MODEL_DETECTOR_H
  2. #define MODEL_DETECTOR_H
  3. #include <string>
  4. #include <vector>
  5. #include <map>
  6. /**
  7. * @brief Detected model architecture types
  8. */
  9. enum class ModelArchitecture {
  10. UNKNOWN,
  11. SD_1_5, // Stable Diffusion 1.x (768 text encoder)
  12. SD_2_1, // Stable Diffusion 2.x (1024 text encoder)
  13. SDXL_BASE, // Stable Diffusion XL Base
  14. SDXL_REFINER,// Stable Diffusion XL Refiner
  15. FLUX_SCHNELL,// Flux Schnell
  16. FLUX_DEV, // Flux Dev
  17. FLUX_CHROMA, // Flux Chroma (unlocked variant)
  18. SD_3, // Stable Diffusion 3
  19. QWEN2VL, // Qwen2-VL vision-language model
  20. };
  21. /**
  22. * @brief Model detection result with detailed information
  23. */
  24. struct ModelDetectionResult {
  25. ModelArchitecture architecture = ModelArchitecture::UNKNOWN;
  26. std::string architectureName;
  27. // Model properties
  28. int textEncoderDim = 0; // Text encoder dimension (768, 1024, etc.)
  29. int unetChannels = 0; // UNet channel count
  30. bool hasConditioner = false; // SDXL conditioner
  31. bool hasRefiner = false; // SDXL refiner components
  32. // Required auxiliary models
  33. bool needsVAE = false;
  34. std::string recommendedVAE; // Recommended VAE model
  35. bool needsTAESD = false; // Can use TAESD for faster preview
  36. // Loading parameters
  37. std::map<std::string, std::string> suggestedParams;
  38. // Raw metadata
  39. std::map<std::string, std::string> metadata;
  40. std::vector<std::string> tensorNames;
  41. };
  42. /**
  43. * @brief Model detector class for analyzing model files
  44. */
  45. class ModelDetector {
  46. public:
  47. /**
  48. * @brief Detect model architecture from file
  49. *
  50. * @param modelPath Path to model file (.safetensors, .ckpt, etc.)
  51. * @return ModelDetectionResult Detection result with architecture info
  52. */
  53. static ModelDetectionResult detectModel(const std::string& modelPath);
  54. /**
  55. * @brief Parse safetensors file header
  56. *
  57. * @param filePath Path to safetensors file
  58. * @param metadata Output: metadata from header
  59. * @param tensorInfo Output: tensor names and shapes
  60. * @return true if successfully parsed, false otherwise
  61. */
  62. static bool parseSafetensorsHeader(
  63. const std::string& filePath,
  64. std::map<std::string, std::string>& metadata,
  65. std::map<std::string, std::vector<int64_t>>& tensorInfo
  66. );
  67. /**
  68. * @brief Parse GGUF file header
  69. *
  70. * @param filePath Path to GGUF file
  71. * @param metadata Output: metadata from header
  72. * @param tensorInfo Output: tensor names and dimensions
  73. * @return true if successfully parsed, false otherwise
  74. */
  75. static bool parseGGUFHeader(
  76. const std::string& filePath,
  77. std::map<std::string, std::string>& metadata,
  78. std::map<std::string, std::vector<int64_t>>& tensorInfo
  79. );
  80. /**
  81. * @brief Analyze tensor structure to determine architecture
  82. *
  83. * @param tensorInfo Map of tensor names to shapes
  84. * @param metadata Model metadata
  85. * @param filename Model filename (for special variant detection)
  86. * @return ModelArchitecture Detected architecture
  87. */
  88. static ModelArchitecture analyzeArchitecture(
  89. const std::map<std::string, std::vector<int64_t>>& tensorInfo,
  90. const std::map<std::string, std::string>& metadata,
  91. const std::string& filename = ""
  92. );
  93. /**
  94. * @brief Get architecture name as string
  95. *
  96. * @param arch Model architecture enum
  97. * @return std::string Human-readable architecture name
  98. */
  99. static std::string getArchitectureName(ModelArchitecture arch);
  100. /**
  101. * @brief Get recommended loading parameters for architecture
  102. *
  103. * @param arch Model architecture
  104. * @return std::map<std::string, std::string> Recommended parameters
  105. */
  106. static std::map<std::string, std::string> getRecommendedParams(ModelArchitecture arch);
  107. };
  108. #endif // MODEL_DETECTOR_H