model_manager.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. #ifndef MODEL_MANAGER_H
  2. #define MODEL_MANAGER_H
  3. #include <string>
  4. #include <memory>
  5. #include <map>
  6. #include <shared_mutex>
  7. #include <vector>
  8. #include <filesystem>
  9. #include <cstdint>
  10. // Forward declarations
  11. class StableDiffusionWrapper;
  12. class ModelPathSelector;
  13. class ModelDetectionCache;
  14. #include "model_detector.h"
  15. #include "server_config.h"
  16. /**
  17. * @brief Model type enumeration
  18. *
  19. * These values are bit flags that can be combined to filter model types.
  20. */
  21. enum class ModelType : uint32_t {
  22. NONE = 0,
  23. LORA = 1,
  24. CHECKPOINT = 2,
  25. VAE = 4,
  26. PRESETS = 8,
  27. PROMPTS = 16,
  28. NEG_PROMPTS = 32,
  29. TAESD = 64,
  30. ESRGAN = 128,
  31. CONTROLNET = 256,
  32. UPSCALER = 512,
  33. EMBEDDING = 1024,
  34. DIFFUSION_MODELS = 2048
  35. };
  36. // Enable bitwise operations for ModelType
  37. inline ModelType operator|(ModelType a, ModelType b) {
  38. return static_cast<ModelType>(static_cast<uint32_t>(a) | static_cast<uint32_t>(b));
  39. }
  40. inline ModelType operator&(ModelType a, ModelType b) {
  41. return static_cast<ModelType>(static_cast<uint32_t>(a) & static_cast<uint32_t>(b));
  42. }
  43. inline ModelType& operator|=(ModelType& a, ModelType b) {
  44. a = a | b;
  45. return a;
  46. }
  47. /**
  48. * @brief Model manager class for loading and managing stable-diffusion models
  49. *
  50. * This class handles loading, unloading, and managing multiple stable-diffusion models.
  51. * It provides thread-safe access to models and manages model resources efficiently.
  52. */
  53. class ModelManager {
  54. public:
  55. /**
  56. * @brief Model information structure
  57. */
  58. struct ModelInfo {
  59. std::string name; ///< Model name
  60. std::string path; ///< Model file path
  61. std::string fullPath; ///< Absolute path to the model file
  62. ModelType type; ///< Model type
  63. bool isLoaded; ///< Whether the model is currently loaded
  64. size_t fileSize; ///< File size in bytes
  65. std::string sha256; ///< SHA256 hash of the file
  66. std::filesystem::file_time_type createdAt; ///< File creation time
  67. std::filesystem::file_time_type modifiedAt; ///< Last modification time
  68. std::string description; ///< Model description
  69. std::map<std::string, std::string> metadata; ///< Additional metadata
  70. // Architecture detection fields
  71. std::string architecture; ///< Detected architecture (e.g., "Stable Diffusion XL Base", "Flux Dev")
  72. std::string recommendedVAE; ///< Recommended VAE for this model
  73. int recommendedWidth = 0; ///< Recommended image width
  74. int recommendedHeight = 0; ///< Recommended image height
  75. int recommendedSteps = 0; ///< Recommended number of steps
  76. std::string recommendedSampler; ///< Recommended sampler
  77. std::vector<std::string> requiredModels; ///< List of required auxiliary models (VAE, CLIP, etc.)
  78. std::vector<std::string> missingModels; ///< List of missing required models
  79. // Caching-related fields
  80. bool cacheValid = false; ///< Whether cached detection results are valid
  81. std::filesystem::file_time_type cacheModifiedAt; ///< Modification time when cache was created
  82. std::string cachePathType; ///< Path type used for this model ("model_path" or "diffusion_model_path")
  83. bool useFolderBasedDetection = false; ///< Whether folder-based detection was used
  84. std::string detectionSource; ///< Source of detection: "folder", "architecture", "fallback"
  85. };
  86. /**
  87. * @brief Model details structure for existence checking
  88. */
  89. struct ModelDetails {
  90. std::string name; ///< Model name
  91. bool exists; ///< Whether the model exists
  92. std::string type; ///< Model type ("VAE", "CLIP-L", "CLIP-G", "T5XXL", "CLIP-Vision", "Qwen2VL")
  93. std::string path; ///< Absolute path to the model file (empty if doesn't exist)
  94. size_t file_size; ///< File size in bytes (0 if doesn't exist)
  95. std::string sha256; ///< SHA256 hash (empty if doesn't exist)
  96. bool is_required; ///< True for required models
  97. bool is_recommended; ///< True for recommended models
  98. };
  99. /**
  100. * @brief Construct a new Model Manager object
  101. */
  102. ModelManager();
  103. /**
  104. * @brief Destroy the Model Manager object
  105. */
  106. virtual ~ModelManager();
  107. /**
  108. * @brief Scan the models directory to discover available models
  109. *
  110. * Recursively scans all subdirectories within the models directory to find
  111. * model files. For each model found, constructs the display name as
  112. * 'relative_path/model_name' where relative_path is the path from the models
  113. * root directory to the file's containing folder (using forward slashes).
  114. * Models in the root directory appear without a prefix.
  115. *
  116. * @return true if scanning was successful, false otherwise
  117. */
  118. bool scanModelsDirectory();
  119. /**
  120. * @brief Cancel any ongoing model directory scanning
  121. */
  122. void cancelScan();
  123. /**
  124. * @brief Load a model from the specified path
  125. *
  126. * @param name The name to assign to the model
  127. * @param path The file path to the model
  128. * @param type The type of model
  129. * @return true if the model was loaded successfully, false otherwise
  130. */
  131. bool loadModel(const std::string& name, const std::string& path, ModelType type);
  132. /**
  133. * @brief Load a model by name (must be discovered first)
  134. *
  135. * @param name The name of the model to load
  136. * @return true if the model was loaded successfully, false otherwise
  137. */
  138. bool loadModel(const std::string& name);
  139. /**
  140. * @brief Unload a model
  141. *
  142. * @param name The name of the model to unload
  143. * @return true if the model was unloaded successfully, false otherwise
  144. */
  145. bool unloadModel(const std::string& name);
  146. /**
  147. * @brief Get a pointer to a loaded model
  148. *
  149. * @param name The name of the model
  150. * @return StableDiffusionWrapper* Pointer to the model wrapper, or nullptr if not found
  151. */
  152. StableDiffusionWrapper* getModel(const std::string& name);
  153. /**
  154. * @brief Get information about all models
  155. *
  156. * @return std::map<std::string, ModelInfo> Map of model names to their information
  157. */
  158. std::map<std::string, ModelInfo> getAllModels() const;
  159. /**
  160. * @brief Get information about models of a specific type
  161. *
  162. * @param type The model type to filter by
  163. * @return std::vector<ModelInfo> List of model information
  164. */
  165. std::vector<ModelInfo> getModelsByType(ModelType type) const;
  166. /**
  167. * @brief Get information about a specific model
  168. *
  169. * @param name The name of the model
  170. * @return ModelInfo Model information, or empty if not found
  171. */
  172. ModelInfo getModelInfo(const std::string& name) const;
  173. /**
  174. * @brief Check if a model is loaded
  175. *
  176. * @param name The name of the model
  177. * @return true if the model is loaded, false otherwise
  178. */
  179. bool isModelLoaded(const std::string& name) const;
  180. /**
  181. * @brief Get the number of loaded models
  182. *
  183. * @return size_t Number of loaded models
  184. */
  185. size_t getLoadedModelsCount() const;
  186. /**
  187. * @brief Get the number of available models
  188. *
  189. * @return size_t Number of available models
  190. */
  191. size_t getAvailableModelsCount() const;
  192. /**
  193. * @brief Set the models directory path
  194. *
  195. * @param path The path to the models directory
  196. */
  197. void setModelsDirectory(const std::string& path);
  198. /**
  199. * @brief Get the models directory path
  200. *
  201. * @return std::string The models directory path
  202. */
  203. std::string getModelsDirectory() const;
  204. /**
  205. * @brief Set directory for a specific model type
  206. *
  207. * @param type The model type
  208. * @param path The directory path
  209. * @return true if the directory was set successfully, false otherwise
  210. */
  211. bool setModelTypeDirectory(ModelType type, const std::string& path);
  212. /**
  213. * @brief Get directory for a specific model type
  214. *
  215. * @param type The model type
  216. * @return std::string The directory path, empty if not set
  217. */
  218. std::string getModelTypeDirectory(ModelType type) const;
  219. /**
  220. * @brief Set all model type directories at once
  221. *
  222. * @param directories Map of model types to directory paths
  223. * @return true if all directories were set successfully, false otherwise
  224. */
  225. bool setAllModelTypeDirectories(const std::map<ModelType, std::string>& directories);
  226. /**
  227. * @brief Get all model type directories
  228. *
  229. * @return std::map<ModelType, std::string> Map of model types to directory paths
  230. */
  231. std::map<ModelType, std::string> getAllModelTypeDirectories() const;
  232. // Legacy methods removed - using explicit directory configuration only
  233. /**
  234. * @brief Configure ModelManager with ServerConfig
  235. *
  236. * @param config The server configuration
  237. * @return true if configuration was successful, false otherwise
  238. */
  239. bool configureFromServerConfig(const struct ServerConfig& config);
  240. /**
  241. * @brief Convert ModelType to string
  242. *
  243. * @param type The model type
  244. * @return std::string String representation of the model type
  245. */
  246. static std::string modelTypeToString(ModelType type);
  247. /**
  248. * @brief Convert string to ModelType
  249. *
  250. * @param typeStr String representation of the model type
  251. * @return ModelType The model type
  252. */
  253. static ModelType stringToModelType(const std::string& typeStr);
  254. /**
  255. * @brief Compute SHA256 hash of a model file
  256. *
  257. * @param modelName The name of the model
  258. * @return std::string The SHA256 hash, or empty string on error
  259. */
  260. std::string computeModelHash(const std::string& modelName);
  261. /**
  262. * @brief Load hash from JSON file for a model
  263. *
  264. * @param modelName The name of the model
  265. * @return std::string The loaded hash, or empty string if not found
  266. */
  267. std::string loadModelHashFromFile(const std::string& modelName);
  268. /**
  269. * @brief Save hash to JSON file for a model
  270. *
  271. * @param modelName The name of the model
  272. * @param hash The SHA256 hash to save
  273. * @return true if saved successfully, false otherwise
  274. */
  275. bool saveModelHashToFile(const std::string& modelName, const std::string& hash);
  276. /**
  277. * @brief Find model by hash (full or partial - minimum 10 chars)
  278. *
  279. * @param hash Full or partial SHA256 hash (minimum 10 characters)
  280. * @return std::string Model name, or empty string if not found
  281. */
  282. std::string findModelByHash(const std::string& hash);
  283. /**
  284. * @brief Load hash for a model (from file or compute if missing)
  285. *
  286. * @param modelName The name of the model
  287. * @param forceCompute Force recomputation even if hash file exists
  288. * @return std::string The SHA256 hash, or empty string on error
  289. */
  290. std::string ensureModelHash(const std::string& modelName, bool forceCompute = false);
  291. /**
  292. * @brief Check if required models exist in the appropriate directories
  293. *
  294. * @param requiredModels List of required model names with types (e.g., "VAE: model.safetensors")
  295. * @return std::vector<ModelDetails> List of model details with existence information
  296. */
  297. std::vector<ModelDetails> checkRequiredModelsExistence(const std::vector<std::string>& requiredModels);
  298. private:
  299. class Impl;
  300. std::unique_ptr<Impl> pImpl; // Pimpl idiom
  301. /**
  302. * @brief Model path selector class for folder-based detection
  303. */
  304. class ModelPathSelector {
  305. public:
  306. /**
  307. * @brief Determine which path to use based on folder location
  308. *
  309. * @param modelPath The absolute path to the model file
  310. * @param checkpointsDir The checkpoints directory path
  311. * @param diffusionModelsDir The diffusion_models directory path
  312. * @return std::string "model_path" for checkpoints, "diffusion_model_path" for diffusion_models
  313. */
  314. static std::string selectPathType(
  315. const std::string& modelPath,
  316. const std::string& checkpointsDir,
  317. const std::string& diffusionModelsDir
  318. );
  319. /**
  320. * @brief Check if model is in a specific directory
  321. *
  322. * @param modelPath The absolute path to the model file
  323. * @param directory The directory to check
  324. * @return true if model is in the specified directory
  325. */
  326. static bool isModelInDirectory(const std::string& modelPath, const std::string& directory);
  327. };
  328. /**
  329. * @brief Model detection cache class for caching detection results
  330. */
  331. class ModelDetectionCache {
  332. public:
  333. /**
  334. * @brief Cache entry for model detection results
  335. */
  336. struct CacheEntry {
  337. std::string architecture;
  338. std::string recommendedVAE;
  339. int recommendedWidth = 0;
  340. int recommendedHeight = 0;
  341. int recommendedSteps = 0;
  342. std::string recommendedSampler;
  343. std::vector<std::string> requiredModels;
  344. std::vector<std::string> missingModels;
  345. std::string pathType;
  346. std::string detectionSource;
  347. std::filesystem::file_time_type cachedAt;
  348. std::filesystem::file_time_type fileModifiedAt;
  349. bool isValid = false;
  350. };
  351. /**
  352. * @brief Get cached detection result for a model
  353. *
  354. * @param modelPath The path to the model file
  355. * @param currentModifiedTime Current file modification time
  356. * @return CacheEntry Cached entry, or invalid if not found/expired
  357. */
  358. static CacheEntry getCachedResult(const std::string& modelPath,
  359. const std::filesystem::file_time_type& currentModifiedTime);
  360. /**
  361. * @brief Cache detection result for a model
  362. *
  363. * @param modelPath The path to the model file
  364. * @param detection The detection result to cache
  365. * @param pathType The path type used
  366. * @param detectionSource The source of detection
  367. * @param fileModifiedTime Current file modification time
  368. */
  369. static void cacheDetectionResult(
  370. const std::string& modelPath,
  371. const ModelDetectionResult& detection,
  372. const std::string& pathType,
  373. const std::string& detectionSource,
  374. const std::filesystem::file_time_type& fileModifiedTime
  375. );
  376. /**
  377. * @brief Invalidate cache for a model
  378. *
  379. * @param modelPath The path to the model file
  380. */
  381. static void invalidateCache(const std::string& modelPath);
  382. /**
  383. * @brief Clear all cached results
  384. */
  385. static void clearAllCache();
  386. private:
  387. static std::map<std::string, CacheEntry> cache_;
  388. static std::mutex cacheMutex_;
  389. };
  390. };
  391. #endif // MODEL_MANAGER_H