model_manager.h 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. #ifndef MODEL_MANAGER_H
  2. #define MODEL_MANAGER_H
  3. #include <string>
  4. #include <memory>
  5. #include <map>
  6. #include <shared_mutex>
  7. #include <vector>
  8. #include <filesystem>
  9. #include <cstdint>
  10. // Forward declarations
  11. class StableDiffusionWrapper;
  12. class ModelPathSelector;
  13. class ModelDetectionCache;
  14. #include "model_detector.h"
  15. #include "server_config.h"
  16. /**
  17. * @brief Model type enumeration
  18. *
  19. * These values are bit flags that can be combined to filter model types.
  20. */
  21. enum class ModelType : uint32_t {
  22. NONE = 0,
  23. LORA = 1,
  24. CHECKPOINT = 2,
  25. VAE = 4,
  26. PRESETS = 8,
  27. PROMPTS = 16,
  28. NEG_PROMPTS = 32,
  29. TAESD = 64,
  30. ESRGAN = 128,
  31. CONTROLNET = 256,
  32. UPSCALER = 512,
  33. EMBEDDING = 1024,
  34. DIFFUSION_MODELS = 2048
  35. };
  36. // Enable bitwise operations for ModelType
  37. inline ModelType operator|(ModelType a, ModelType b) {
  38. return static_cast<ModelType>(static_cast<uint32_t>(a) | static_cast<uint32_t>(b));
  39. }
  40. inline ModelType operator&(ModelType a, ModelType b) {
  41. return static_cast<ModelType>(static_cast<uint32_t>(a) & static_cast<uint32_t>(b));
  42. }
  43. inline ModelType& operator|=(ModelType& a, ModelType b) {
  44. a = a | b;
  45. return a;
  46. }
  47. /**
  48. * @brief Model manager class for loading and managing stable-diffusion models
  49. *
  50. * This class handles loading, unloading, and managing multiple stable-diffusion models.
  51. * It provides thread-safe access to models and manages model resources efficiently.
  52. */
  53. class ModelManager {
  54. public:
  55. /**
  56. * @brief Model information structure
  57. */
  58. struct ModelInfo {
  59. std::string name; ///< Model name
  60. std::string path; ///< Model file path
  61. std::string fullPath; ///< Absolute path to the model file
  62. ModelType type; ///< Model type
  63. bool isLoaded; ///< Whether the model is currently loaded
  64. size_t fileSize; ///< File size in bytes
  65. std::string sha256; ///< SHA256 hash of the file
  66. std::filesystem::file_time_type createdAt; ///< File creation time
  67. std::filesystem::file_time_type modifiedAt; ///< Last modification time
  68. std::string description; ///< Model description
  69. std::map<std::string, std::string> metadata; ///< Additional metadata
  70. // Architecture detection fields
  71. std::string architecture; ///< Detected architecture (e.g., "Stable Diffusion XL Base", "Flux Dev")
  72. std::string recommendedVAE; ///< Recommended VAE for this model
  73. int recommendedWidth = 0; ///< Recommended image width
  74. int recommendedHeight = 0; ///< Recommended image height
  75. int recommendedSteps = 0; ///< Recommended number of steps
  76. std::string recommendedSampler; ///< Recommended sampler
  77. std::vector<std::string> requiredModels; ///< List of required auxiliary models (VAE, CLIP, etc.)
  78. std::vector<std::string> missingModels; ///< List of missing required models
  79. // Caching-related fields
  80. bool cacheValid = false; ///< Whether cached detection results are valid
  81. std::filesystem::file_time_type cacheModifiedAt; ///< Modification time when cache was created
  82. std::string cachePathType; ///< Path type used for this model ("model_path" or "diffusion_model_path")
  83. bool useFolderBasedDetection = false; ///< Whether folder-based detection was used
  84. std::string detectionSource; ///< Source of detection: "folder", "architecture", "fallback"
  85. };
  86. /**
  87. * @brief Construct a new Model Manager object
  88. */
  89. ModelManager();
  90. /**
  91. * @brief Destroy the Model Manager object
  92. */
  93. virtual ~ModelManager();
  94. /**
  95. * @brief Scan the models directory to discover available models
  96. *
  97. * Recursively scans all subdirectories within the models directory to find
  98. * model files. For each model found, constructs the display name as
  99. * 'relative_path/model_name' where relative_path is the path from the models
  100. * root directory to the file's containing folder (using forward slashes).
  101. * Models in the root directory appear without a prefix.
  102. *
  103. * @return true if scanning was successful, false otherwise
  104. */
  105. bool scanModelsDirectory();
  106. /**
  107. * @brief Cancel any ongoing model directory scanning
  108. */
  109. void cancelScan();
  110. /**
  111. * @brief Load a model from the specified path
  112. *
  113. * @param name The name to assign to the model
  114. * @param path The file path to the model
  115. * @param type The type of model
  116. * @return true if the model was loaded successfully, false otherwise
  117. */
  118. bool loadModel(const std::string& name, const std::string& path, ModelType type);
  119. /**
  120. * @brief Load a model by name (must be discovered first)
  121. *
  122. * @param name The name of the model to load
  123. * @return true if the model was loaded successfully, false otherwise
  124. */
  125. bool loadModel(const std::string& name);
  126. /**
  127. * @brief Unload a model
  128. *
  129. * @param name The name of the model to unload
  130. * @return true if the model was unloaded successfully, false otherwise
  131. */
  132. bool unloadModel(const std::string& name);
  133. /**
  134. * @brief Get a pointer to a loaded model
  135. *
  136. * @param name The name of the model
  137. * @return StableDiffusionWrapper* Pointer to the model wrapper, or nullptr if not found
  138. */
  139. StableDiffusionWrapper* getModel(const std::string& name);
  140. /**
  141. * @brief Get information about all models
  142. *
  143. * @return std::map<std::string, ModelInfo> Map of model names to their information
  144. */
  145. std::map<std::string, ModelInfo> getAllModels() const;
  146. /**
  147. * @brief Get information about models of a specific type
  148. *
  149. * @param type The model type to filter by
  150. * @return std::vector<ModelInfo> List of model information
  151. */
  152. std::vector<ModelInfo> getModelsByType(ModelType type) const;
  153. /**
  154. * @brief Get information about a specific model
  155. *
  156. * @param name The name of the model
  157. * @return ModelInfo Model information, or empty if not found
  158. */
  159. ModelInfo getModelInfo(const std::string& name) const;
  160. /**
  161. * @brief Check if a model is loaded
  162. *
  163. * @param name The name of the model
  164. * @return true if the model is loaded, false otherwise
  165. */
  166. bool isModelLoaded(const std::string& name) const;
  167. /**
  168. * @brief Get the number of loaded models
  169. *
  170. * @return size_t Number of loaded models
  171. */
  172. size_t getLoadedModelsCount() const;
  173. /**
  174. * @brief Get the number of available models
  175. *
  176. * @return size_t Number of available models
  177. */
  178. size_t getAvailableModelsCount() const;
  179. /**
  180. * @brief Set the models directory path
  181. *
  182. * @param path The path to the models directory
  183. */
  184. void setModelsDirectory(const std::string& path);
  185. /**
  186. * @brief Get the models directory path
  187. *
  188. * @return std::string The models directory path
  189. */
  190. std::string getModelsDirectory() const;
  191. /**
  192. * @brief Set directory for a specific model type
  193. *
  194. * @param type The model type
  195. * @param path The directory path
  196. * @return true if the directory was set successfully, false otherwise
  197. */
  198. bool setModelTypeDirectory(ModelType type, const std::string& path);
  199. /**
  200. * @brief Get directory for a specific model type
  201. *
  202. * @param type The model type
  203. * @return std::string The directory path, empty if not set
  204. */
  205. std::string getModelTypeDirectory(ModelType type) const;
  206. /**
  207. * @brief Set all model type directories at once
  208. *
  209. * @param directories Map of model types to directory paths
  210. * @return true if all directories were set successfully, false otherwise
  211. */
  212. bool setAllModelTypeDirectories(const std::map<ModelType, std::string>& directories);
  213. /**
  214. * @brief Get all model type directories
  215. *
  216. * @return std::map<ModelType, std::string> Map of model types to directory paths
  217. */
  218. std::map<ModelType, std::string> getAllModelTypeDirectories() const;
  219. // Legacy methods removed - using explicit directory configuration only
  220. /**
  221. * @brief Configure ModelManager with ServerConfig
  222. *
  223. * @param config The server configuration
  224. * @return true if configuration was successful, false otherwise
  225. */
  226. bool configureFromServerConfig(const struct ServerConfig& config);
  227. /**
  228. * @brief Convert ModelType to string
  229. *
  230. * @param type The model type
  231. * @return std::string String representation of the model type
  232. */
  233. static std::string modelTypeToString(ModelType type);
  234. /**
  235. * @brief Convert string to ModelType
  236. *
  237. * @param typeStr String representation of the model type
  238. * @return ModelType The model type
  239. */
  240. static ModelType stringToModelType(const std::string& typeStr);
  241. /**
  242. * @brief Compute SHA256 hash of a model file
  243. *
  244. * @param modelName The name of the model
  245. * @return std::string The SHA256 hash, or empty string on error
  246. */
  247. std::string computeModelHash(const std::string& modelName);
  248. /**
  249. * @brief Load hash from JSON file for a model
  250. *
  251. * @param modelName The name of the model
  252. * @return std::string The loaded hash, or empty string if not found
  253. */
  254. std::string loadModelHashFromFile(const std::string& modelName);
  255. /**
  256. * @brief Save hash to JSON file for a model
  257. *
  258. * @param modelName The name of the model
  259. * @param hash The SHA256 hash to save
  260. * @return true if saved successfully, false otherwise
  261. */
  262. bool saveModelHashToFile(const std::string& modelName, const std::string& hash);
  263. /**
  264. * @brief Find model by hash (full or partial - minimum 10 chars)
  265. *
  266. * @param hash Full or partial SHA256 hash (minimum 10 characters)
  267. * @return std::string Model name, or empty string if not found
  268. */
  269. std::string findModelByHash(const std::string& hash);
  270. /**
  271. * @brief Load hash for a model (from file or compute if missing)
  272. *
  273. * @param modelName The name of the model
  274. * @param forceCompute Force recomputation even if hash file exists
  275. * @return std::string The SHA256 hash, or empty string on error
  276. */
  277. std::string ensureModelHash(const std::string& modelName, bool forceCompute = false);
  278. private:
  279. class Impl;
  280. std::unique_ptr<Impl> pImpl; // Pimpl idiom
  281. /**
  282. * @brief Model path selector class for folder-based detection
  283. */
  284. class ModelPathSelector {
  285. public:
  286. /**
  287. * @brief Determine which path to use based on folder location
  288. *
  289. * @param modelPath The absolute path to the model file
  290. * @param checkpointsDir The checkpoints directory path
  291. * @param diffusionModelsDir The diffusion_models directory path
  292. * @return std::string "model_path" for checkpoints, "diffusion_model_path" for diffusion_models
  293. */
  294. static std::string selectPathType(
  295. const std::string& modelPath,
  296. const std::string& checkpointsDir,
  297. const std::string& diffusionModelsDir
  298. );
  299. /**
  300. * @brief Check if model is in a specific directory
  301. *
  302. * @param modelPath The absolute path to the model file
  303. * @param directory The directory to check
  304. * @return true if model is in the specified directory
  305. */
  306. static bool isModelInDirectory(const std::string& modelPath, const std::string& directory);
  307. };
  308. /**
  309. * @brief Model detection cache class for caching detection results
  310. */
  311. class ModelDetectionCache {
  312. public:
  313. /**
  314. * @brief Cache entry for model detection results
  315. */
  316. struct CacheEntry {
  317. std::string architecture;
  318. std::string recommendedVAE;
  319. int recommendedWidth = 0;
  320. int recommendedHeight = 0;
  321. int recommendedSteps = 0;
  322. std::string recommendedSampler;
  323. std::vector<std::string> requiredModels;
  324. std::vector<std::string> missingModels;
  325. std::string pathType;
  326. std::string detectionSource;
  327. std::filesystem::file_time_type cachedAt;
  328. std::filesystem::file_time_type fileModifiedAt;
  329. bool isValid = false;
  330. };
  331. /**
  332. * @brief Get cached detection result for a model
  333. *
  334. * @param modelPath The path to the model file
  335. * @param currentModifiedTime Current file modification time
  336. * @return CacheEntry Cached entry, or invalid if not found/expired
  337. */
  338. static CacheEntry getCachedResult(const std::string& modelPath,
  339. const std::filesystem::file_time_type& currentModifiedTime);
  340. /**
  341. * @brief Cache detection result for a model
  342. *
  343. * @param modelPath The path to the model file
  344. * @param detection The detection result to cache
  345. * @param pathType The path type used
  346. * @param detectionSource The source of detection
  347. * @param fileModifiedTime Current file modification time
  348. */
  349. static void cacheDetectionResult(
  350. const std::string& modelPath,
  351. const ModelDetectionResult& detection,
  352. const std::string& pathType,
  353. const std::string& detectionSource,
  354. const std::filesystem::file_time_type& fileModifiedTime
  355. );
  356. /**
  357. * @brief Invalidate cache for a model
  358. *
  359. * @param modelPath The path to the model file
  360. */
  361. static void invalidateCache(const std::string& modelPath);
  362. /**
  363. * @brief Clear all cached results
  364. */
  365. static void clearAllCache();
  366. private:
  367. static std::map<std::string, CacheEntry> cache_;
  368. static std::mutex cacheMutex_;
  369. };
  370. };
  371. #endif // MODEL_MANAGER_H