model_manager.h 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. #ifndef MODEL_MANAGER_H
  2. #define MODEL_MANAGER_H
  3. #include <string>
  4. #include <memory>
  5. #include <map>
  6. #include <shared_mutex>
  7. #include <vector>
  8. #include <filesystem>
  9. #include <cstdint>
  10. // Forward declarations
  11. class StableDiffusionWrapper;
  12. #include "server_config.h"
  13. /**
  14. * @brief Model type enumeration
  15. *
  16. * These values are bit flags that can be combined to filter model types.
  17. */
  18. enum class ModelType : uint32_t {
  19. NONE = 0,
  20. LORA = 1,
  21. CHECKPOINT = 2,
  22. VAE = 4,
  23. PRESETS = 8,
  24. PROMPTS = 16,
  25. NEG_PROMPTS = 32,
  26. TAESD = 64,
  27. ESRGAN = 128,
  28. CONTROLNET = 256,
  29. UPSCALER = 512,
  30. EMBEDDING = 1024,
  31. DIFFUSION_MODELS = 2048
  32. };
  33. // Enable bitwise operations for ModelType
  34. inline ModelType operator|(ModelType a, ModelType b) {
  35. return static_cast<ModelType>(static_cast<uint32_t>(a) | static_cast<uint32_t>(b));
  36. }
  37. inline ModelType operator&(ModelType a, ModelType b) {
  38. return static_cast<ModelType>(static_cast<uint32_t>(a) & static_cast<uint32_t>(b));
  39. }
  40. inline ModelType& operator|=(ModelType& a, ModelType b) {
  41. a = a | b;
  42. return a;
  43. }
  44. /**
  45. * @brief Model manager class for loading and managing stable-diffusion models
  46. *
  47. * This class handles loading, unloading, and managing multiple stable-diffusion models.
  48. * It provides thread-safe access to models and manages model resources efficiently.
  49. */
  50. class ModelManager {
  51. public:
  52. /**
  53. * @brief Model information structure
  54. */
  55. struct ModelInfo {
  56. std::string name; ///< Model name
  57. std::string path; ///< Model file path
  58. std::string fullPath; ///< Absolute path to the model file
  59. ModelType type; ///< Model type
  60. bool isLoaded; ///< Whether the model is currently loaded
  61. size_t fileSize; ///< File size in bytes
  62. std::string sha256; ///< SHA256 hash of the file
  63. std::filesystem::file_time_type createdAt; ///< File creation time
  64. std::filesystem::file_time_type modifiedAt; ///< Last modification time
  65. std::string description; ///< Model description
  66. std::map<std::string, std::string> metadata; ///< Additional metadata
  67. // Architecture detection fields
  68. std::string architecture; ///< Detected architecture (e.g., "Stable Diffusion XL Base", "Flux Dev")
  69. std::string recommendedVAE; ///< Recommended VAE for this model
  70. int recommendedWidth = 0; ///< Recommended image width
  71. int recommendedHeight = 0; ///< Recommended image height
  72. int recommendedSteps = 0; ///< Recommended number of steps
  73. std::string recommendedSampler; ///< Recommended sampler
  74. std::vector<std::string> requiredModels; ///< List of required auxiliary models (VAE, CLIP, etc.)
  75. std::vector<std::string> missingModels; ///< List of missing required models
  76. };
  77. /**
  78. * @brief Construct a new Model Manager object
  79. */
  80. ModelManager();
  81. /**
  82. * @brief Destroy the Model Manager object
  83. */
  84. virtual ~ModelManager();
  85. /**
  86. * @brief Scan the models directory to discover available models
  87. *
  88. * Recursively scans all subdirectories within the models directory to find
  89. * model files. For each model found, constructs the display name as
  90. * 'relative_path/model_name' where relative_path is the path from the models
  91. * root directory to the file's containing folder (using forward slashes).
  92. * Models in the root directory appear without a prefix.
  93. *
  94. * @return true if scanning was successful, false otherwise
  95. */
  96. bool scanModelsDirectory();
  97. /**
  98. * @brief Cancel any ongoing model directory scanning
  99. */
  100. void cancelScan();
  101. /**
  102. * @brief Load a model from the specified path
  103. *
  104. * @param name The name to assign to the model
  105. * @param path The file path to the model
  106. * @param type The type of model
  107. * @return true if the model was loaded successfully, false otherwise
  108. */
  109. bool loadModel(const std::string& name, const std::string& path, ModelType type);
  110. /**
  111. * @brief Load a model by name (must be discovered first)
  112. *
  113. * @param name The name of the model to load
  114. * @return true if the model was loaded successfully, false otherwise
  115. */
  116. bool loadModel(const std::string& name);
  117. /**
  118. * @brief Unload a model
  119. *
  120. * @param name The name of the model to unload
  121. * @return true if the model was unloaded successfully, false otherwise
  122. */
  123. bool unloadModel(const std::string& name);
  124. /**
  125. * @brief Get a pointer to a loaded model
  126. *
  127. * @param name The name of the model
  128. * @return StableDiffusionWrapper* Pointer to the model wrapper, or nullptr if not found
  129. */
  130. StableDiffusionWrapper* getModel(const std::string& name);
  131. /**
  132. * @brief Get information about all models
  133. *
  134. * @return std::map<std::string, ModelInfo> Map of model names to their information
  135. */
  136. std::map<std::string, ModelInfo> getAllModels() const;
  137. /**
  138. * @brief Get information about models of a specific type
  139. *
  140. * @param type The model type to filter by
  141. * @return std::vector<ModelInfo> List of model information
  142. */
  143. std::vector<ModelInfo> getModelsByType(ModelType type) const;
  144. /**
  145. * @brief Get information about a specific model
  146. *
  147. * @param name The name of the model
  148. * @return ModelInfo Model information, or empty if not found
  149. */
  150. ModelInfo getModelInfo(const std::string& name) const;
  151. /**
  152. * @brief Check if a model is loaded
  153. *
  154. * @param name The name of the model
  155. * @return true if the model is loaded, false otherwise
  156. */
  157. bool isModelLoaded(const std::string& name) const;
  158. /**
  159. * @brief Get the number of loaded models
  160. *
  161. * @return size_t Number of loaded models
  162. */
  163. size_t getLoadedModelsCount() const;
  164. /**
  165. * @brief Get the number of available models
  166. *
  167. * @return size_t Number of available models
  168. */
  169. size_t getAvailableModelsCount() const;
  170. /**
  171. * @brief Set the models directory path
  172. *
  173. * @param path The path to the models directory
  174. */
  175. void setModelsDirectory(const std::string& path);
  176. /**
  177. * @brief Get the models directory path
  178. *
  179. * @return std::string The models directory path
  180. */
  181. std::string getModelsDirectory() const;
  182. /**
  183. * @brief Set directory for a specific model type
  184. *
  185. * @param type The model type
  186. * @param path The directory path
  187. * @return true if the directory was set successfully, false otherwise
  188. */
  189. bool setModelTypeDirectory(ModelType type, const std::string& path);
  190. /**
  191. * @brief Get directory for a specific model type
  192. *
  193. * @param type The model type
  194. * @return std::string The directory path, empty if not set
  195. */
  196. std::string getModelTypeDirectory(ModelType type) const;
  197. /**
  198. * @brief Set all model type directories at once
  199. *
  200. * @param directories Map of model types to directory paths
  201. * @return true if all directories were set successfully, false otherwise
  202. */
  203. bool setAllModelTypeDirectories(const std::map<ModelType, std::string>& directories);
  204. /**
  205. * @brief Get all model type directories
  206. *
  207. * @return std::map<ModelType, std::string> Map of model types to directory paths
  208. */
  209. std::map<ModelType, std::string> getAllModelTypeDirectories() const;
  210. // Legacy methods removed - using explicit directory configuration only
  211. /**
  212. * @brief Configure ModelManager with ServerConfig
  213. *
  214. * @param config The server configuration
  215. * @return true if configuration was successful, false otherwise
  216. */
  217. bool configureFromServerConfig(const struct ServerConfig& config);
  218. /**
  219. * @brief Convert ModelType to string
  220. *
  221. * @param type The model type
  222. * @return std::string String representation of the model type
  223. */
  224. static std::string modelTypeToString(ModelType type);
  225. /**
  226. * @brief Convert string to ModelType
  227. *
  228. * @param typeStr String representation of the model type
  229. * @return ModelType The model type
  230. */
  231. static ModelType stringToModelType(const std::string& typeStr);
  232. /**
  233. * @brief Compute SHA256 hash of a model file
  234. *
  235. * @param modelName The name of the model
  236. * @return std::string The SHA256 hash, or empty string on error
  237. */
  238. std::string computeModelHash(const std::string& modelName);
  239. /**
  240. * @brief Load hash from JSON file for a model
  241. *
  242. * @param modelName The name of the model
  243. * @return std::string The loaded hash, or empty string if not found
  244. */
  245. std::string loadModelHashFromFile(const std::string& modelName);
  246. /**
  247. * @brief Save hash to JSON file for a model
  248. *
  249. * @param modelName The name of the model
  250. * @param hash The SHA256 hash to save
  251. * @return true if saved successfully, false otherwise
  252. */
  253. bool saveModelHashToFile(const std::string& modelName, const std::string& hash);
  254. /**
  255. * @brief Find model by hash (full or partial - minimum 10 chars)
  256. *
  257. * @param hash Full or partial SHA256 hash (minimum 10 characters)
  258. * @return std::string Model name, or empty string if not found
  259. */
  260. std::string findModelByHash(const std::string& hash);
  261. /**
  262. * @brief Load hash for a model (from file or compute if missing)
  263. *
  264. * @param modelName The name of the model
  265. * @param forceCompute Force recomputation even if hash file exists
  266. * @return std::string The SHA256 hash, or empty string on error
  267. */
  268. std::string ensureModelHash(const std::string& modelName, bool forceCompute = false);
  269. private:
  270. class Impl;
  271. std::unique_ptr<Impl> pImpl; // Pimpl idiom
  272. };
  273. #endif // MODEL_MANAGER_H