model_manager.cpp 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062
  1. #include "model_manager.h"
  2. #include "model_detector.h"
  3. #include "stable_diffusion_wrapper.h"
  4. #include <iostream>
  5. #include <fstream>
  6. #include <algorithm>
  7. #include <filesystem>
  8. #include <shared_mutex>
  9. #include <chrono>
  10. #include <future>
  11. #include <atomic>
  12. #include <set>
  13. #include <openssl/evp.h>
  14. #include <sstream>
  15. #include <iomanip>
  16. #include <nlohmann/json.hpp>
  17. namespace fs = std::filesystem;
  18. // File extension mappings for each model type
  19. const std::vector<std::string> CHECKPOINT_FILE_EXTENSIONS = {"safetensors", "ckpt", "gguf"};
  20. const std::vector<std::string> EMBEDDING_FILE_EXTENSIONS = {"safetensors", "pt"};
  21. const std::vector<std::string> LORA_FILE_EXTENSIONS = {"safetensors", "ckpt"};
  22. const std::vector<std::string> VAE_FILE_EXTENSIONS = {"safetensors", "pt", "ckpt", "gguf"};
  23. const std::vector<std::string> TAESD_FILE_EXTENSIONS = {"safetensors", "pth", "gguf"};
  24. const std::vector<std::string> ESRGAN_FILE_EXTENSIONS = {"pth", "pt"};
  25. const std::vector<std::string> CONTROLNET_FILE_EXTENSIONS = {"safetensors", "pth"};
  26. class ModelManager::Impl {
  27. public:
  28. std::string modelsDirectory = "./models";
  29. std::map<ModelType, std::string> modelTypeDirectories;
  30. std::map<std::string, ModelInfo> availableModels;
  31. std::map<std::string, std::unique_ptr<StableDiffusionWrapper>> loadedModels;
  32. mutable std::shared_mutex modelsMutex;
  33. std::atomic<bool> scanCancelled{false};
  34. /**
  35. * @brief Validate a directory path
  36. *
  37. * @param path The directory path to validate
  38. * @return true if the directory exists and is valid, false otherwise
  39. */
  40. bool validateDirectory(const std::string& path) const {
  41. if (path.empty()) {
  42. return false;
  43. }
  44. std::filesystem::path dirPath(path);
  45. if (!std::filesystem::exists(dirPath)) {
  46. std::cerr << "Directory does not exist: " << path << std::endl;
  47. return false;
  48. }
  49. if (!std::filesystem::is_directory(dirPath)) {
  50. std::cerr << "Path is not a directory: " << path << std::endl;
  51. return false;
  52. }
  53. return true;
  54. }
  55. /**
  56. * @brief Get default directory name for a model type
  57. *
  58. * @param type The model type
  59. * @return std::string Default directory name
  60. */
  61. std::string getDefaultDirectoryName(ModelType type) const {
  62. switch (type) {
  63. case ModelType::CHECKPOINT:
  64. return "checkpoints";
  65. case ModelType::CONTROLNET:
  66. return "controlnet";
  67. case ModelType::EMBEDDING:
  68. return "embeddings";
  69. case ModelType::ESRGAN:
  70. return "esrgan";
  71. case ModelType::LORA:
  72. return "lora";
  73. case ModelType::TAESD:
  74. return "taesd";
  75. case ModelType::VAE:
  76. return "vae";
  77. case ModelType::DIFFUSION_MODELS:
  78. return "diffusion_models";
  79. default:
  80. return "";
  81. }
  82. }
  83. /**
  84. * @brief Get directory path for a model type
  85. *
  86. * @param type The model type
  87. * @return std::string Directory path, empty if not set
  88. */
  89. std::string getModelTypeDirectory(ModelType type) const {
  90. auto it = modelTypeDirectories.find(type);
  91. if (it != modelTypeDirectories.end()) {
  92. return it->second;
  93. }
  94. // Always use explicit directory configuration
  95. std::string defaultDir = getDefaultDirectoryName(type);
  96. if (!defaultDir.empty()) {
  97. return modelsDirectory + "/" + defaultDir;
  98. }
  99. return "";
  100. }
  101. /**
  102. * @brief Get file extensions for a specific model type
  103. *
  104. * @param type The model type
  105. * @return const std::vector<std::string>& Vector of file extensions
  106. */
  107. const std::vector<std::string>& getFileExtensions(ModelType type) const {
  108. switch (type) {
  109. case ModelType::CHECKPOINT:
  110. case ModelType::DIFFUSION_MODELS:
  111. return CHECKPOINT_FILE_EXTENSIONS;
  112. case ModelType::EMBEDDING:
  113. return EMBEDDING_FILE_EXTENSIONS;
  114. case ModelType::LORA:
  115. return LORA_FILE_EXTENSIONS;
  116. case ModelType::VAE:
  117. return VAE_FILE_EXTENSIONS;
  118. case ModelType::TAESD:
  119. return TAESD_FILE_EXTENSIONS;
  120. case ModelType::ESRGAN:
  121. return ESRGAN_FILE_EXTENSIONS;
  122. case ModelType::CONTROLNET:
  123. return CONTROLNET_FILE_EXTENSIONS;
  124. default:
  125. static const std::vector<std::string> empty;
  126. return empty;
  127. }
  128. }
  129. /**
  130. * @brief Check if a file extension matches a model type
  131. *
  132. * @param extension The file extension
  133. * @param type The model type
  134. * @return true if the extension matches the model type
  135. */
  136. bool isExtensionMatch(const std::string& extension, ModelType type) const {
  137. const auto& extensions = getFileExtensions(type);
  138. return std::find(extensions.begin(), extensions.end(), extension) != extensions.end();
  139. }
  140. /**
  141. * @brief FIXED: Determine model type based on file path and extension
  142. *
  143. * THIS IS THE FIXED VERSION - no extension-based fallback!
  144. * Only returns a model type if the file is actually in the right directory.
  145. *
  146. * @param filePath The file path
  147. * @return ModelType The determined model type
  148. */
  149. ModelType determineModelType(const fs::path& filePath) const {
  150. std::string extension = filePath.extension().string();
  151. if (extension.empty()) {
  152. return ModelType::NONE;
  153. }
  154. // Remove the dot from extension
  155. if (extension[0] == '.') {
  156. extension = extension.substr(1);
  157. }
  158. // Convert to lowercase for comparison
  159. std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower);
  160. // Check if the file resides under a directory registered for a given ModelType
  161. fs::path absoluteFilePath = fs::absolute(filePath);
  162. // First check configured directories (if any)
  163. for (const auto& [type, directory] : modelTypeDirectories) {
  164. if (!directory.empty()) {
  165. fs::path absoluteDirPath = fs::absolute(directory).lexically_normal();
  166. fs::path normalizedFilePath = absoluteFilePath.lexically_normal();
  167. // Check if the file is under this directory (directly or in subdirectories)
  168. // Get the relative path from directory to file
  169. auto relativePath = normalizedFilePath.lexically_relative(absoluteDirPath);
  170. // If relative path doesn't start with "..", then file is under the directory
  171. std::string relPathStr = relativePath.string();
  172. bool isUnderDirectory = !relPathStr.empty() &&
  173. relPathStr.substr(0, 2) != ".." &&
  174. relPathStr[0] != '/';
  175. if (isUnderDirectory && isExtensionMatch(extension, type)) {
  176. return type;
  177. }
  178. }
  179. }
  180. // Check default directory structure when no explicit directories configured
  181. if (modelTypeDirectories.empty()) {
  182. // Check the entire path hierarchy for model type directories
  183. fs::path currentPath = filePath.parent_path();
  184. while (currentPath.has_filename()) {
  185. std::string dirName = currentPath.filename().string();
  186. std::transform(dirName.begin(), dirName.end(), dirName.begin(), ::tolower);
  187. // Check default directory names
  188. if (dirName == "checkpoints" || dirName == "stable-diffusion") {
  189. if (isExtensionMatch(extension, ModelType::CHECKPOINT)) {
  190. return ModelType::CHECKPOINT;
  191. }
  192. } else if (dirName == "controlnet") {
  193. if (isExtensionMatch(extension, ModelType::CONTROLNET)) {
  194. return ModelType::CONTROLNET;
  195. }
  196. } else if (dirName == "lora") {
  197. if (isExtensionMatch(extension, ModelType::LORA)) {
  198. return ModelType::LORA;
  199. }
  200. } else if (dirName == "vae") {
  201. if (isExtensionMatch(extension, ModelType::VAE)) {
  202. return ModelType::VAE;
  203. }
  204. } else if (dirName == "taesd") {
  205. if (isExtensionMatch(extension, ModelType::TAESD)) {
  206. return ModelType::TAESD;
  207. }
  208. } else if (dirName == "esrgan" || dirName == "upscaler") {
  209. if (isExtensionMatch(extension, ModelType::ESRGAN)) {
  210. return ModelType::ESRGAN;
  211. }
  212. } else if (dirName == "embeddings" || dirName == "textual-inversion") {
  213. } else if (dirName == "diffusion_models" || dirName == "diffusion") {
  214. if (isExtensionMatch(extension, ModelType::DIFFUSION_MODELS)) {
  215. return ModelType::DIFFUSION_MODELS;
  216. }
  217. if (isExtensionMatch(extension, ModelType::EMBEDDING)) {
  218. return ModelType::EMBEDDING;
  219. }
  220. }
  221. // Move up to parent directory
  222. currentPath = currentPath.parent_path();
  223. }
  224. }
  225. // NO EXTENSION-BASED FALLBACK - this was the bug!
  226. // Files must be in the correct directory to be recognized
  227. return ModelType::NONE;
  228. }
  229. /**
  230. * @brief Get file information with timeout
  231. *
  232. * @param filePath The file path to get info for
  233. * @param timeoutMs Timeout in milliseconds
  234. * @return std::pair<bool, std::pair<uintmax_t, fs::file_time_type>> Success flag and file info
  235. */
  236. std::pair<bool, std::pair<uintmax_t, fs::file_time_type>> getFileInfoWithTimeout(
  237. const fs::path& filePath, int timeoutMs = 5000) {
  238. auto future = std::async(std::launch::async, [&filePath]() -> std::pair<uintmax_t, fs::file_time_type> {
  239. try {
  240. uintmax_t fileSize = fs::file_size(filePath);
  241. fs::file_time_type modifiedAt = fs::last_write_time(filePath);
  242. return {fileSize, modifiedAt};
  243. } catch (const fs::filesystem_error&) {
  244. return {0, fs::file_time_type{}};
  245. }
  246. });
  247. if (future.wait_for(std::chrono::milliseconds(timeoutMs)) == std::future_status::timeout) {
  248. std::cerr << "Timeout getting file info for " << filePath << std::endl;
  249. return {false, {0, fs::file_time_type{}}};
  250. }
  251. return {true, future.get()};
  252. }
  253. /**
  254. * @brief Scan a directory for models recursively (all types, without holding mutex)
  255. *
  256. * Recursively walks the directory tree to find all model files. For each model
  257. * found, constructs the display name as 'relative_path/model_name' where
  258. * relative_path is the path from the models root directory to the file's
  259. * containing folder (using forward slashes). Models in the root directory
  260. * appear without a prefix.
  261. *
  262. * @param directory The directory to scan
  263. * @param modelsMap Reference to the map to store results
  264. * @return bool True if scanning completed without cancellation
  265. */
  266. bool scanDirectory(const fs::path& directory, std::map<std::string, ModelInfo>& modelsMap) {
  267. if (scanCancelled.load()) {
  268. return false;
  269. }
  270. if (!fs::exists(directory) || !fs::is_directory(directory)) {
  271. return true;
  272. }
  273. try {
  274. for (const auto& entry : fs::recursive_directory_iterator(directory)) {
  275. if (scanCancelled.load()) {
  276. return false;
  277. }
  278. if (entry.is_regular_file()) {
  279. fs::path filePath = entry.path();
  280. ModelType detectedType = determineModelType(filePath);
  281. // Only add files that have a valid model type (not NONE)
  282. if (detectedType != ModelType::NONE) {
  283. ModelInfo info;
  284. // Calculate relative path from the model type directory to exclude model type folder names
  285. fs::path relativePath;
  286. try {
  287. // Get the specific model type directory for this detected type
  288. std::string modelTypeDir = getModelTypeDirectory(detectedType);
  289. if (!modelTypeDir.empty()) {
  290. fs::path typeBaseDir(modelTypeDir);
  291. // Get relative path from the model type directory
  292. relativePath = fs::relative(filePath, typeBaseDir);
  293. } else {
  294. // Fallback: use the base models directory
  295. if (!modelsDirectory.empty()) {
  296. fs::path baseDir(modelsDirectory);
  297. relativePath = fs::relative(filePath, baseDir);
  298. } else {
  299. relativePath = fs::relative(filePath, directory);
  300. }
  301. }
  302. } catch (const fs::filesystem_error&) {
  303. // If relative path calculation fails, use filename only
  304. relativePath = filePath.filename();
  305. }
  306. std::string modelName = relativePath.string();
  307. // Normalize path separators for consistency
  308. std::replace(modelName.begin(), modelName.end(), '\\', '/');
  309. // Check if model already exists to avoid duplicates
  310. if (modelsMap.find(modelName) == modelsMap.end()) {
  311. info.name = modelName;
  312. info.path = filePath.string();
  313. info.fullPath = fs::absolute(filePath).string();
  314. info.type = detectedType;
  315. info.isLoaded = false;
  316. info.description = ""; // Initialize description
  317. info.metadata = {}; // Initialize metadata
  318. // Get file info with timeout
  319. auto [success, fileInfo] = getFileInfoWithTimeout(filePath);
  320. if (success) {
  321. info.fileSize = fileInfo.first;
  322. info.modifiedAt = fileInfo.second;
  323. info.createdAt = fileInfo.second; // Use modified time as creation time for now
  324. } else {
  325. info.fileSize = 0;
  326. info.modifiedAt = fs::file_time_type{};
  327. info.createdAt = fs::file_time_type{};
  328. }
  329. // Try to load cached hash from .json file
  330. std::string hashFile = info.fullPath + ".json";
  331. if (fs::exists(hashFile)) {
  332. try {
  333. std::ifstream file(hashFile);
  334. nlohmann::json hashData = nlohmann::json::parse(file);
  335. if (hashData.contains("sha256") && hashData["sha256"].is_string()) {
  336. info.sha256 = hashData["sha256"];
  337. } else {
  338. info.sha256 = "";
  339. }
  340. } catch (...) {
  341. info.sha256 = ""; // If parsing fails, leave empty
  342. }
  343. } else {
  344. info.sha256 = ""; // No cached hash file
  345. }
  346. // Detect architecture for checkpoint models (including diffusion_models)
  347. if (detectedType == ModelType::CHECKPOINT || detectedType == ModelType::DIFFUSION_MODELS) {
  348. try {
  349. ModelDetectionResult detection = ModelDetector::detectModel(info.fullPath);
  350. // For .ckpt files that can't be detected, default to SD1.5
  351. if (detection.architecture == ModelArchitecture::UNKNOWN &&
  352. (filePath.extension() == ".ckpt" || filePath.extension() == ".pt")) {
  353. info.architecture = "Stable Diffusion 1.5 (assumed)";
  354. info.recommendedVAE = "vae-ft-mse-840000-ema-pruned.safetensors";
  355. info.recommendedWidth = 512;
  356. info.recommendedHeight = 512;
  357. info.recommendedSteps = 20;
  358. info.recommendedSampler = "euler_a";
  359. } else {
  360. info.architecture = detection.architectureName;
  361. info.recommendedVAE = detection.recommendedVAE;
  362. // Parse recommended parameters
  363. if (detection.suggestedParams.count("width")) {
  364. info.recommendedWidth = std::stoi(detection.suggestedParams["width"]);
  365. }
  366. if (detection.suggestedParams.count("height")) {
  367. info.recommendedHeight = std::stoi(detection.suggestedParams["height"]);
  368. }
  369. if (detection.suggestedParams.count("steps")) {
  370. info.recommendedSteps = std::stoi(detection.suggestedParams["steps"]);
  371. }
  372. if (detection.suggestedParams.count("sampler")) {
  373. info.recommendedSampler = detection.suggestedParams["sampler"];
  374. }
  375. }
  376. // Build list of required models based on architecture
  377. if (detection.needsVAE && !detection.recommendedVAE.empty()) {
  378. info.requiredModels.push_back("VAE: " + detection.recommendedVAE);
  379. }
  380. // Add CLIP-L if required
  381. if (detection.suggestedParams.count("clip_l_required")) {
  382. info.requiredModels.push_back("CLIP-L: " + detection.suggestedParams.at("clip_l_required"));
  383. }
  384. // Add CLIP-G if required
  385. if (detection.suggestedParams.count("clip_g_required")) {
  386. info.requiredModels.push_back("CLIP-G: " + detection.suggestedParams.at("clip_g_required"));
  387. }
  388. // Add T5XXL if required
  389. if (detection.suggestedParams.count("t5xxl_required")) {
  390. info.requiredModels.push_back("T5XXL: " + detection.suggestedParams.at("t5xxl_required"));
  391. }
  392. // Add Qwen models if required
  393. if (detection.suggestedParams.count("qwen2vl_required")) {
  394. info.requiredModels.push_back("Qwen2-VL: " + detection.suggestedParams.at("qwen2vl_required"));
  395. }
  396. if (detection.suggestedParams.count("qwen2vl_vision_required")) {
  397. info.requiredModels.push_back("Qwen2-VL-Vision: " + detection.suggestedParams.at("qwen2vl_vision_required"));
  398. }
  399. } catch (const std::exception& e) {
  400. // If detection fails completely, default to SD1.5
  401. info.architecture = "Stable Diffusion 1.5 (assumed)";
  402. info.recommendedVAE = "vae-ft-mse-840000-ema-pruned.safetensors";
  403. info.recommendedWidth = 512;
  404. info.recommendedHeight = 512;
  405. info.recommendedSteps = 20;
  406. info.recommendedSampler = "euler_a";
  407. }
  408. }
  409. modelsMap[info.name] = info;
  410. }
  411. }
  412. }
  413. }
  414. } catch (const fs::filesystem_error& e) {
  415. // Silently handle filesystem errors
  416. }
  417. return !scanCancelled.load();
  418. }
  419. };
  420. ModelManager::ModelManager() : pImpl(std::make_unique<Impl>()) {
  421. }
  422. ModelManager::~ModelManager() = default;
  423. bool ModelManager::scanModelsDirectory() {
  424. // Reset cancellation flag
  425. pImpl->scanCancelled.store(false);
  426. // Create temporary map to store scan results (outside of lock)
  427. std::map<std::string, ModelInfo> tempModels;
  428. // Scan all configured directories for all model types recursively
  429. // We scan each directory once and detect all model types within it
  430. std::set<std::string> scannedDirectories; // Avoid scanning the same directory multiple times
  431. std::vector<std::string> directoriesToScan;
  432. // Collect unique directories to scan
  433. std::vector<ModelType> allTypes = {
  434. ModelType::CHECKPOINT, ModelType::CONTROLNET, ModelType::LORA,
  435. ModelType::VAE, ModelType::TAESD, ModelType::ESRGAN, ModelType::EMBEDDING,
  436. ModelType::DIFFUSION_MODELS
  437. };
  438. for (const auto& type : allTypes) {
  439. std::string dirPath = pImpl->getModelTypeDirectory(type);
  440. if (!dirPath.empty() && scannedDirectories.find(dirPath) == scannedDirectories.end()) {
  441. directoriesToScan.push_back(dirPath);
  442. scannedDirectories.insert(dirPath);
  443. }
  444. }
  445. // Also scan the base models directory if it exists and isn't already covered
  446. if (!pImpl->modelsDirectory.empty() &&
  447. fs::exists(pImpl->modelsDirectory) &&
  448. scannedDirectories.find(pImpl->modelsDirectory) == scannedDirectories.end()) {
  449. directoriesToScan.push_back(pImpl->modelsDirectory);
  450. }
  451. // Scan each unique directory recursively for all model types
  452. for (const auto& dirPath : directoriesToScan) {
  453. if (!pImpl->scanDirectory(dirPath, tempModels)) {
  454. return false;
  455. }
  456. }
  457. // Brief exclusive lock only to swap the data
  458. {
  459. std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  460. pImpl->availableModels.swap(tempModels);
  461. }
  462. return true;
  463. }
  464. bool ModelManager::loadModel(const std::string& name, const std::string& path, ModelType type) {
  465. std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  466. // Check if model is already loaded
  467. if (pImpl->loadedModels.find(name) != pImpl->loadedModels.end()) {
  468. return true;
  469. }
  470. // Check if file exists
  471. if (!fs::exists(path)) {
  472. std::cerr << "Model file does not exist: " << path << std::endl;
  473. return false;
  474. }
  475. // Create and initialize the stable-diffusion wrapper
  476. auto wrapper = std::make_unique<StableDiffusionWrapper>();
  477. // Set up generation parameters for model loading
  478. StableDiffusionWrapper::GenerationParams loadParams;
  479. loadParams.modelPath = path;
  480. loadParams.modelType = "f16"; // Default to f16 for better performance
  481. // Try to load the model
  482. if (!wrapper->loadModel(path, loadParams)) {
  483. std::cerr << "Failed to load model '" << name << "': " << wrapper->getLastError() << std::endl;
  484. return false;
  485. }
  486. pImpl->loadedModels[name] = std::move(wrapper);
  487. // Update model info
  488. if (pImpl->availableModels.find(name) != pImpl->availableModels.end()) {
  489. pImpl->availableModels[name].isLoaded = true;
  490. } else {
  491. // Create a new model info entry
  492. ModelInfo info;
  493. info.name = name;
  494. info.path = path;
  495. info.fullPath = fs::absolute(path).string();
  496. info.type = type;
  497. info.isLoaded = true;
  498. info.sha256 = "";
  499. info.description = ""; // Initialize description
  500. info.metadata = {}; // Initialize metadata
  501. try {
  502. info.fileSize = fs::file_size(path);
  503. info.modifiedAt = fs::last_write_time(path);
  504. info.createdAt = info.modifiedAt; // Use modified time as creation time for now
  505. } catch (const fs::filesystem_error& e) {
  506. std::cerr << "Error getting file info for " << path << ": " << e.what() << std::endl;
  507. info.fileSize = 0;
  508. info.modifiedAt = fs::file_time_type{};
  509. info.createdAt = fs::file_time_type{};
  510. }
  511. pImpl->availableModels[name] = info;
  512. }
  513. return true;
  514. }
  515. bool ModelManager::loadModel(const std::string& name) {
  516. std::string path;
  517. ModelType type;
  518. {
  519. std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  520. // Check if model exists in available models
  521. auto it = pImpl->availableModels.find(name);
  522. if (it == pImpl->availableModels.end()) {
  523. std::cerr << "Model '" << name << "' not found in available models" << std::endl;
  524. return false;
  525. }
  526. // Check if already loaded
  527. if (pImpl->loadedModels.find(name) != pImpl->loadedModels.end()) {
  528. return true;
  529. }
  530. // Extract path and type while we have the lock
  531. path = it->second.path;
  532. type = it->second.type;
  533. } // Release lock here
  534. // Load the model without holding the lock
  535. return loadModel(name, path, type);
  536. }
  537. bool ModelManager::unloadModel(const std::string& name) {
  538. std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  539. // Check if model is loaded
  540. auto loadedIt = pImpl->loadedModels.find(name);
  541. if (loadedIt == pImpl->loadedModels.end()) {
  542. return false;
  543. }
  544. // Unload the model properly
  545. if (loadedIt->second) {
  546. loadedIt->second->unloadModel();
  547. }
  548. pImpl->loadedModels.erase(loadedIt);
  549. // Update model info
  550. auto availableIt = pImpl->availableModels.find(name);
  551. if (availableIt != pImpl->availableModels.end()) {
  552. availableIt->second.isLoaded = false;
  553. }
  554. return true;
  555. }
  556. StableDiffusionWrapper* ModelManager::getModel(const std::string& name) {
  557. std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  558. auto it = pImpl->loadedModels.find(name);
  559. if (it == pImpl->loadedModels.end()) {
  560. return nullptr;
  561. }
  562. return it->second.get();
  563. }
  564. std::map<std::string, ModelManager::ModelInfo> ModelManager::getAllModels() const {
  565. std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  566. return pImpl->availableModels;
  567. }
  568. std::vector<ModelManager::ModelInfo> ModelManager::getModelsByType(ModelType type) const {
  569. std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  570. std::vector<ModelInfo> result;
  571. for (const auto& pair : pImpl->availableModels) {
  572. if (pair.second.type == type) {
  573. result.push_back(pair.second);
  574. }
  575. }
  576. return result;
  577. }
  578. ModelManager::ModelInfo ModelManager::getModelInfo(const std::string& name) const {
  579. std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  580. auto it = pImpl->availableModels.find(name);
  581. if (it == pImpl->availableModels.end()) {
  582. return ModelInfo{}; // Return empty ModelInfo if not found
  583. }
  584. return it->second;
  585. }
  586. bool ModelManager::isModelLoaded(const std::string& name) const {
  587. std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  588. auto it = pImpl->loadedModels.find(name);
  589. return it != pImpl->loadedModels.end();
  590. }
  591. size_t ModelManager::getLoadedModelsCount() const {
  592. std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  593. return pImpl->loadedModels.size();
  594. }
  595. size_t ModelManager::getAvailableModelsCount() const {
  596. std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  597. return pImpl->availableModels.size();
  598. }
  599. void ModelManager::setModelsDirectory(const std::string& path) {
  600. pImpl->modelsDirectory = path;
  601. }
  602. std::string ModelManager::getModelsDirectory() const {
  603. return pImpl->modelsDirectory;
  604. }
  605. std::string ModelManager::modelTypeToString(ModelType type) {
  606. switch (type) {
  607. case ModelType::LORA:
  608. return "lora";
  609. case ModelType::CHECKPOINT:
  610. return "checkpoint";
  611. case ModelType::VAE:
  612. return "vae";
  613. case ModelType::PRESETS:
  614. return "presets";
  615. case ModelType::PROMPTS:
  616. return "prompts";
  617. case ModelType::NEG_PROMPTS:
  618. return "neg_prompts";
  619. case ModelType::TAESD:
  620. return "taesd";
  621. case ModelType::ESRGAN:
  622. return "esrgan";
  623. case ModelType::CONTROLNET:
  624. return "controlnet";
  625. case ModelType::UPSCALER:
  626. return "upscaler";
  627. case ModelType::EMBEDDING:
  628. return "embedding";
  629. case ModelType::DIFFUSION_MODELS:
  630. return "checkpoint";
  631. default:
  632. return "unknown";
  633. }
  634. }
  635. ModelType ModelManager::stringToModelType(const std::string& typeStr) {
  636. std::string lowerType = typeStr;
  637. std::transform(lowerType.begin(), lowerType.end(), lowerType.begin(), ::tolower);
  638. if (lowerType == "lora") {
  639. return ModelType::LORA;
  640. } else if (lowerType == "checkpoint" || lowerType == "stable-diffusion") {
  641. return ModelType::CHECKPOINT;
  642. } else if (lowerType == "vae") {
  643. return ModelType::VAE;
  644. } else if (lowerType == "presets") {
  645. return ModelType::PRESETS;
  646. } else if (lowerType == "prompts") {
  647. return ModelType::PROMPTS;
  648. } else if (lowerType == "neg_prompts" || lowerType == "negative_prompts") {
  649. return ModelType::NEG_PROMPTS;
  650. } else if (lowerType == "taesd") {
  651. return ModelType::TAESD;
  652. } else if (lowerType == "esrgan") {
  653. return ModelType::ESRGAN;
  654. } else if (lowerType == "controlnet") {
  655. return ModelType::CONTROLNET;
  656. } else if (lowerType == "upscaler") {
  657. return ModelType::UPSCALER;
  658. } else if (lowerType == "embedding" || lowerType == "textual-inversion") {
  659. return ModelType::EMBEDDING;
  660. }
  661. return ModelType::NONE;
  662. }
  663. bool ModelManager::setModelTypeDirectory(ModelType type, const std::string& path) {
  664. std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  665. if (!pImpl->validateDirectory(path)) {
  666. return false;
  667. }
  668. pImpl->modelTypeDirectories[type] = path;
  669. return true;
  670. }
  671. std::string ModelManager::getModelTypeDirectory(ModelType type) const {
  672. std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  673. return pImpl->getModelTypeDirectory(type);
  674. }
  675. bool ModelManager::setAllModelTypeDirectories(const std::map<ModelType, std::string>& directories) {
  676. std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  677. // Validate all directories first
  678. for (const auto& [type, path] : directories) {
  679. if (!path.empty() && !pImpl->validateDirectory(path)) {
  680. return false;
  681. }
  682. }
  683. // Set all directories
  684. pImpl->modelTypeDirectories = directories;
  685. return true;
  686. }
  687. std::map<ModelType, std::string> ModelManager::getAllModelTypeDirectories() const {
  688. std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  689. return pImpl->modelTypeDirectories;
  690. }
  691. // Legacy resetToLegacyDirectories method removed
  692. // Using explicit directory configuration only
  693. bool ModelManager::configureFromServerConfig(const ServerConfig& config) {
  694. std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  695. // Set the base models directory
  696. pImpl->modelsDirectory = config.modelsDir;
  697. // Always use explicit directory configuration
  698. std::map<ModelType, std::string> directories;
  699. if (!config.checkpoints.empty()) {
  700. directories[ModelType::CHECKPOINT] = config.checkpoints;
  701. }
  702. if (!config.controlnetDir.empty()) {
  703. directories[ModelType::CONTROLNET] = config.controlnetDir;
  704. }
  705. if (!config.embeddingsDir.empty()) {
  706. directories[ModelType::EMBEDDING] = config.embeddingsDir;
  707. }
  708. if (!config.esrganDir.empty()) {
  709. directories[ModelType::ESRGAN] = config.esrganDir;
  710. }
  711. if (!config.loraDir.empty()) {
  712. directories[ModelType::LORA] = config.loraDir;
  713. }
  714. if (!config.taesdDir.empty()) {
  715. directories[ModelType::TAESD] = config.taesdDir;
  716. }
  717. if (!config.vaeDir.empty()) {
  718. directories[ModelType::VAE] = config.vaeDir;
  719. }
  720. // Validate all directories first
  721. for (const auto& [type, path] : directories) {
  722. if (!path.empty() && !pImpl->validateDirectory(path)) {
  723. return false;
  724. }
  725. }
  726. // Set all directories (inlined to avoid deadlock from calling setAllModelTypeDirectories)
  727. pImpl->modelTypeDirectories = directories;
  728. return true;
  729. }
  730. void ModelManager::cancelScan() {
  731. pImpl->scanCancelled.store(true);
  732. }
  733. // SHA256 Hashing Implementation
  734. std::string ModelManager::computeModelHash(const std::string& modelName) {
  735. std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  736. auto it = pImpl->availableModels.find(modelName);
  737. if (it == pImpl->availableModels.end()) {
  738. std::cerr << "Model not found: " << modelName << std::endl;
  739. return "";
  740. }
  741. std::string filePath = it->second.fullPath;
  742. lock.unlock();
  743. std::ifstream file(filePath, std::ios::binary);
  744. if (!file.is_open()) {
  745. std::cerr << "Failed to open file for hashing: " << filePath << std::endl;
  746. return "";
  747. }
  748. // Create and initialize EVP context for SHA256
  749. EVP_MD_CTX* mdctx = EVP_MD_CTX_new();
  750. if (mdctx == nullptr) {
  751. std::cerr << "Failed to create EVP context" << std::endl;
  752. return "";
  753. }
  754. if (EVP_DigestInit_ex(mdctx, EVP_sha256(), nullptr) != 1) {
  755. std::cerr << "Failed to initialize SHA256 digest" << std::endl;
  756. EVP_MD_CTX_free(mdctx);
  757. return "";
  758. }
  759. const size_t bufferSize = 8192;
  760. char buffer[bufferSize];
  761. std::cout << "Computing SHA256 for: " << modelName << std::endl;
  762. size_t totalRead = 0;
  763. size_t lastReportedMB = 0;
  764. while (file.read(buffer, bufferSize) || file.gcount() > 0) {
  765. size_t bytesRead = file.gcount();
  766. if (EVP_DigestUpdate(mdctx, buffer, bytesRead) != 1) {
  767. std::cerr << "Failed to update digest" << std::endl;
  768. EVP_MD_CTX_free(mdctx);
  769. return "";
  770. }
  771. totalRead += bytesRead;
  772. // Progress reporting every 100MB
  773. size_t currentMB = totalRead / (1024 * 1024);
  774. if (currentMB >= lastReportedMB + 100) {
  775. std::cout << " Hashed " << currentMB << " MB..." << std::endl;
  776. lastReportedMB = currentMB;
  777. }
  778. }
  779. file.close();
  780. unsigned char hash[EVP_MAX_MD_SIZE];
  781. unsigned int hashLen = 0;
  782. if (EVP_DigestFinal_ex(mdctx, hash, &hashLen) != 1) {
  783. std::cerr << "Failed to finalize digest" << std::endl;
  784. EVP_MD_CTX_free(mdctx);
  785. return "";
  786. }
  787. EVP_MD_CTX_free(mdctx);
  788. // Convert to hex string
  789. std::ostringstream oss;
  790. for (unsigned int i = 0; i < hashLen; i++) {
  791. oss << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(hash[i]);
  792. }
  793. std::string hashStr = oss.str();
  794. std::cout << "Hash computed: " << hashStr.substr(0, 16) << "..." << std::endl;
  795. return hashStr;
  796. }
  797. std::string ModelManager::loadModelHashFromFile(const std::string& modelName) {
  798. std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  799. auto it = pImpl->availableModels.find(modelName);
  800. if (it == pImpl->availableModels.end()) {
  801. return "";
  802. }
  803. std::string jsonPath = it->second.fullPath + ".json";
  804. lock.unlock();
  805. if (!fs::exists(jsonPath)) {
  806. return "";
  807. }
  808. try {
  809. std::ifstream jsonFile(jsonPath);
  810. if (!jsonFile.is_open()) {
  811. return "";
  812. }
  813. nlohmann::json j;
  814. jsonFile >> j;
  815. jsonFile.close();
  816. if (j.contains("sha256") && j["sha256"].is_string()) {
  817. return j["sha256"].get<std::string>();
  818. }
  819. } catch (const std::exception& e) {
  820. std::cerr << "Error loading hash from JSON: " << e.what() << std::endl;
  821. }
  822. return "";
  823. }
  824. bool ModelManager::saveModelHashToFile(const std::string& modelName, const std::string& hash) {
  825. std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  826. auto it = pImpl->availableModels.find(modelName);
  827. if (it == pImpl->availableModels.end()) {
  828. return false;
  829. }
  830. std::string jsonPath = it->second.fullPath + ".json";
  831. size_t fileSize = it->second.fileSize;
  832. lock.unlock();
  833. try {
  834. nlohmann::json j;
  835. j["sha256"] = hash;
  836. j["file_size"] = fileSize;
  837. j["computed_at"] = std::chrono::system_clock::now().time_since_epoch().count();
  838. std::ofstream jsonFile(jsonPath);
  839. if (!jsonFile.is_open()) {
  840. std::cerr << "Failed to open file for writing: " << jsonPath << std::endl;
  841. return false;
  842. }
  843. jsonFile << j.dump(2);
  844. jsonFile.close();
  845. std::cout << "Saved hash to: " << jsonPath << std::endl;
  846. return true;
  847. } catch (const std::exception& e) {
  848. std::cerr << "Error saving hash to JSON: " << e.what() << std::endl;
  849. return false;
  850. }
  851. }
  852. std::string ModelManager::findModelByHash(const std::string& hash) {
  853. if (hash.length() < 10) {
  854. std::cerr << "Hash must be at least 10 characters" << std::endl;
  855. return "";
  856. }
  857. std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  858. for (const auto& [name, info] : pImpl->availableModels) {
  859. if (info.sha256.empty()) {
  860. continue;
  861. }
  862. // Support full or partial match (minimum 10 chars)
  863. if (info.sha256 == hash || info.sha256.substr(0, hash.length()) == hash) {
  864. return name;
  865. }
  866. }
  867. return "";
  868. }
  869. std::string ModelManager::ensureModelHash(const std::string& modelName, bool forceCompute) {
  870. // Try to load existing hash if not forcing recompute
  871. if (!forceCompute) {
  872. std::string existingHash = loadModelHashFromFile(modelName);
  873. if (!existingHash.empty()) {
  874. // Update in-memory model info
  875. std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  876. auto it = pImpl->availableModels.find(modelName);
  877. if (it != pImpl->availableModels.end()) {
  878. it->second.sha256 = existingHash;
  879. }
  880. return existingHash;
  881. }
  882. }
  883. // Compute new hash
  884. std::string hash = computeModelHash(modelName);
  885. if (hash.empty()) {
  886. return "";
  887. }
  888. // Save to file
  889. saveModelHashToFile(modelName, hash);
  890. // Update in-memory model info
  891. std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
  892. auto it = pImpl->availableModels.find(modelName);
  893. if (it != pImpl->availableModels.end()) {
  894. it->second.sha256 = hash;
  895. }
  896. return hash;
  897. }