model_manager.h 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. #ifndef MODEL_MANAGER_H
  2. #define MODEL_MANAGER_H
  3. #include <string>
  4. #include <memory>
  5. #include <map>
  6. #include <shared_mutex>
  7. #include <vector>
  8. #include <filesystem>
  9. #include <cstdint>
  10. #include <functional>
  11. // Forward declarations
  12. class StableDiffusionWrapper;
  13. class ModelPathSelector;
  14. class ModelDetectionCache;
  15. #include "model_detector.h"
  16. #include "server_config.h"
  17. /**
  18. * @brief Model type enumeration
  19. *
  20. * These values are bit flags that can be combined to filter model types.
  21. */
  22. enum class ModelType : uint32_t {
  23. NONE = 0,
  24. LORA = 1,
  25. CHECKPOINT = 2,
  26. VAE = 4,
  27. PRESETS = 8,
  28. PROMPTS = 16,
  29. NEG_PROMPTS = 32,
  30. TAESD = 64,
  31. ESRGAN = 128,
  32. CONTROLNET = 256,
  33. UPSCALER = 512,
  34. EMBEDDING = 1024,
  35. DIFFUSION_MODELS = 2048
  36. };
  37. // Enable bitwise operations for ModelType
  38. inline ModelType operator|(ModelType a, ModelType b) {
  39. return static_cast<ModelType>(static_cast<uint32_t>(a) | static_cast<uint32_t>(b));
  40. }
  41. inline ModelType operator&(ModelType a, ModelType b) {
  42. return static_cast<ModelType>(static_cast<uint32_t>(a) & static_cast<uint32_t>(b));
  43. }
  44. inline ModelType& operator|=(ModelType& a, ModelType b) {
  45. a = a | b;
  46. return a;
  47. }
  48. /**
  49. * @brief Model manager class for loading and managing stable-diffusion models
  50. *
  51. * This class handles loading, unloading, and managing multiple stable-diffusion models.
  52. * It provides thread-safe access to models and manages model resources efficiently.
  53. */
  54. class ModelManager {
  55. public:
  56. /**
  57. * @brief Model information structure
  58. */
  59. struct ModelInfo {
  60. std::string name; ///< Model name
  61. std::string path; ///< Model file path
  62. std::string fullPath; ///< Absolute path to the model file
  63. ModelType type; ///< Model type
  64. bool isLoaded; ///< Whether the model is currently loaded
  65. size_t fileSize; ///< File size in bytes
  66. std::string sha256; ///< SHA256 hash of the file
  67. std::filesystem::file_time_type createdAt; ///< File creation time
  68. std::filesystem::file_time_type modifiedAt; ///< Last modification time
  69. std::string description; ///< Model description
  70. std::map<std::string, std::string> metadata; ///< Additional metadata
  71. // Architecture detection fields
  72. std::string architecture; ///< Detected architecture (e.g., "Stable Diffusion XL Base", "Flux Dev")
  73. std::string recommendedVAE; ///< Recommended VAE for this model
  74. int recommendedWidth = 0; ///< Recommended image width
  75. int recommendedHeight = 0; ///< Recommended image height
  76. int recommendedSteps = 0; ///< Recommended number of steps
  77. std::string recommendedSampler; ///< Recommended sampler
  78. std::vector<std::string> requiredModels; ///< List of required auxiliary models (VAE, CLIP, etc.)
  79. std::vector<std::string> missingModels; ///< List of missing required models
  80. // Caching-related fields
  81. bool cacheValid = false; ///< Whether cached detection results are valid
  82. std::filesystem::file_time_type cacheModifiedAt; ///< Modification time when cache was created
  83. std::string cachePathType; ///< Path type used for this model ("model_path" or "diffusion_model_path")
  84. bool useFolderBasedDetection = false; ///< Whether folder-based detection was used
  85. std::string detectionSource; ///< Source of detection: "folder", "architecture", "fallback"
  86. };
  87. /**
  88. * @brief Model details structure for existence checking
  89. */
  90. struct ModelDetails {
  91. std::string name; ///< Model name
  92. bool exists; ///< Whether the model exists
  93. std::string type; ///< Model type ("VAE", "CLIP-L", "CLIP-G", "T5XXL", "CLIP-Vision", "Qwen2VL")
  94. std::string path; ///< Absolute path to the model file (empty if doesn't exist)
  95. size_t file_size; ///< File size in bytes (0 if doesn't exist)
  96. std::string sha256; ///< SHA256 hash (empty if doesn't exist)
  97. bool is_required; ///< True for required models
  98. bool is_recommended; ///< True for recommended models
  99. };
  100. /**
  101. * @brief Construct a new Model Manager object
  102. */
  103. ModelManager();
  104. /**
  105. * @brief Destroy the Model Manager object
  106. */
  107. virtual ~ModelManager();
  108. /**
  109. * @brief Scan the models directory to discover available models
  110. *
  111. * Recursively scans all subdirectories within the models directory to find
  112. * model files. For each model found, constructs the display name as
  113. * 'relative_path/model_name' where relative_path is the path from the models
  114. * root directory to the file's containing folder (using forward slashes).
  115. * Models in the root directory appear without a prefix.
  116. *
  117. * @return true if scanning was successful, false otherwise
  118. */
  119. bool scanModelsDirectory();
  120. /**
  121. * @brief Cancel any ongoing model directory scanning
  122. */
  123. void cancelScan();
  124. /**
  125. * @brief Load a model from the specified path
  126. *
  127. * @param name The name to assign to the model
  128. * @param path The file path to the model
  129. * @param type The type of model
  130. * @return true if the model was loaded successfully, false otherwise
  131. */
  132. bool loadModel(const std::string& name, const std::string& path, ModelType type);
  133. /**
  134. * @brief Load a model by name (must be discovered first)
  135. *
  136. * @param name The name of the model to load
  137. * @return true if the model was loaded successfully, false otherwise
  138. */
  139. bool loadModel(const std::string& name);
  140. /**
  141. * @brief Load a model by name with progress callback
  142. *
  143. * @param name The name of the model to load
  144. * @param progressCallback Callback function for progress updates (0.0-1.0)
  145. * @return true if the model was loaded successfully, false otherwise
  146. */
  147. bool loadModel(const std::string& name, std::function<void(float)> progressCallback);
  148. /**
  149. * @brief Unload a model
  150. *
  151. * @param name The name of the model to unload
  152. * @return true if the model was unloaded successfully, false otherwise
  153. */
  154. bool unloadModel(const std::string& name);
  155. /**
  156. * @brief Unload all currently loaded models
  157. *
  158. * This method unloads all models that are currently loaded, ensuring that
  159. * all contexts and parameters are properly freed. This is useful during
  160. * graceful shutdown to prevent memory leaks.
  161. */
  162. void unloadAllModels();
  163. /**
  164. * @brief Get a pointer to a loaded model
  165. *
  166. * @param name The name of the model
  167. * @return StableDiffusionWrapper* Pointer to the model wrapper, or nullptr if not found
  168. */
  169. StableDiffusionWrapper* getModel(const std::string& name);
  170. /**
  171. * @brief Get information about all models
  172. *
  173. * @return std::map<std::string, ModelInfo> Map of model names to their information
  174. */
  175. std::map<std::string, ModelInfo> getAllModels() const;
  176. /**
  177. * @brief Get information about models of a specific type
  178. *
  179. * @param type The model type to filter by
  180. * @return std::vector<ModelInfo> List of model information
  181. */
  182. std::vector<ModelInfo> getModelsByType(ModelType type) const;
  183. /**
  184. * @brief Get information about a specific model
  185. *
  186. * @param name The name of the model
  187. * @return ModelInfo Model information, or empty if not found
  188. */
  189. ModelInfo getModelInfo(const std::string& name) const;
  190. /**
  191. * @brief Check if a model is loaded
  192. *
  193. * @param name The name of the model
  194. * @return true if the model is loaded, false otherwise
  195. */
  196. bool isModelLoaded(const std::string& name) const;
  197. /**
  198. * @brief Get the number of loaded models
  199. *
  200. * @return size_t Number of loaded models
  201. */
  202. size_t getLoadedModelsCount() const;
  203. /**
  204. * @brief Get the number of available models
  205. *
  206. * @return size_t Number of available models
  207. */
  208. size_t getAvailableModelsCount() const;
  209. /**
  210. * @brief Set the models directory path
  211. *
  212. * @param path The path to the models directory
  213. */
  214. void setModelsDirectory(const std::string& path);
  215. /**
  216. * @brief Get the models directory path
  217. *
  218. * @return std::string The models directory path
  219. */
  220. std::string getModelsDirectory() const;
  221. /**
  222. * @brief Set directory for a specific model type
  223. *
  224. * @param type The model type
  225. * @param path The directory path
  226. * @return true if the directory was set successfully, false otherwise
  227. */
  228. bool setModelTypeDirectory(ModelType type, const std::string& path);
  229. /**
  230. * @brief Get directory for a specific model type
  231. *
  232. * @param type The model type
  233. * @return std::string The directory path, empty if not set
  234. */
  235. std::string getModelTypeDirectory(ModelType type) const;
  236. /**
  237. * @brief Set all model type directories at once
  238. *
  239. * @param directories Map of model types to directory paths
  240. * @return true if all directories were set successfully, false otherwise
  241. */
  242. bool setAllModelTypeDirectories(const std::map<ModelType, std::string>& directories);
  243. /**
  244. * @brief Get all model type directories
  245. *
  246. * @return std::map<ModelType, std::string> Map of model types to directory paths
  247. */
  248. std::map<ModelType, std::string> getAllModelTypeDirectories() const;
  249. // Legacy methods removed - using explicit directory configuration only
  250. /**
  251. * @brief Configure ModelManager with ServerConfig
  252. *
  253. * @param config The server configuration
  254. * @return true if configuration was successful, false otherwise
  255. */
  256. bool configureFromServerConfig(const struct ServerConfig& config);
  257. /**
  258. * @brief Convert ModelType to string
  259. *
  260. * @param type The model type
  261. * @return std::string String representation of the model type
  262. */
  263. static std::string modelTypeToString(ModelType type);
  264. /**
  265. * @brief Convert string to ModelType
  266. *
  267. * @param typeStr String representation of the model type
  268. * @return ModelType The model type
  269. */
  270. static ModelType stringToModelType(const std::string& typeStr);
  271. /**
  272. * @brief Compute SHA256 hash of a model file
  273. *
  274. * @param modelName The name of the model
  275. * @return std::string The SHA256 hash, or empty string on error
  276. */
  277. std::string computeModelHash(const std::string& modelName);
  278. /**
  279. * @brief Load hash from JSON file for a model
  280. *
  281. * @param modelName The name of the model
  282. * @return std::string The loaded hash, or empty string if not found
  283. */
  284. std::string loadModelHashFromFile(const std::string& modelName);
  285. /**
  286. * @brief Save hash to JSON file for a model
  287. *
  288. * @param modelName The name of the model
  289. * @param hash The SHA256 hash to save
  290. * @return true if saved successfully, false otherwise
  291. */
  292. bool saveModelHashToFile(const std::string& modelName, const std::string& hash);
  293. /**
  294. * @brief Find model by hash (full or partial - minimum 10 chars)
  295. *
  296. * @param hash Full or partial SHA256 hash (minimum 10 characters)
  297. * @return std::string Model name, or empty string if not found
  298. */
  299. std::string findModelByHash(const std::string& hash);
  300. /**
  301. * @brief Load hash for a model (from file or compute if missing)
  302. *
  303. * @param modelName The name of the model
  304. * @param forceCompute Force recomputation even if hash file exists
  305. * @return std::string The SHA256 hash, or empty string on error
  306. */
  307. std::string ensureModelHash(const std::string& modelName, bool forceCompute = false);
  308. /**
  309. * @brief Check if required models exist in the appropriate directories
  310. *
  311. * @param requiredModels List of required model names with types (e.g., "VAE: model.safetensors")
  312. * @return std::vector<ModelDetails> List of model details with existence information
  313. */
  314. std::vector<ModelDetails> checkRequiredModelsExistence(const std::vector<std::string>& requiredModels);
  315. /**
  316. * @brief Hash all models in the models directory with progress reporting
  317. *
  318. * @param progressCallback Callback function for progress updates (0.0-1.0)
  319. * @return true if all models were hashed successfully, false otherwise
  320. */
  321. bool hashAllModels(std::function<void(int, int, const std::string&)> progressCallback = nullptr);
  322. private:
  323. class Impl;
  324. std::unique_ptr<Impl> pImpl; // Pimpl idiom
  325. /**
  326. * @brief Model path selector class for folder-based detection
  327. */
  328. class ModelPathSelector {
  329. public:
  330. /**
  331. * @brief Determine which path to use based on folder location
  332. *
  333. * @param modelPath The absolute path to the model file
  334. * @param checkpointsDir The checkpoints directory path
  335. * @param diffusionModelsDir The diffusion_models directory path
  336. * @return std::string "model_path" for checkpoints, "diffusion_model_path" for diffusion_models
  337. */
  338. static std::string selectPathType(
  339. const std::string& modelPath,
  340. const std::string& checkpointsDir,
  341. const std::string& diffusionModelsDir,
  342. bool verbose = false
  343. );
  344. /**
  345. * @brief Check if model is in a specific directory
  346. *
  347. * @param modelPath The absolute path to the model file
  348. * @param directory The directory to check
  349. * @return true if model is in the specified directory
  350. */
  351. static bool isModelInDirectory(const std::string& modelPath, const std::string& directory);
  352. };
  353. /**
  354. * @brief Model detection cache class for caching detection results
  355. */
  356. class ModelDetectionCache {
  357. public:
  358. /**
  359. * @brief Cache entry for model detection results
  360. */
  361. struct CacheEntry {
  362. std::string architecture;
  363. std::string recommendedVAE;
  364. int recommendedWidth = 0;
  365. int recommendedHeight = 0;
  366. int recommendedSteps = 0;
  367. std::string recommendedSampler;
  368. std::vector<std::string> requiredModels;
  369. std::vector<std::string> missingModels;
  370. std::string pathType;
  371. std::string detectionSource;
  372. std::filesystem::file_time_type cachedAt;
  373. std::filesystem::file_time_type fileModifiedAt;
  374. bool isValid = false;
  375. };
  376. /**
  377. * @brief Get cached detection result for a model
  378. *
  379. * @param modelPath The path to the model file
  380. * @param currentModifiedTime Current file modification time
  381. * @return CacheEntry Cached entry, or invalid if not found/expired
  382. */
  383. static CacheEntry getCachedResult(const std::string& modelPath,
  384. const std::filesystem::file_time_type& currentModifiedTime);
  385. /**
  386. * @brief Cache detection result for a model
  387. *
  388. * @param modelPath The path to the model file
  389. * @param detection The detection result to cache
  390. * @param pathType The path type used
  391. * @param detectionSource The source of detection
  392. * @param fileModifiedTime Current file modification time
  393. */
  394. static void cacheDetectionResult(
  395. const std::string& modelPath,
  396. const ModelDetectionResult& detection,
  397. const std::string& pathType,
  398. const std::string& detectionSource,
  399. const std::filesystem::file_time_type& fileModifiedTime
  400. );
  401. /**
  402. * @brief Invalidate cache for a model
  403. *
  404. * @param modelPath The path to the model file
  405. */
  406. static void invalidateCache(const std::string& modelPath);
  407. /**
  408. * @brief Clear all cached results
  409. */
  410. static void clearAllCache();
  411. private:
  412. static std::map<std::string, CacheEntry> cache_;
  413. static std::mutex cacheMutex_;
  414. };
  415. };
  416. #endif // MODEL_MANAGER_H