model_manager.h 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. #ifndef MODEL_MANAGER_H
  2. #define MODEL_MANAGER_H
  3. #include <string>
  4. #include <memory>
  5. #include <map>
  6. #include <shared_mutex>
  7. #include <vector>
  8. #include <filesystem>
  9. #include <cstdint>
  10. // Forward declarations
  11. class StableDiffusionWrapper;
  12. #include "server_config.h"
  13. /**
  14. * @brief Model type enumeration
  15. *
  16. * These values are bit flags that can be combined to filter model types.
  17. */
  18. enum class ModelType : uint32_t {
  19. NONE = 0,
  20. LORA = 1,
  21. CHECKPOINT = 2,
  22. VAE = 4,
  23. PRESETS = 8,
  24. PROMPTS = 16,
  25. NEG_PROMPTS = 32,
  26. TAESD = 64,
  27. ESRGAN = 128,
  28. CONTROLNET = 256,
  29. UPSCALER = 512,
  30. EMBEDDING = 1024
  31. };
  32. // Enable bitwise operations for ModelType
  33. inline ModelType operator|(ModelType a, ModelType b) {
  34. return static_cast<ModelType>(static_cast<uint32_t>(a) | static_cast<uint32_t>(b));
  35. }
  36. inline ModelType operator&(ModelType a, ModelType b) {
  37. return static_cast<ModelType>(static_cast<uint32_t>(a) & static_cast<uint32_t>(b));
  38. }
  39. inline ModelType& operator|=(ModelType& a, ModelType b) {
  40. a = a | b;
  41. return a;
  42. }
  43. /**
  44. * @brief Model manager class for loading and managing stable-diffusion models
  45. *
  46. * This class handles loading, unloading, and managing multiple stable-diffusion models.
  47. * It provides thread-safe access to models and manages model resources efficiently.
  48. */
  49. class ModelManager {
  50. public:
  51. /**
  52. * @brief Model information structure
  53. */
  54. struct ModelInfo {
  55. std::string name; ///< Model name
  56. std::string path; ///< Model file path
  57. std::string fullPath; ///< Absolute path to the model file
  58. ModelType type; ///< Model type
  59. bool isLoaded; ///< Whether the model is currently loaded
  60. size_t fileSize; ///< File size in bytes
  61. std::string sha256; ///< SHA256 hash of the file
  62. std::filesystem::file_time_type createdAt; ///< File creation time
  63. std::filesystem::file_time_type modifiedAt; ///< Last modification time
  64. std::string description; ///< Model description
  65. std::map<std::string, std::string> metadata; ///< Additional metadata
  66. // Architecture detection fields
  67. std::string architecture; ///< Detected architecture (e.g., "Stable Diffusion XL Base", "Flux Dev")
  68. std::string recommendedVAE; ///< Recommended VAE for this model
  69. int recommendedWidth = 0; ///< Recommended image width
  70. int recommendedHeight = 0; ///< Recommended image height
  71. int recommendedSteps = 0; ///< Recommended number of steps
  72. std::string recommendedSampler; ///< Recommended sampler
  73. std::vector<std::string> requiredModels; ///< List of required auxiliary models (VAE, CLIP, etc.)
  74. std::vector<std::string> missingModels; ///< List of missing required models
  75. };
  76. /**
  77. * @brief Construct a new Model Manager object
  78. */
  79. ModelManager();
  80. /**
  81. * @brief Destroy the Model Manager object
  82. */
  83. virtual ~ModelManager();
  84. /**
  85. * @brief Scan the models directory to discover available models
  86. *
  87. * @return true if scanning was successful, false otherwise
  88. */
  89. bool scanModelsDirectory();
  90. /**
  91. * @brief Cancel any ongoing model directory scanning
  92. */
  93. void cancelScan();
  94. /**
  95. * @brief Load a model from the specified path
  96. *
  97. * @param name The name to assign to the model
  98. * @param path The file path to the model
  99. * @param type The type of model
  100. * @return true if the model was loaded successfully, false otherwise
  101. */
  102. bool loadModel(const std::string& name, const std::string& path, ModelType type);
  103. /**
  104. * @brief Load a model by name (must be discovered first)
  105. *
  106. * @param name The name of the model to load
  107. * @return true if the model was loaded successfully, false otherwise
  108. */
  109. bool loadModel(const std::string& name);
  110. /**
  111. * @brief Unload a model
  112. *
  113. * @param name The name of the model to unload
  114. * @return true if the model was unloaded successfully, false otherwise
  115. */
  116. bool unloadModel(const std::string& name);
  117. /**
  118. * @brief Get a pointer to a loaded model
  119. *
  120. * @param name The name of the model
  121. * @return StableDiffusionWrapper* Pointer to the model wrapper, or nullptr if not found
  122. */
  123. StableDiffusionWrapper* getModel(const std::string& name);
  124. /**
  125. * @brief Get information about all models
  126. *
  127. * @return std::map<std::string, ModelInfo> Map of model names to their information
  128. */
  129. std::map<std::string, ModelInfo> getAllModels() const;
  130. /**
  131. * @brief Get information about models of a specific type
  132. *
  133. * @param type The model type to filter by
  134. * @return std::vector<ModelInfo> List of model information
  135. */
  136. std::vector<ModelInfo> getModelsByType(ModelType type) const;
  137. /**
  138. * @brief Get information about a specific model
  139. *
  140. * @param name The name of the model
  141. * @return ModelInfo Model information, or empty if not found
  142. */
  143. ModelInfo getModelInfo(const std::string& name) const;
  144. /**
  145. * @brief Check if a model is loaded
  146. *
  147. * @param name The name of the model
  148. * @return true if the model is loaded, false otherwise
  149. */
  150. bool isModelLoaded(const std::string& name) const;
  151. /**
  152. * @brief Get the number of loaded models
  153. *
  154. * @return size_t Number of loaded models
  155. */
  156. size_t getLoadedModelsCount() const;
  157. /**
  158. * @brief Get the number of available models
  159. *
  160. * @return size_t Number of available models
  161. */
  162. size_t getAvailableModelsCount() const;
  163. /**
  164. * @brief Set the models directory path
  165. *
  166. * @param path The path to the models directory
  167. */
  168. void setModelsDirectory(const std::string& path);
  169. /**
  170. * @brief Get the models directory path
  171. *
  172. * @return std::string The models directory path
  173. */
  174. std::string getModelsDirectory() const;
  175. /**
  176. * @brief Set directory for a specific model type
  177. *
  178. * @param type The model type
  179. * @param path The directory path
  180. * @return true if the directory was set successfully, false otherwise
  181. */
  182. bool setModelTypeDirectory(ModelType type, const std::string& path);
  183. /**
  184. * @brief Get directory for a specific model type
  185. *
  186. * @param type The model type
  187. * @return std::string The directory path, empty if not set
  188. */
  189. std::string getModelTypeDirectory(ModelType type) const;
  190. /**
  191. * @brief Set all model type directories at once
  192. *
  193. * @param directories Map of model types to directory paths
  194. * @return true if all directories were set successfully, false otherwise
  195. */
  196. bool setAllModelTypeDirectories(const std::map<ModelType, std::string>& directories);
  197. /**
  198. * @brief Get all model type directories
  199. *
  200. * @return std::map<ModelType, std::string> Map of model types to directory paths
  201. */
  202. std::map<ModelType, std::string> getAllModelTypeDirectories() const;
  203. /**
  204. * @brief Reset to legacy directory mode (single models directory)
  205. */
  206. void resetToLegacyDirectories();
  207. /**
  208. * @brief Configure ModelManager with ServerConfig
  209. *
  210. * @param config The server configuration
  211. * @return true if configuration was successful, false otherwise
  212. */
  213. bool configureFromServerConfig(const struct ServerConfig& config);
  214. /**
  215. * @brief Convert ModelType to string
  216. *
  217. * @param type The model type
  218. * @return std::string String representation of the model type
  219. */
  220. static std::string modelTypeToString(ModelType type);
  221. /**
  222. * @brief Convert string to ModelType
  223. *
  224. * @param typeStr String representation of the model type
  225. * @return ModelType The model type
  226. */
  227. static ModelType stringToModelType(const std::string& typeStr);
  228. /**
  229. * @brief Compute SHA256 hash of a model file
  230. *
  231. * @param modelName The name of the model
  232. * @return std::string The SHA256 hash, or empty string on error
  233. */
  234. std::string computeModelHash(const std::string& modelName);
  235. /**
  236. * @brief Load hash from JSON file for a model
  237. *
  238. * @param modelName The name of the model
  239. * @return std::string The loaded hash, or empty string if not found
  240. */
  241. std::string loadModelHashFromFile(const std::string& modelName);
  242. /**
  243. * @brief Save hash to JSON file for a model
  244. *
  245. * @param modelName The name of the model
  246. * @param hash The SHA256 hash to save
  247. * @return true if saved successfully, false otherwise
  248. */
  249. bool saveModelHashToFile(const std::string& modelName, const std::string& hash);
  250. /**
  251. * @brief Find model by hash (full or partial - minimum 10 chars)
  252. *
  253. * @param hash Full or partial SHA256 hash (minimum 10 characters)
  254. * @return std::string Model name, or empty string if not found
  255. */
  256. std::string findModelByHash(const std::string& hash);
  257. /**
  258. * @brief Load hash for a model (from file or compute if missing)
  259. *
  260. * @param modelName The name of the model
  261. * @param forceCompute Force recomputation even if hash file exists
  262. * @return std::string The SHA256 hash, or empty string on error
  263. */
  264. std::string ensureModelHash(const std::string& modelName, bool forceCompute = false);
  265. private:
  266. class Impl;
  267. std::unique_ptr<Impl> pImpl; // Pimpl idiom
  268. };
  269. #endif // MODEL_MANAGER_H