model_manager.cpp 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065
  1. #include "model_manager.h"
  2. #include "model_detector.h"
  3. #include "stable_diffusion_wrapper.h"
  4. #include <iostream>
  5. #include <fstream>
  6. #include <algorithm>
  7. #include <filesystem>
  8. #include <shared_mutex>
  9. #include <chrono>
  10. #include <future>
  11. #include <atomic>
  12. #include <openssl/evp.h>
  13. #include <sstream>
  14. #include <iomanip>
  15. #include <nlohmann/json.hpp>
  16. namespace fs = std::filesystem;
  17. // File extension mappings for each model type
  18. const std::vector<std::string> CHECKPOINT_FILE_EXTENSIONS = {"safetensors", "ckpt", "gguf"};
  19. const std::vector<std::string> EMBEDDING_FILE_EXTENSIONS = {"safetensors", "pt"};
  20. const std::vector<std::string> LORA_FILE_EXTENSIONS = {"safetensors", "ckpt"};
  21. const std::vector<std::string> VAE_FILE_EXTENSIONS = {"safetensors", "pt", "ckpt", "gguf"};
  22. const std::vector<std::string> TAESD_FILE_EXTENSIONS = {"safetensors", "pth", "gguf"};
  23. const std::vector<std::string> ESRGAN_FILE_EXTENSIONS = {"pth", "pt"};
  24. const std::vector<std::string> CONTROLNET_FILE_EXTENSIONS = {"safetensors", "pth"};
  25. class ModelManager::Impl {
  26. public:
  27. std::string modelsDirectory = "./models";
  28. std::map<ModelType, std::string> modelTypeDirectories;
  29. std::map<std::string, ModelInfo> availableModels;
  30. std::map<std::string, std::unique_ptr<StableDiffusionWrapper>> loadedModels;
  31. mutable std::shared_mutex modelsMutex;
  32. std::atomic<bool> scanCancelled{false};
  33. bool legacyMode = true;
  34. /**
  35. * @brief Validate a directory path
  36. *
  37. * @param path The directory path to validate
  38. * @return true if the directory exists and is valid, false otherwise
  39. */
  40. bool validateDirectory(const std::string& path) const {
  41. if (path.empty()) {
  42. return false;
  43. }
  44. std::filesystem::path dirPath(path);
  45. if (!std::filesystem::exists(dirPath)) {
  46. std::cerr << "Directory does not exist: " << path << std::endl;
  47. return false;
  48. }
  49. if (!std::filesystem::is_directory(dirPath)) {
  50. std::cerr << "Path is not a directory: " << path << std::endl;
  51. return false;
  52. }
  53. return true;
  54. }
  55. /**
  56. * @brief Get default directory name for a model type
  57. *
  58. * @param type The model type
  59. * @return std::string Default directory name
  60. */
  61. std::string getDefaultDirectoryName(ModelType type) const {
  62. switch (type) {
  63. case ModelType::CHECKPOINT:
  64. return "checkpoints";
  65. case ModelType::CONTROLNET:
  66. return "controlnet";
  67. case ModelType::EMBEDDING:
  68. return "embeddings";
  69. case ModelType::ESRGAN:
  70. return "esrgan";
  71. case ModelType::LORA:
  72. return "lora";
  73. case ModelType::TAESD:
  74. return "taesd";
  75. case ModelType::VAE:
  76. return "vae";
  77. default:
  78. return "";
  79. }
  80. }
  81. /**
  82. * @brief Get directory path for a model type
  83. *
  84. * @param type The model type
  85. * @return std::string Directory path, empty if not set
  86. */
  87. std::string getModelTypeDirectory(ModelType type) const {
  88. auto it = modelTypeDirectories.find(type);
  89. if (it != modelTypeDirectories.end()) {
  90. return it->second;
  91. }
  92. // If in legacy mode, construct default path
  93. if (legacyMode) {
  94. std::string defaultDir = getDefaultDirectoryName(type);
  95. if (!defaultDir.empty()) {
  96. return modelsDirectory + "/" + defaultDir;
  97. }
  98. }
  99. return "";
  100. }
  101. /**
  102. * @brief Get file extensions for a specific model type
  103. *
  104. * @param type The model type
  105. * @return const std::vector<std::string>& Vector of file extensions
  106. */
  107. const std::vector<std::string>& getFileExtensions(ModelType type) const {
  108. switch (type) {
  109. case ModelType::CHECKPOINT:
  110. return CHECKPOINT_FILE_EXTENSIONS;
  111. case ModelType::EMBEDDING:
  112. return EMBEDDING_FILE_EXTENSIONS;
  113. case ModelType::LORA:
  114. return LORA_FILE_EXTENSIONS;
  115. case ModelType::VAE:
  116. return VAE_FILE_EXTENSIONS;
  117. case ModelType::TAESD:
  118. return TAESD_FILE_EXTENSIONS;
  119. case ModelType::ESRGAN:
  120. return ESRGAN_FILE_EXTENSIONS;
  121. case ModelType::CONTROLNET:
  122. return CONTROLNET_FILE_EXTENSIONS;
  123. default:
  124. static const std::vector<std::string> empty;
  125. return empty;
  126. }
  127. }
  128. /**
  129. * @brief Check if a file extension matches a model type
  130. *
  131. * @param extension The file extension
  132. * @param type The model type
  133. * @return true if the extension matches the model type
  134. */
  135. bool isExtensionMatch(const std::string& extension, ModelType type) const {
  136. const auto& extensions = getFileExtensions(type);
  137. return std::find(extensions.begin(), extensions.end(), extension) != extensions.end();
  138. }
  139. /**
  140. * @brief Determine model type based on file path and extension
  141. *
  142. * @param filePath The file path
  143. * @return ModelType The determined model type
  144. */
  145. ModelType determineModelType(const fs::path& filePath) const {
  146. std::string extension = filePath.extension().string();
  147. if (extension.empty()) {
  148. return ModelType::NONE;
  149. }
  150. // Remove the dot from extension
  151. if (extension[0] == '.') {
  152. extension = extension.substr(1);
  153. }
  154. // Convert to lowercase for comparison
  155. std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower);
  156. // Check if the file resides under a directory registered for a given ModelType
  157. fs::path absoluteFilePath = fs::absolute(filePath);
  158. // First check configured directories (if any)
  159. for (const auto& [type, directory] : modelTypeDirectories) {
  160. if (!directory.empty()) {
  161. fs::path absoluteDirPath = fs::absolute(directory).lexically_normal();
  162. fs::path normalizedFilePath = absoluteFilePath.lexically_normal();
  163. // Check if the file is under this directory (directly or in subdirectories)
  164. // Get the relative path from directory to file
  165. auto relativePath = normalizedFilePath.lexically_relative(absoluteDirPath);
  166. // If relative path doesn't start with "..", then file is under the directory
  167. std::string relPathStr = relativePath.string();
  168. bool isUnderDirectory = !relPathStr.empty() &&
  169. relPathStr.substr(0, 2) != ".." &&
  170. relPathStr[0] != '/';
  171. if (isUnderDirectory && isExtensionMatch(extension, type)) {
  172. return type;
  173. }
  174. }
  175. }
  176. // If in legacy mode or no configured directories matched, check default directory structure
  177. if (legacyMode || modelTypeDirectories.empty()) {
  178. std::string parentPath = filePath.parent_path().filename().string();
  179. std::transform(parentPath.begin(), parentPath.end(), parentPath.begin(), ::tolower);
  180. // Check default directory names
  181. if (parentPath == "checkpoints" || parentPath == "stable-diffusion") {
  182. if (isExtensionMatch(extension, ModelType::CHECKPOINT)) {
  183. return ModelType::CHECKPOINT;
  184. }
  185. } else if (parentPath == "controlnet") {
  186. if (isExtensionMatch(extension, ModelType::CONTROLNET)) {
  187. return ModelType::CONTROLNET;
  188. }
  189. } else if (parentPath == "lora") {
  190. if (isExtensionMatch(extension, ModelType::LORA)) {
  191. return ModelType::LORA;
  192. }
  193. } else if (parentPath == "vae") {
  194. if (isExtensionMatch(extension, ModelType::VAE)) {
  195. return ModelType::VAE;
  196. }
  197. } else if (parentPath == "taesd") {
  198. if (isExtensionMatch(extension, ModelType::TAESD)) {
  199. return ModelType::TAESD;
  200. }
  201. } else if (parentPath == "esrgan" || parentPath == "upscaler") {
  202. if (isExtensionMatch(extension, ModelType::ESRGAN)) {
  203. return ModelType::ESRGAN;
  204. }
  205. } else if (parentPath == "embeddings" || parentPath == "textual-inversion") {
  206. if (isExtensionMatch(extension, ModelType::EMBEDDING)) {
  207. return ModelType::EMBEDDING;
  208. }
  209. }
  210. }
  211. // Fall back to extension-based detection
  212. // Only return a model type if the extension matches expected extensions for that type
  213. if (isExtensionMatch(extension, ModelType::CHECKPOINT)) {
  214. return ModelType::CHECKPOINT;
  215. } else if (isExtensionMatch(extension, ModelType::LORA)) {
  216. return ModelType::LORA;
  217. } else if (isExtensionMatch(extension, ModelType::VAE)) {
  218. return ModelType::VAE;
  219. } else if (isExtensionMatch(extension, ModelType::TAESD)) {
  220. return ModelType::TAESD;
  221. } else if (isExtensionMatch(extension, ModelType::ESRGAN)) {
  222. return ModelType::ESRGAN;
  223. } else if (isExtensionMatch(extension, ModelType::CONTROLNET)) {
  224. return ModelType::CONTROLNET;
  225. } else if (isExtensionMatch(extension, ModelType::EMBEDDING)) {
  226. return ModelType::EMBEDDING;
  227. }
  228. return ModelType::NONE;
  229. }
  230. /**
  231. * @brief Get file information with timeout
  232. *
  233. * @param filePath The file path to get info for
  234. * @param timeoutMs Timeout in milliseconds
  235. * @return std::pair<bool, std::pair<uintmax_t, fs::file_time_type>> Success flag and file info
  236. */
  237. std::pair<bool, std::pair<uintmax_t, fs::file_time_type>> getFileInfoWithTimeout(
  238. const fs::path& filePath, int timeoutMs = 5000) {
  239. auto future = std::async(std::launch::async, [&filePath]() -> std::pair<uintmax_t, fs::file_time_type> {
  240. try {
  241. uintmax_t fileSize = fs::file_size(filePath);
  242. fs::file_time_type modifiedAt = fs::last_write_time(filePath);
  243. return {fileSize, modifiedAt};
  244. } catch (const fs::filesystem_error&) {
  245. return {0, fs::file_time_type{}};
  246. }
  247. });
  248. if (future.wait_for(std::chrono::milliseconds(timeoutMs)) == std::future_status::timeout) {
  249. std::cerr << "Timeout getting file info for " << filePath << std::endl;
  250. return {false, {0, fs::file_time_type{}}};
  251. }
  252. return {true, future.get()};
  253. }
  254. /**
  255. * @brief Scan a directory for models of a specific type (without holding mutex)
  256. *
  257. * @param directory The directory to scan
  258. * @param type The model type to look for
  259. * @param modelsMap Reference to the map to store results
  260. * @return bool True if scanning completed without cancellation
  261. */
  262. bool scanDirectory(const fs::path& directory, ModelType type, std::map<std::string, ModelInfo>& modelsMap) {
  263. if (scanCancelled.load()) {
  264. return false;
  265. }
  266. if (!fs::exists(directory) || !fs::is_directory(directory)) {
  267. return true;
  268. }
  269. try {
  270. for (const auto& entry : fs::recursive_directory_iterator(directory)) {
  271. if (scanCancelled.load()) {
  272. return false;
  273. }
  274. if (entry.is_regular_file()) {
  275. fs::path filePath = entry.path();
  276. ModelType detectedType = determineModelType(filePath);
  277. // Only add files that have a valid model type (not NONE)
  278. if (detectedType != ModelType::NONE && (type == ModelType::NONE || detectedType == type)) {
  279. ModelInfo info;
  280. // Calculate relative path from the scanned directory (not base models directory)
  281. fs::path relativePath = fs::relative(filePath, directory);
  282. std::string modelName = relativePath.string();
  283. // Check if model already exists to avoid duplicates
  284. if (modelsMap.find(modelName) == modelsMap.end()) {
  285. info.name = modelName;
  286. info.path = filePath.string();
  287. info.fullPath = fs::absolute(filePath).string();
  288. info.type = detectedType;
  289. info.isLoaded = false;
  290. info.description = ""; // Initialize description
  291. info.metadata = {}; // Initialize metadata
  292. // Get file info with timeout
  293. auto [success, fileInfo] = getFileInfoWithTimeout(filePath);
  294. if (success) {
  295. info.fileSize = fileInfo.first;
  296. info.modifiedAt = fileInfo.second;
  297. info.createdAt = fileInfo.second; // Use modified time as creation time for now
  298. } else {
  299. info.fileSize = 0;
  300. info.modifiedAt = fs::file_time_type{};
  301. info.createdAt = fs::file_time_type{};
  302. }
  303. // Try to load cached hash from .json file
  304. std::string hashFile = info.fullPath + ".json";
  305. if (fs::exists(hashFile)) {
  306. try {
  307. std::ifstream file(hashFile);
  308. nlohmann::json hashData = nlohmann::json::parse(file);
  309. if (hashData.contains("sha256") && hashData["sha256"].is_string()) {
  310. info.sha256 = hashData["sha256"];
  311. } else {
  312. info.sha256 = "";
  313. }
  314. } catch (...) {
  315. info.sha256 = ""; // If parsing fails, leave empty
  316. }
  317. } else {
  318. info.sha256 = ""; // No cached hash file
  319. }
  320. // Detect architecture for checkpoint models
  321. if (detectedType == ModelType::CHECKPOINT) {
  322. try {
  323. ModelDetectionResult detection = ModelDetector::detectModel(info.fullPath);
  324. // For .ckpt files that can't be detected, default to SD1.5
  325. if (detection.architecture == ModelArchitecture::UNKNOWN &&
  326. (filePath.extension() == ".ckpt" || filePath.extension() == ".pt")) {
  327. info.architecture = "Stable Diffusion 1.5 (assumed)";
  328. info.recommendedVAE = "vae-ft-mse-840000-ema-pruned.safetensors";
  329. info.recommendedWidth = 512;
  330. info.recommendedHeight = 512;
  331. info.recommendedSteps = 20;
  332. info.recommendedSampler = "euler_a";
  333. } else {
  334. info.architecture = detection.architectureName;
  335. info.recommendedVAE = detection.recommendedVAE;
  336. // Parse recommended parameters
  337. if (detection.suggestedParams.count("width")) {
  338. info.recommendedWidth = std::stoi(detection.suggestedParams["width"]);
  339. }
  340. if (detection.suggestedParams.count("height")) {
  341. info.recommendedHeight = std::stoi(detection.suggestedParams["height"]);
  342. }
  343. if (detection.suggestedParams.count("steps")) {
  344. info.recommendedSteps = std::stoi(detection.suggestedParams["steps"]);
  345. }
  346. if (detection.suggestedParams.count("sampler")) {
  347. info.recommendedSampler = detection.suggestedParams["sampler"];
  348. }
  349. }
  350. // Build list of required models based on architecture
  351. if (detection.needsVAE && !detection.recommendedVAE.empty()) {
  352. info.requiredModels.push_back("VAE: " + detection.recommendedVAE);
  353. }
  354. // Add CLIP-L if required
  355. if (detection.suggestedParams.count("clip_l_required")) {
  356. info.requiredModels.push_back("CLIP-L: " + detection.suggestedParams.at("clip_l_required"));
  357. }
  358. // Add CLIP-G if required
  359. if (detection.suggestedParams.count("clip_g_required")) {
  360. info.requiredModels.push_back("CLIP-G: " + detection.suggestedParams.at("clip_g_required"));
  361. }
  362. // Add T5XXL if required
  363. if (detection.suggestedParams.count("t5xxl_required")) {
  364. info.requiredModels.push_back("T5XXL: " + detection.suggestedParams.at("t5xxl_required"));
  365. }
  366. // Add Qwen models if required
  367. if (detection.suggestedParams.count("qwen2vl_required")) {
  368. info.requiredModels.push_back("Qwen2-VL: " + detection.suggestedParams.at("qwen2vl_required"));
  369. }
  370. if (detection.suggestedParams.count("qwen2vl_vision_required")) {
  371. info.requiredModels.push_back("Qwen2-VL-Vision: " + detection.suggestedParams.at("qwen2vl_vision_required"));
  372. }
  373. } catch (const std::exception& e) {
  374. // If detection fails completely, default to SD1.5
  375. info.architecture = "Stable Diffusion 1.5 (assumed)";
  376. info.recommendedVAE = "vae-ft-mse-840000-ema-pruned.safetensors";
  377. info.recommendedWidth = 512;
  378. info.recommendedHeight = 512;
  379. info.recommendedSteps = 20;
  380. info.recommendedSampler = "euler_a";
  381. }
  382. }
  383. modelsMap[info.name] = info;
  384. }
  385. }
  386. }
  387. }
  388. } catch (const fs::filesystem_error& e) {
  389. // Silently handle filesystem errors
  390. }
  391. return !scanCancelled.load();
  392. }
  393. };
  394. ModelManager::ModelManager() : pImpl(std::make_unique<Impl>()) {
  395. }
  396. ModelManager::~ModelManager() = default;
  397. bool ModelManager::scanModelsDirectory() {
  398. // Reset cancellation flag
  399. pImpl->scanCancelled.store(false);
  400. // Create temporary map to store scan results (outside of lock)
  401. std::map<std::string, ModelInfo> tempModels;
  402. if (pImpl->legacyMode) {
  403. // Legacy mode: scan the models directory itself and its subdirectories
  404. fs::path modelsPath(pImpl->modelsDirectory);
  405. if (!fs::exists(modelsPath) || !fs::is_directory(modelsPath)) {
  406. std::cerr << "Models directory does not exist: " << pImpl->modelsDirectory << std::endl;
  407. return false;
  408. }
  409. // First, scan the main models directory itself for any model files
  410. // This handles the case where models are directly in the specified directory
  411. if (!pImpl->scanDirectory(modelsPath, ModelType::NONE, tempModels)) {
  412. return false;
  413. }
  414. // Then scan known subdirectories for organized models
  415. std::vector<std::pair<fs::path, ModelType>> directoriesToScan = {
  416. {modelsPath / "stable-diffusion", ModelType::CHECKPOINT},
  417. {modelsPath / "controlnet", ModelType::CONTROLNET},
  418. {modelsPath / "lora", ModelType::LORA},
  419. {modelsPath / "vae", ModelType::VAE},
  420. {modelsPath / "taesd", ModelType::TAESD},
  421. {modelsPath / "esrgan", ModelType::ESRGAN},
  422. {modelsPath / "upscaler", ModelType::ESRGAN},
  423. {modelsPath / "embeddings", ModelType::EMBEDDING},
  424. {modelsPath / "textual-inversion", ModelType::EMBEDDING},
  425. {modelsPath / "checkpoints", ModelType::CHECKPOINT}, // Also scan checkpoints subdirectory
  426. {modelsPath / "other", ModelType::NONE} // Scan for any type
  427. };
  428. for (const auto& [dirPath, type] : directoriesToScan) {
  429. if (!pImpl->scanDirectory(dirPath, type, tempModels)) {
  430. return false;
  431. }
  432. }
  433. } else {
  434. // Explicit mode: scan configured directories for each model type
  435. std::vector<std::pair<ModelType, std::string>> directoriesToScan = {
  436. {ModelType::CHECKPOINT, pImpl->getModelTypeDirectory(ModelType::CHECKPOINT)},
  437. {ModelType::CONTROLNET, pImpl->getModelTypeDirectory(ModelType::CONTROLNET)},
  438. {ModelType::LORA, pImpl->getModelTypeDirectory(ModelType::LORA)},
  439. {ModelType::VAE, pImpl->getModelTypeDirectory(ModelType::VAE)},
  440. {ModelType::TAESD, pImpl->getModelTypeDirectory(ModelType::TAESD)},
  441. {ModelType::ESRGAN, pImpl->getModelTypeDirectory(ModelType::ESRGAN)},
  442. {ModelType::EMBEDDING, pImpl->getModelTypeDirectory(ModelType::EMBEDDING)}
  443. };
  444. for (const auto& [type, dirPath] : directoriesToScan) {
  445. if (!dirPath.empty()) {
  446. if (!pImpl->scanDirectory(dirPath, type, tempModels)) {
  447. return false;
  448. }
  449. }
  450. }
  451. }
  452. // Brief exclusive lock only to swap the data
  453. {
  454. std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  455. pImpl->availableModels.swap(tempModels);
  456. }
  457. return true;
  458. }
  459. bool ModelManager::loadModel(const std::string& name, const std::string& path, ModelType type) {
  460. std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  461. // Check if model is already loaded
  462. if (pImpl->loadedModels.find(name) != pImpl->loadedModels.end()) {
  463. return true;
  464. }
  465. // Check if file exists
  466. if (!fs::exists(path)) {
  467. std::cerr << "Model file does not exist: " << path << std::endl;
  468. return false;
  469. }
  470. // Create and initialize the stable-diffusion wrapper
  471. auto wrapper = std::make_unique<StableDiffusionWrapper>();
  472. // Set up generation parameters for model loading
  473. StableDiffusionWrapper::GenerationParams loadParams;
  474. loadParams.modelPath = path;
  475. loadParams.modelType = "f16"; // Default to f16 for better performance
  476. // Try to load the model
  477. if (!wrapper->loadModel(path, loadParams)) {
  478. std::cerr << "Failed to load model '" << name << "': " << wrapper->getLastError() << std::endl;
  479. return false;
  480. }
  481. pImpl->loadedModels[name] = std::move(wrapper);
  482. // Update model info
  483. if (pImpl->availableModels.find(name) != pImpl->availableModels.end()) {
  484. pImpl->availableModels[name].isLoaded = true;
  485. } else {
  486. // Create a new model info entry
  487. ModelInfo info;
  488. info.name = name;
  489. info.path = path;
  490. info.fullPath = fs::absolute(path).string();
  491. info.type = type;
  492. info.isLoaded = true;
  493. info.sha256 = "";
  494. info.description = ""; // Initialize description
  495. info.metadata = {}; // Initialize metadata
  496. try {
  497. info.fileSize = fs::file_size(path);
  498. info.modifiedAt = fs::last_write_time(path);
  499. info.createdAt = info.modifiedAt; // Use modified time as creation time for now
  500. } catch (const fs::filesystem_error& e) {
  501. std::cerr << "Error getting file info for " << path << ": " << e.what() << std::endl;
  502. info.fileSize = 0;
  503. info.modifiedAt = fs::file_time_type{};
  504. info.createdAt = fs::file_time_type{};
  505. }
  506. pImpl->availableModels[name] = info;
  507. }
  508. return true;
  509. }
  510. bool ModelManager::loadModel(const std::string& name) {
  511. std::string path;
  512. ModelType type;
  513. {
  514. std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  515. // Check if model exists in available models
  516. auto it = pImpl->availableModels.find(name);
  517. if (it == pImpl->availableModels.end()) {
  518. std::cerr << "Model '" << name << "' not found in available models" << std::endl;
  519. return false;
  520. }
  521. // Check if already loaded
  522. if (pImpl->loadedModels.find(name) != pImpl->loadedModels.end()) {
  523. return true;
  524. }
  525. // Extract path and type while we have the lock
  526. path = it->second.path;
  527. type = it->second.type;
  528. } // Release lock here
  529. // Load the model without holding the lock
  530. return loadModel(name, path, type);
  531. }
  532. bool ModelManager::unloadModel(const std::string& name) {
  533. std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  534. // Check if model is loaded
  535. auto loadedIt = pImpl->loadedModels.find(name);
  536. if (loadedIt == pImpl->loadedModels.end()) {
  537. return false;
  538. }
  539. // Unload the model properly
  540. if (loadedIt->second) {
  541. loadedIt->second->unloadModel();
  542. }
  543. pImpl->loadedModels.erase(loadedIt);
  544. // Update model info
  545. auto availableIt = pImpl->availableModels.find(name);
  546. if (availableIt != pImpl->availableModels.end()) {
  547. availableIt->second.isLoaded = false;
  548. }
  549. return true;
  550. }
  551. StableDiffusionWrapper* ModelManager::getModel(const std::string& name) {
  552. std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  553. auto it = pImpl->loadedModels.find(name);
  554. if (it == pImpl->loadedModels.end()) {
  555. return nullptr;
  556. }
  557. return it->second.get();
  558. }
  559. std::map<std::string, ModelManager::ModelInfo> ModelManager::getAllModels() const {
  560. std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  561. return pImpl->availableModels;
  562. }
  563. std::vector<ModelManager::ModelInfo> ModelManager::getModelsByType(ModelType type) const {
  564. std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  565. std::vector<ModelInfo> result;
  566. for (const auto& pair : pImpl->availableModels) {
  567. if (pair.second.type == type) {
  568. result.push_back(pair.second);
  569. }
  570. }
  571. return result;
  572. }
  573. ModelManager::ModelInfo ModelManager::getModelInfo(const std::string& name) const {
  574. std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  575. auto it = pImpl->availableModels.find(name);
  576. if (it == pImpl->availableModels.end()) {
  577. return ModelInfo{}; // Return empty ModelInfo if not found
  578. }
  579. return it->second;
  580. }
  581. bool ModelManager::isModelLoaded(const std::string& name) const {
  582. std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  583. auto it = pImpl->loadedModels.find(name);
  584. return it != pImpl->loadedModels.end();
  585. }
  586. size_t ModelManager::getLoadedModelsCount() const {
  587. std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  588. return pImpl->loadedModels.size();
  589. }
  590. size_t ModelManager::getAvailableModelsCount() const {
  591. std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  592. return pImpl->availableModels.size();
  593. }
  594. void ModelManager::setModelsDirectory(const std::string& path) {
  595. pImpl->modelsDirectory = path;
  596. }
  597. std::string ModelManager::getModelsDirectory() const {
  598. return pImpl->modelsDirectory;
  599. }
  600. std::string ModelManager::modelTypeToString(ModelType type) {
  601. switch (type) {
  602. case ModelType::LORA:
  603. return "lora";
  604. case ModelType::CHECKPOINT:
  605. return "checkpoint";
  606. case ModelType::VAE:
  607. return "vae";
  608. case ModelType::PRESETS:
  609. return "presets";
  610. case ModelType::PROMPTS:
  611. return "prompts";
  612. case ModelType::NEG_PROMPTS:
  613. return "neg_prompts";
  614. case ModelType::TAESD:
  615. return "taesd";
  616. case ModelType::ESRGAN:
  617. return "esrgan";
  618. case ModelType::CONTROLNET:
  619. return "controlnet";
  620. case ModelType::UPSCALER:
  621. return "upscaler";
  622. case ModelType::EMBEDDING:
  623. return "embedding";
  624. default:
  625. return "unknown";
  626. }
  627. }
  628. ModelType ModelManager::stringToModelType(const std::string& typeStr) {
  629. std::string lowerType = typeStr;
  630. std::transform(lowerType.begin(), lowerType.end(), lowerType.begin(), ::tolower);
  631. if (lowerType == "lora") {
  632. return ModelType::LORA;
  633. } else if (lowerType == "checkpoint" || lowerType == "stable-diffusion") {
  634. return ModelType::CHECKPOINT;
  635. } else if (lowerType == "vae") {
  636. return ModelType::VAE;
  637. } else if (lowerType == "presets") {
  638. return ModelType::PRESETS;
  639. } else if (lowerType == "prompts") {
  640. return ModelType::PROMPTS;
  641. } else if (lowerType == "neg_prompts" || lowerType == "negative_prompts") {
  642. return ModelType::NEG_PROMPTS;
  643. } else if (lowerType == "taesd") {
  644. return ModelType::TAESD;
  645. } else if (lowerType == "esrgan") {
  646. return ModelType::ESRGAN;
  647. } else if (lowerType == "controlnet") {
  648. return ModelType::CONTROLNET;
  649. } else if (lowerType == "upscaler") {
  650. return ModelType::UPSCALER;
  651. } else if (lowerType == "embedding" || lowerType == "textual-inversion") {
  652. return ModelType::EMBEDDING;
  653. }
  654. return ModelType::NONE;
  655. }
  656. bool ModelManager::setModelTypeDirectory(ModelType type, const std::string& path) {
  657. std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  658. if (!pImpl->validateDirectory(path)) {
  659. return false;
  660. }
  661. pImpl->modelTypeDirectories[type] = path;
  662. pImpl->legacyMode = false;
  663. return true;
  664. }
  665. std::string ModelManager::getModelTypeDirectory(ModelType type) const {
  666. std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  667. return pImpl->getModelTypeDirectory(type);
  668. }
  669. bool ModelManager::setAllModelTypeDirectories(const std::map<ModelType, std::string>& directories) {
  670. std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  671. // Validate all directories first
  672. for (const auto& [type, path] : directories) {
  673. if (!path.empty() && !pImpl->validateDirectory(path)) {
  674. return false;
  675. }
  676. }
  677. // Set all directories
  678. pImpl->modelTypeDirectories = directories;
  679. pImpl->legacyMode = false;
  680. return true;
  681. }
  682. std::map<ModelType, std::string> ModelManager::getAllModelTypeDirectories() const {
  683. std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  684. return pImpl->modelTypeDirectories;
  685. }
  686. void ModelManager::resetToLegacyDirectories() {
  687. // Note: This method should be called with modelsMutex already locked
  688. pImpl->modelTypeDirectories.clear();
  689. pImpl->legacyMode = true;
  690. }
  691. bool ModelManager::configureFromServerConfig(const ServerConfig& config) {
  692. std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  693. // Set the base models directory
  694. pImpl->modelsDirectory = config.modelsDir;
  695. if (config.legacyMode) {
  696. // Legacy mode: use single models directory
  697. resetToLegacyDirectories();
  698. return true;
  699. } else {
  700. // Explicit mode: set per-type directories
  701. std::map<ModelType, std::string> directories;
  702. if (!config.checkpoints.empty()) {
  703. directories[ModelType::CHECKPOINT] = config.checkpoints;
  704. }
  705. if (!config.controlnetDir.empty()) {
  706. directories[ModelType::CONTROLNET] = config.controlnetDir;
  707. }
  708. if (!config.embeddingsDir.empty()) {
  709. directories[ModelType::EMBEDDING] = config.embeddingsDir;
  710. }
  711. if (!config.esrganDir.empty()) {
  712. directories[ModelType::ESRGAN] = config.esrganDir;
  713. }
  714. if (!config.loraDir.empty()) {
  715. directories[ModelType::LORA] = config.loraDir;
  716. }
  717. if (!config.taesdDir.empty()) {
  718. directories[ModelType::TAESD] = config.taesdDir;
  719. }
  720. if (!config.vaeDir.empty()) {
  721. directories[ModelType::VAE] = config.vaeDir;
  722. }
  723. // Validate all directories first
  724. for (const auto& [type, path] : directories) {
  725. if (!path.empty() && !pImpl->validateDirectory(path)) {
  726. return false;
  727. }
  728. }
  729. // Set all directories (inlined to avoid deadlock from calling setAllModelTypeDirectories)
  730. pImpl->modelTypeDirectories = directories;
  731. pImpl->legacyMode = false;
  732. return true;
  733. }
  734. }
  735. void ModelManager::cancelScan() {
  736. pImpl->scanCancelled.store(true);
  737. }
  738. // SHA256 Hashing Implementation
  739. std::string ModelManager::computeModelHash(const std::string& modelName) {
  740. std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  741. auto it = pImpl->availableModels.find(modelName);
  742. if (it == pImpl->availableModels.end()) {
  743. std::cerr << "Model not found: " << modelName << std::endl;
  744. return "";
  745. }
  746. std::string filePath = it->second.fullPath;
  747. lock.unlock();
  748. std::ifstream file(filePath, std::ios::binary);
  749. if (!file.is_open()) {
  750. std::cerr << "Failed to open file for hashing: " << filePath << std::endl;
  751. return "";
  752. }
  753. // Create and initialize EVP context for SHA256
  754. EVP_MD_CTX* mdctx = EVP_MD_CTX_new();
  755. if (mdctx == nullptr) {
  756. std::cerr << "Failed to create EVP context" << std::endl;
  757. return "";
  758. }
  759. if (EVP_DigestInit_ex(mdctx, EVP_sha256(), nullptr) != 1) {
  760. std::cerr << "Failed to initialize SHA256 digest" << std::endl;
  761. EVP_MD_CTX_free(mdctx);
  762. return "";
  763. }
  764. const size_t bufferSize = 8192;
  765. char buffer[bufferSize];
  766. std::cout << "Computing SHA256 for: " << modelName << std::endl;
  767. size_t totalRead = 0;
  768. size_t lastReportedMB = 0;
  769. while (file.read(buffer, bufferSize) || file.gcount() > 0) {
  770. size_t bytesRead = file.gcount();
  771. if (EVP_DigestUpdate(mdctx, buffer, bytesRead) != 1) {
  772. std::cerr << "Failed to update digest" << std::endl;
  773. EVP_MD_CTX_free(mdctx);
  774. return "";
  775. }
  776. totalRead += bytesRead;
  777. // Progress reporting every 100MB
  778. size_t currentMB = totalRead / (1024 * 1024);
  779. if (currentMB >= lastReportedMB + 100) {
  780. std::cout << " Hashed " << currentMB << " MB..." << std::endl;
  781. lastReportedMB = currentMB;
  782. }
  783. }
  784. file.close();
  785. unsigned char hash[EVP_MAX_MD_SIZE];
  786. unsigned int hashLen = 0;
  787. if (EVP_DigestFinal_ex(mdctx, hash, &hashLen) != 1) {
  788. std::cerr << "Failed to finalize digest" << std::endl;
  789. EVP_MD_CTX_free(mdctx);
  790. return "";
  791. }
  792. EVP_MD_CTX_free(mdctx);
  793. // Convert to hex string
  794. std::ostringstream oss;
  795. for (unsigned int i = 0; i < hashLen; i++) {
  796. oss << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(hash[i]);
  797. }
  798. std::string hashStr = oss.str();
  799. std::cout << "Hash computed: " << hashStr.substr(0, 16) << "..." << std::endl;
  800. return hashStr;
  801. }
  802. std::string ModelManager::loadModelHashFromFile(const std::string& modelName) {
  803. std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  804. auto it = pImpl->availableModels.find(modelName);
  805. if (it == pImpl->availableModels.end()) {
  806. return "";
  807. }
  808. std::string jsonPath = it->second.fullPath + ".json";
  809. lock.unlock();
  810. if (!fs::exists(jsonPath)) {
  811. return "";
  812. }
  813. try {
  814. std::ifstream jsonFile(jsonPath);
  815. if (!jsonFile.is_open()) {
  816. return "";
  817. }
  818. nlohmann::json j;
  819. jsonFile >> j;
  820. jsonFile.close();
  821. if (j.contains("sha256") && j["sha256"].is_string()) {
  822. return j["sha256"].get<std::string>();
  823. }
  824. } catch (const std::exception& e) {
  825. std::cerr << "Error loading hash from JSON: " << e.what() << std::endl;
  826. }
  827. return "";
  828. }
  829. bool ModelManager::saveModelHashToFile(const std::string& modelName, const std::string& hash) {
  830. std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  831. auto it = pImpl->availableModels.find(modelName);
  832. if (it == pImpl->availableModels.end()) {
  833. return false;
  834. }
  835. std::string jsonPath = it->second.fullPath + ".json";
  836. size_t fileSize = it->second.fileSize;
  837. lock.unlock();
  838. try {
  839. nlohmann::json j;
  840. j["sha256"] = hash;
  841. j["file_size"] = fileSize;
  842. j["computed_at"] = std::chrono::system_clock::now().time_since_epoch().count();
  843. std::ofstream jsonFile(jsonPath);
  844. if (!jsonFile.is_open()) {
  845. std::cerr << "Failed to open file for writing: " << jsonPath << std::endl;
  846. return false;
  847. }
  848. jsonFile << j.dump(2);
  849. jsonFile.close();
  850. std::cout << "Saved hash to: " << jsonPath << std::endl;
  851. return true;
  852. } catch (const std::exception& e) {
  853. std::cerr << "Error saving hash to JSON: " << e.what() << std::endl;
  854. return false;
  855. }
  856. }
  857. std::string ModelManager::findModelByHash(const std::string& hash) {
  858. if (hash.length() < 10) {
  859. std::cerr << "Hash must be at least 10 characters" << std::endl;
  860. return "";
  861. }
  862. std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  863. for (const auto& [name, info] : pImpl->availableModels) {
  864. if (info.sha256.empty()) {
  865. continue;
  866. }
  867. // Support full or partial match (minimum 10 chars)
  868. if (info.sha256 == hash || info.sha256.substr(0, hash.length()) == hash) {
  869. return name;
  870. }
  871. }
  872. return "";
  873. }
  874. std::string ModelManager::ensureModelHash(const std::string& modelName, bool forceCompute) {
  875. // Try to load existing hash if not forcing recompute
  876. if (!forceCompute) {
  877. std::string existingHash = loadModelHashFromFile(modelName);
  878. if (!existingHash.empty()) {
  879. // Update in-memory model info
  880. std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  881. auto it = pImpl->availableModels.find(modelName);
  882. if (it != pImpl->availableModels.end()) {
  883. it->second.sha256 = existingHash;
  884. }
  885. return existingHash;
  886. }
  887. }
  888. // Compute new hash
  889. std::string hash = computeModelHash(modelName);
  890. if (hash.empty()) {
  891. return "";
  892. }
  893. // Save to file
  894. saveModelHashToFile(modelName, hash);
  895. // Update in-memory model info
  896. std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  897. auto it = pImpl->availableModels.find(modelName);
  898. if (it != pImpl->availableModels.end()) {
  899. it->second.sha256 = hash;
  900. }
  901. return hash;
  902. }