model_detector.h 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. #ifndef MODEL_DETECTOR_H
  2. #define MODEL_DETECTOR_H
  3. #include <string>
  4. #include <vector>
  5. #include <map>
  6. #include <optional>
  7. /**
  8. * @brief Detected model architecture types
  9. */
  10. enum class ModelArchitecture {
  11. UNKNOWN,
  12. SD_1_5, // Stable Diffusion 1.x (768 text encoder)
  13. SD_2_1, // Stable Diffusion 2.x (1024 text encoder)
  14. SDXL_BASE, // Stable Diffusion XL Base
  15. SDXL_REFINER,// Stable Diffusion XL Refiner
  16. FLUX_SCHNELL,// Flux Schnell
  17. FLUX_DEV, // Flux Dev
  18. FLUX_CHROMA, // Flux Chroma (unlocked variant)
  19. SD_3, // Stable Diffusion 3
  20. QWEN2VL, // Qwen2-VL vision-language model
  21. };
  22. /**
  23. * @brief Model detection result with detailed information
  24. */
  25. struct ModelDetectionResult {
  26. ModelArchitecture architecture = ModelArchitecture::UNKNOWN;
  27. std::string architectureName;
  28. // Model properties
  29. int textEncoderDim = 0; // Text encoder dimension (768, 1024, etc.)
  30. int unetChannels = 0; // UNet channel count
  31. bool hasConditioner = false; // SDXL conditioner
  32. bool hasRefiner = false; // SDXL refiner components
  33. // Required auxiliary models
  34. bool needsVAE = false;
  35. std::string recommendedVAE; // Recommended VAE model
  36. bool needsTAESD = false; // Can use TAESD for faster preview
  37. // Loading parameters
  38. std::map<std::string, std::string> suggestedParams;
  39. // Raw metadata
  40. std::map<std::string, std::string> metadata;
  41. std::vector<std::string> tensorNames;
  42. };
  43. /**
  44. * @brief Model detector class for analyzing model files
  45. */
  46. class ModelDetector {
  47. public:
  48. /**
  49. * @brief Detect model architecture from file
  50. *
  51. * @param modelPath Path to model file (.safetensors, .ckpt, etc.)
  52. * @return ModelDetectionResult Detection result with architecture info
  53. */
  54. static ModelDetectionResult detectModel(const std::string& modelPath);
  55. /**
  56. * @brief Parse safetensors file header
  57. *
  58. * @param filePath Path to safetensors file
  59. * @param metadata Output: metadata from header
  60. * @param tensorInfo Output: tensor names and shapes
  61. * @return true if successfully parsed, false otherwise
  62. */
  63. static bool parseSafetensorsHeader(
  64. const std::string& filePath,
  65. std::map<std::string, std::string>& metadata,
  66. std::map<std::string, std::vector<int64_t>>& tensorInfo
  67. );
  68. /**
  69. * @brief Parse GGUF file header
  70. *
  71. * @param filePath Path to GGUF file
  72. * @param metadata Output: metadata from header
  73. * @param tensorInfo Output: tensor names and dimensions
  74. * @return true if successfully parsed, false otherwise
  75. */
  76. static bool parseGGUFHeader(
  77. const std::string& filePath,
  78. std::map<std::string, std::string>& metadata,
  79. std::map<std::string, std::vector<int64_t>>& tensorInfo
  80. );
  81. /**
  82. * @brief Analyze tensor structure to determine architecture
  83. *
  84. * @param tensorInfo Map of tensor names to shapes
  85. * @param metadata Model metadata
  86. * @param filename Model filename (for special variant detection)
  87. * @return ModelArchitecture Detected architecture
  88. */
  89. static ModelArchitecture analyzeArchitecture(
  90. const std::map<std::string, std::vector<int64_t>>& tensorInfo,
  91. const std::map<std::string, std::string>& metadata,
  92. const std::string& filename = ""
  93. );
  94. /**
  95. * @brief Get architecture name as string
  96. *
  97. * @param arch Model architecture enum
  98. * @return std::string Human-readable architecture name
  99. */
  100. static std::string getArchitectureName(ModelArchitecture arch);
  101. /**
  102. * @brief Get recommended loading parameters for architecture
  103. *
  104. * @param arch Model architecture
  105. * @return std::map<std::string, std::string> Recommended parameters
  106. */
  107. static std::map<std::string, std::string> getRecommendedParams(ModelArchitecture arch);
  108. };
  109. #endif // MODEL_DETECTOR_H