| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065 |
- #include "model_manager.h"
- #include "model_detector.h"
- #include "stable_diffusion_wrapper.h"
- #include <iostream>
- #include <fstream>
- #include <algorithm>
- #include <filesystem>
- #include <shared_mutex>
- #include <chrono>
- #include <future>
- #include <atomic>
- #include <openssl/evp.h>
- #include <sstream>
- #include <iomanip>
- #include <nlohmann/json.hpp>
- namespace fs = std::filesystem;
- // File extension mappings for each model type
- const std::vector<std::string> CHECKPOINT_FILE_EXTENSIONS = {"safetensors", "ckpt", "gguf"};
- const std::vector<std::string> EMBEDDING_FILE_EXTENSIONS = {"safetensors", "pt"};
- const std::vector<std::string> LORA_FILE_EXTENSIONS = {"safetensors", "ckpt"};
- const std::vector<std::string> VAE_FILE_EXTENSIONS = {"safetensors", "pt", "ckpt", "gguf"};
- const std::vector<std::string> TAESD_FILE_EXTENSIONS = {"safetensors", "pth", "gguf"};
- const std::vector<std::string> ESRGAN_FILE_EXTENSIONS = {"pth", "pt"};
- const std::vector<std::string> CONTROLNET_FILE_EXTENSIONS = {"safetensors", "pth"};
- class ModelManager::Impl {
- public:
- std::string modelsDirectory = "./models";
- std::map<ModelType, std::string> modelTypeDirectories;
- std::map<std::string, ModelInfo> availableModels;
- std::map<std::string, std::unique_ptr<StableDiffusionWrapper>> loadedModels;
- mutable std::shared_mutex modelsMutex;
- std::atomic<bool> scanCancelled{false};
- bool legacyMode = true;
- /**
- * @brief Validate a directory path
- *
- * @param path The directory path to validate
- * @return true if the directory exists and is valid, false otherwise
- */
- bool validateDirectory(const std::string& path) const {
- if (path.empty()) {
- return false;
- }
- std::filesystem::path dirPath(path);
- if (!std::filesystem::exists(dirPath)) {
- std::cerr << "Directory does not exist: " << path << std::endl;
- return false;
- }
- if (!std::filesystem::is_directory(dirPath)) {
- std::cerr << "Path is not a directory: " << path << std::endl;
- return false;
- }
- return true;
- }
- /**
- * @brief Get default directory name for a model type
- *
- * @param type The model type
- * @return std::string Default directory name
- */
- std::string getDefaultDirectoryName(ModelType type) const {
- switch (type) {
- case ModelType::CHECKPOINT:
- return "checkpoints";
- case ModelType::CONTROLNET:
- return "controlnet";
- case ModelType::EMBEDDING:
- return "embeddings";
- case ModelType::ESRGAN:
- return "esrgan";
- case ModelType::LORA:
- return "lora";
- case ModelType::TAESD:
- return "taesd";
- case ModelType::VAE:
- return "vae";
- default:
- return "";
- }
- }
- /**
- * @brief Get directory path for a model type
- *
- * @param type The model type
- * @return std::string Directory path, empty if not set
- */
- std::string getModelTypeDirectory(ModelType type) const {
- auto it = modelTypeDirectories.find(type);
- if (it != modelTypeDirectories.end()) {
- return it->second;
- }
- // If in legacy mode, construct default path
- if (legacyMode) {
- std::string defaultDir = getDefaultDirectoryName(type);
- if (!defaultDir.empty()) {
- return modelsDirectory + "/" + defaultDir;
- }
- }
- return "";
- }
- /**
- * @brief Get file extensions for a specific model type
- *
- * @param type The model type
- * @return const std::vector<std::string>& Vector of file extensions
- */
- const std::vector<std::string>& getFileExtensions(ModelType type) const {
- switch (type) {
- case ModelType::CHECKPOINT:
- return CHECKPOINT_FILE_EXTENSIONS;
- case ModelType::EMBEDDING:
- return EMBEDDING_FILE_EXTENSIONS;
- case ModelType::LORA:
- return LORA_FILE_EXTENSIONS;
- case ModelType::VAE:
- return VAE_FILE_EXTENSIONS;
- case ModelType::TAESD:
- return TAESD_FILE_EXTENSIONS;
- case ModelType::ESRGAN:
- return ESRGAN_FILE_EXTENSIONS;
- case ModelType::CONTROLNET:
- return CONTROLNET_FILE_EXTENSIONS;
- default:
- static const std::vector<std::string> empty;
- return empty;
- }
- }
- /**
- * @brief Check if a file extension matches a model type
- *
- * @param extension The file extension
- * @param type The model type
- * @return true if the extension matches the model type
- */
- bool isExtensionMatch(const std::string& extension, ModelType type) const {
- const auto& extensions = getFileExtensions(type);
- return std::find(extensions.begin(), extensions.end(), extension) != extensions.end();
- }
- /**
- * @brief Determine model type based on file path and extension
- *
- * @param filePath The file path
- * @return ModelType The determined model type
- */
- ModelType determineModelType(const fs::path& filePath) const {
- std::string extension = filePath.extension().string();
- if (extension.empty()) {
- return ModelType::NONE;
- }
- // Remove the dot from extension
- if (extension[0] == '.') {
- extension = extension.substr(1);
- }
- // Convert to lowercase for comparison
- std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower);
- // Check if the file resides under a directory registered for a given ModelType
- fs::path absoluteFilePath = fs::absolute(filePath);
- // First check configured directories (if any)
- for (const auto& [type, directory] : modelTypeDirectories) {
- if (!directory.empty()) {
- fs::path absoluteDirPath = fs::absolute(directory).lexically_normal();
- fs::path normalizedFilePath = absoluteFilePath.lexically_normal();
- // Check if the file is under this directory (directly or in subdirectories)
- // Get the relative path from directory to file
- auto relativePath = normalizedFilePath.lexically_relative(absoluteDirPath);
- // If relative path doesn't start with "..", then file is under the directory
- std::string relPathStr = relativePath.string();
- bool isUnderDirectory = !relPathStr.empty() &&
- relPathStr.substr(0, 2) != ".." &&
- relPathStr[0] != '/';
- if (isUnderDirectory && isExtensionMatch(extension, type)) {
- return type;
- }
- }
- }
- // If in legacy mode or no configured directories matched, check default directory structure
- if (legacyMode || modelTypeDirectories.empty()) {
- std::string parentPath = filePath.parent_path().filename().string();
- std::transform(parentPath.begin(), parentPath.end(), parentPath.begin(), ::tolower);
- // Check default directory names
- if (parentPath == "checkpoints" || parentPath == "stable-diffusion") {
- if (isExtensionMatch(extension, ModelType::CHECKPOINT)) {
- return ModelType::CHECKPOINT;
- }
- } else if (parentPath == "controlnet") {
- if (isExtensionMatch(extension, ModelType::CONTROLNET)) {
- return ModelType::CONTROLNET;
- }
- } else if (parentPath == "lora") {
- if (isExtensionMatch(extension, ModelType::LORA)) {
- return ModelType::LORA;
- }
- } else if (parentPath == "vae") {
- if (isExtensionMatch(extension, ModelType::VAE)) {
- return ModelType::VAE;
- }
- } else if (parentPath == "taesd") {
- if (isExtensionMatch(extension, ModelType::TAESD)) {
- return ModelType::TAESD;
- }
- } else if (parentPath == "esrgan" || parentPath == "upscaler") {
- if (isExtensionMatch(extension, ModelType::ESRGAN)) {
- return ModelType::ESRGAN;
- }
- } else if (parentPath == "embeddings" || parentPath == "textual-inversion") {
- if (isExtensionMatch(extension, ModelType::EMBEDDING)) {
- return ModelType::EMBEDDING;
- }
- }
- }
- // Fall back to extension-based detection
- // Only return a model type if the extension matches expected extensions for that type
- if (isExtensionMatch(extension, ModelType::CHECKPOINT)) {
- return ModelType::CHECKPOINT;
- } else if (isExtensionMatch(extension, ModelType::LORA)) {
- return ModelType::LORA;
- } else if (isExtensionMatch(extension, ModelType::VAE)) {
- return ModelType::VAE;
- } else if (isExtensionMatch(extension, ModelType::TAESD)) {
- return ModelType::TAESD;
- } else if (isExtensionMatch(extension, ModelType::ESRGAN)) {
- return ModelType::ESRGAN;
- } else if (isExtensionMatch(extension, ModelType::CONTROLNET)) {
- return ModelType::CONTROLNET;
- } else if (isExtensionMatch(extension, ModelType::EMBEDDING)) {
- return ModelType::EMBEDDING;
- }
- return ModelType::NONE;
- }
- /**
- * @brief Get file information with timeout
- *
- * @param filePath The file path to get info for
- * @param timeoutMs Timeout in milliseconds
- * @return std::pair<bool, std::pair<uintmax_t, fs::file_time_type>> Success flag and file info
- */
- std::pair<bool, std::pair<uintmax_t, fs::file_time_type>> getFileInfoWithTimeout(
- const fs::path& filePath, int timeoutMs = 5000) {
- auto future = std::async(std::launch::async, [&filePath]() -> std::pair<uintmax_t, fs::file_time_type> {
- try {
- uintmax_t fileSize = fs::file_size(filePath);
- fs::file_time_type modifiedAt = fs::last_write_time(filePath);
- return {fileSize, modifiedAt};
- } catch (const fs::filesystem_error&) {
- return {0, fs::file_time_type{}};
- }
- });
- if (future.wait_for(std::chrono::milliseconds(timeoutMs)) == std::future_status::timeout) {
- std::cerr << "Timeout getting file info for " << filePath << std::endl;
- return {false, {0, fs::file_time_type{}}};
- }
- return {true, future.get()};
- }
- /**
- * @brief Scan a directory for models of a specific type (without holding mutex)
- *
- * @param directory The directory to scan
- * @param type The model type to look for
- * @param modelsMap Reference to the map to store results
- * @return bool True if scanning completed without cancellation
- */
- bool scanDirectory(const fs::path& directory, ModelType type, std::map<std::string, ModelInfo>& modelsMap) {
- if (scanCancelled.load()) {
- return false;
- }
- if (!fs::exists(directory) || !fs::is_directory(directory)) {
- return true;
- }
- try {
- for (const auto& entry : fs::recursive_directory_iterator(directory)) {
- if (scanCancelled.load()) {
- return false;
- }
- if (entry.is_regular_file()) {
- fs::path filePath = entry.path();
- ModelType detectedType = determineModelType(filePath);
- // Only add files that have a valid model type (not NONE)
- if (detectedType != ModelType::NONE && (type == ModelType::NONE || detectedType == type)) {
- ModelInfo info;
- // Calculate relative path from the scanned directory (not base models directory)
- fs::path relativePath = fs::relative(filePath, directory);
- std::string modelName = relativePath.string();
- // Check if model already exists to avoid duplicates
- if (modelsMap.find(modelName) == modelsMap.end()) {
- info.name = modelName;
- info.path = filePath.string();
- info.fullPath = fs::absolute(filePath).string();
- info.type = detectedType;
- info.isLoaded = false;
- info.description = ""; // Initialize description
- info.metadata = {}; // Initialize metadata
- // Get file info with timeout
- auto [success, fileInfo] = getFileInfoWithTimeout(filePath);
- if (success) {
- info.fileSize = fileInfo.first;
- info.modifiedAt = fileInfo.second;
- info.createdAt = fileInfo.second; // Use modified time as creation time for now
- } else {
- info.fileSize = 0;
- info.modifiedAt = fs::file_time_type{};
- info.createdAt = fs::file_time_type{};
- }
- // Try to load cached hash from .json file
- std::string hashFile = info.fullPath + ".json";
- if (fs::exists(hashFile)) {
- try {
- std::ifstream file(hashFile);
- nlohmann::json hashData = nlohmann::json::parse(file);
- if (hashData.contains("sha256") && hashData["sha256"].is_string()) {
- info.sha256 = hashData["sha256"];
- } else {
- info.sha256 = "";
- }
- } catch (...) {
- info.sha256 = ""; // If parsing fails, leave empty
- }
- } else {
- info.sha256 = ""; // No cached hash file
- }
- // Detect architecture for checkpoint models
- if (detectedType == ModelType::CHECKPOINT) {
- try {
- ModelDetectionResult detection = ModelDetector::detectModel(info.fullPath);
- // For .ckpt files that can't be detected, default to SD1.5
- if (detection.architecture == ModelArchitecture::UNKNOWN &&
- (filePath.extension() == ".ckpt" || filePath.extension() == ".pt")) {
- info.architecture = "Stable Diffusion 1.5 (assumed)";
- info.recommendedVAE = "vae-ft-mse-840000-ema-pruned.safetensors";
- info.recommendedWidth = 512;
- info.recommendedHeight = 512;
- info.recommendedSteps = 20;
- info.recommendedSampler = "euler_a";
- } else {
- info.architecture = detection.architectureName;
- info.recommendedVAE = detection.recommendedVAE;
- // Parse recommended parameters
- if (detection.suggestedParams.count("width")) {
- info.recommendedWidth = std::stoi(detection.suggestedParams["width"]);
- }
- if (detection.suggestedParams.count("height")) {
- info.recommendedHeight = std::stoi(detection.suggestedParams["height"]);
- }
- if (detection.suggestedParams.count("steps")) {
- info.recommendedSteps = std::stoi(detection.suggestedParams["steps"]);
- }
- if (detection.suggestedParams.count("sampler")) {
- info.recommendedSampler = detection.suggestedParams["sampler"];
- }
- }
- // Build list of required models based on architecture
- if (detection.needsVAE && !detection.recommendedVAE.empty()) {
- info.requiredModels.push_back("VAE: " + detection.recommendedVAE);
- }
- // Add CLIP-L if required
- if (detection.suggestedParams.count("clip_l_required")) {
- info.requiredModels.push_back("CLIP-L: " + detection.suggestedParams.at("clip_l_required"));
- }
- // Add CLIP-G if required
- if (detection.suggestedParams.count("clip_g_required")) {
- info.requiredModels.push_back("CLIP-G: " + detection.suggestedParams.at("clip_g_required"));
- }
- // Add T5XXL if required
- if (detection.suggestedParams.count("t5xxl_required")) {
- info.requiredModels.push_back("T5XXL: " + detection.suggestedParams.at("t5xxl_required"));
- }
- // Add Qwen models if required
- if (detection.suggestedParams.count("qwen2vl_required")) {
- info.requiredModels.push_back("Qwen2-VL: " + detection.suggestedParams.at("qwen2vl_required"));
- }
- if (detection.suggestedParams.count("qwen2vl_vision_required")) {
- info.requiredModels.push_back("Qwen2-VL-Vision: " + detection.suggestedParams.at("qwen2vl_vision_required"));
- }
- } catch (const std::exception& e) {
- // If detection fails completely, default to SD1.5
- info.architecture = "Stable Diffusion 1.5 (assumed)";
- info.recommendedVAE = "vae-ft-mse-840000-ema-pruned.safetensors";
- info.recommendedWidth = 512;
- info.recommendedHeight = 512;
- info.recommendedSteps = 20;
- info.recommendedSampler = "euler_a";
- }
- }
- modelsMap[info.name] = info;
- }
- }
- }
- }
- } catch (const fs::filesystem_error& e) {
- // Silently handle filesystem errors
- }
- return !scanCancelled.load();
- }
- };
- ModelManager::ModelManager() : pImpl(std::make_unique<Impl>()) {
- }
- ModelManager::~ModelManager() = default;
- bool ModelManager::scanModelsDirectory() {
- // Reset cancellation flag
- pImpl->scanCancelled.store(false);
- // Create temporary map to store scan results (outside of lock)
- std::map<std::string, ModelInfo> tempModels;
- if (pImpl->legacyMode) {
- // Legacy mode: scan the models directory itself and its subdirectories
- fs::path modelsPath(pImpl->modelsDirectory);
- if (!fs::exists(modelsPath) || !fs::is_directory(modelsPath)) {
- std::cerr << "Models directory does not exist: " << pImpl->modelsDirectory << std::endl;
- return false;
- }
- // First, scan the main models directory itself for any model files
- // This handles the case where models are directly in the specified directory
- if (!pImpl->scanDirectory(modelsPath, ModelType::NONE, tempModels)) {
- return false;
- }
- // Then scan known subdirectories for organized models
- std::vector<std::pair<fs::path, ModelType>> directoriesToScan = {
- {modelsPath / "stable-diffusion", ModelType::CHECKPOINT},
- {modelsPath / "controlnet", ModelType::CONTROLNET},
- {modelsPath / "lora", ModelType::LORA},
- {modelsPath / "vae", ModelType::VAE},
- {modelsPath / "taesd", ModelType::TAESD},
- {modelsPath / "esrgan", ModelType::ESRGAN},
- {modelsPath / "upscaler", ModelType::ESRGAN},
- {modelsPath / "embeddings", ModelType::EMBEDDING},
- {modelsPath / "textual-inversion", ModelType::EMBEDDING},
- {modelsPath / "checkpoints", ModelType::CHECKPOINT}, // Also scan checkpoints subdirectory
- {modelsPath / "other", ModelType::NONE} // Scan for any type
- };
- for (const auto& [dirPath, type] : directoriesToScan) {
- if (!pImpl->scanDirectory(dirPath, type, tempModels)) {
- return false;
- }
- }
- } else {
- // Explicit mode: scan configured directories for each model type
- std::vector<std::pair<ModelType, std::string>> directoriesToScan = {
- {ModelType::CHECKPOINT, pImpl->getModelTypeDirectory(ModelType::CHECKPOINT)},
- {ModelType::CONTROLNET, pImpl->getModelTypeDirectory(ModelType::CONTROLNET)},
- {ModelType::LORA, pImpl->getModelTypeDirectory(ModelType::LORA)},
- {ModelType::VAE, pImpl->getModelTypeDirectory(ModelType::VAE)},
- {ModelType::TAESD, pImpl->getModelTypeDirectory(ModelType::TAESD)},
- {ModelType::ESRGAN, pImpl->getModelTypeDirectory(ModelType::ESRGAN)},
- {ModelType::EMBEDDING, pImpl->getModelTypeDirectory(ModelType::EMBEDDING)}
- };
- for (const auto& [type, dirPath] : directoriesToScan) {
- if (!dirPath.empty()) {
- if (!pImpl->scanDirectory(dirPath, type, tempModels)) {
- return false;
- }
- }
- }
- }
- // Brief exclusive lock only to swap the data
- {
- std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
- pImpl->availableModels.swap(tempModels);
- }
- return true;
- }
- bool ModelManager::loadModel(const std::string& name, const std::string& path, ModelType type) {
- std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
- // Check if model is already loaded
- if (pImpl->loadedModels.find(name) != pImpl->loadedModels.end()) {
- return true;
- }
- // Check if file exists
- if (!fs::exists(path)) {
- std::cerr << "Model file does not exist: " << path << std::endl;
- return false;
- }
- // Create and initialize the stable-diffusion wrapper
- auto wrapper = std::make_unique<StableDiffusionWrapper>();
- // Set up generation parameters for model loading
- StableDiffusionWrapper::GenerationParams loadParams;
- loadParams.modelPath = path;
- loadParams.modelType = "f16"; // Default to f16 for better performance
- // Try to load the model
- if (!wrapper->loadModel(path, loadParams)) {
- std::cerr << "Failed to load model '" << name << "': " << wrapper->getLastError() << std::endl;
- return false;
- }
- pImpl->loadedModels[name] = std::move(wrapper);
- // Update model info
- if (pImpl->availableModels.find(name) != pImpl->availableModels.end()) {
- pImpl->availableModels[name].isLoaded = true;
- } else {
- // Create a new model info entry
- ModelInfo info;
- info.name = name;
- info.path = path;
- info.fullPath = fs::absolute(path).string();
- info.type = type;
- info.isLoaded = true;
- info.sha256 = "";
- info.description = ""; // Initialize description
- info.metadata = {}; // Initialize metadata
- try {
- info.fileSize = fs::file_size(path);
- info.modifiedAt = fs::last_write_time(path);
- info.createdAt = info.modifiedAt; // Use modified time as creation time for now
- } catch (const fs::filesystem_error& e) {
- std::cerr << "Error getting file info for " << path << ": " << e.what() << std::endl;
- info.fileSize = 0;
- info.modifiedAt = fs::file_time_type{};
- info.createdAt = fs::file_time_type{};
- }
- pImpl->availableModels[name] = info;
- }
- return true;
- }
- bool ModelManager::loadModel(const std::string& name) {
- std::string path;
- ModelType type;
- {
- std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
- // Check if model exists in available models
- auto it = pImpl->availableModels.find(name);
- if (it == pImpl->availableModels.end()) {
- std::cerr << "Model '" << name << "' not found in available models" << std::endl;
- return false;
- }
- // Check if already loaded
- if (pImpl->loadedModels.find(name) != pImpl->loadedModels.end()) {
- return true;
- }
- // Extract path and type while we have the lock
- path = it->second.path;
- type = it->second.type;
- } // Release lock here
- // Load the model without holding the lock
- return loadModel(name, path, type);
- }
- bool ModelManager::unloadModel(const std::string& name) {
- std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
- // Check if model is loaded
- auto loadedIt = pImpl->loadedModels.find(name);
- if (loadedIt == pImpl->loadedModels.end()) {
- return false;
- }
- // Unload the model properly
- if (loadedIt->second) {
- loadedIt->second->unloadModel();
- }
- pImpl->loadedModels.erase(loadedIt);
- // Update model info
- auto availableIt = pImpl->availableModels.find(name);
- if (availableIt != pImpl->availableModels.end()) {
- availableIt->second.isLoaded = false;
- }
- return true;
- }
- StableDiffusionWrapper* ModelManager::getModel(const std::string& name) {
- std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
- auto it = pImpl->loadedModels.find(name);
- if (it == pImpl->loadedModels.end()) {
- return nullptr;
- }
- return it->second.get();
- }
- std::map<std::string, ModelManager::ModelInfo> ModelManager::getAllModels() const {
- std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
- return pImpl->availableModels;
- }
- std::vector<ModelManager::ModelInfo> ModelManager::getModelsByType(ModelType type) const {
- std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
- std::vector<ModelInfo> result;
- for (const auto& pair : pImpl->availableModels) {
- if (pair.second.type == type) {
- result.push_back(pair.second);
- }
- }
- return result;
- }
- ModelManager::ModelInfo ModelManager::getModelInfo(const std::string& name) const {
- std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
- auto it = pImpl->availableModels.find(name);
- if (it == pImpl->availableModels.end()) {
- return ModelInfo{}; // Return empty ModelInfo if not found
- }
- return it->second;
- }
- bool ModelManager::isModelLoaded(const std::string& name) const {
- std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
- auto it = pImpl->loadedModels.find(name);
- return it != pImpl->loadedModels.end();
- }
- size_t ModelManager::getLoadedModelsCount() const {
- std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
- return pImpl->loadedModels.size();
- }
- size_t ModelManager::getAvailableModelsCount() const {
- std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
- return pImpl->availableModels.size();
- }
- void ModelManager::setModelsDirectory(const std::string& path) {
- pImpl->modelsDirectory = path;
- }
- std::string ModelManager::getModelsDirectory() const {
- return pImpl->modelsDirectory;
- }
- std::string ModelManager::modelTypeToString(ModelType type) {
- switch (type) {
- case ModelType::LORA:
- return "lora";
- case ModelType::CHECKPOINT:
- return "checkpoint";
- case ModelType::VAE:
- return "vae";
- case ModelType::PRESETS:
- return "presets";
- case ModelType::PROMPTS:
- return "prompts";
- case ModelType::NEG_PROMPTS:
- return "neg_prompts";
- case ModelType::TAESD:
- return "taesd";
- case ModelType::ESRGAN:
- return "esrgan";
- case ModelType::CONTROLNET:
- return "controlnet";
- case ModelType::UPSCALER:
- return "upscaler";
- case ModelType::EMBEDDING:
- return "embedding";
- default:
- return "unknown";
- }
- }
- ModelType ModelManager::stringToModelType(const std::string& typeStr) {
- std::string lowerType = typeStr;
- std::transform(lowerType.begin(), lowerType.end(), lowerType.begin(), ::tolower);
- if (lowerType == "lora") {
- return ModelType::LORA;
- } else if (lowerType == "checkpoint" || lowerType == "stable-diffusion") {
- return ModelType::CHECKPOINT;
- } else if (lowerType == "vae") {
- return ModelType::VAE;
- } else if (lowerType == "presets") {
- return ModelType::PRESETS;
- } else if (lowerType == "prompts") {
- return ModelType::PROMPTS;
- } else if (lowerType == "neg_prompts" || lowerType == "negative_prompts") {
- return ModelType::NEG_PROMPTS;
- } else if (lowerType == "taesd") {
- return ModelType::TAESD;
- } else if (lowerType == "esrgan") {
- return ModelType::ESRGAN;
- } else if (lowerType == "controlnet") {
- return ModelType::CONTROLNET;
- } else if (lowerType == "upscaler") {
- return ModelType::UPSCALER;
- } else if (lowerType == "embedding" || lowerType == "textual-inversion") {
- return ModelType::EMBEDDING;
- }
- return ModelType::NONE;
- }
- bool ModelManager::setModelTypeDirectory(ModelType type, const std::string& path) {
- std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
- if (!pImpl->validateDirectory(path)) {
- return false;
- }
- pImpl->modelTypeDirectories[type] = path;
- pImpl->legacyMode = false;
- return true;
- }
- std::string ModelManager::getModelTypeDirectory(ModelType type) const {
- std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
- return pImpl->getModelTypeDirectory(type);
- }
- bool ModelManager::setAllModelTypeDirectories(const std::map<ModelType, std::string>& directories) {
- std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
- // Validate all directories first
- for (const auto& [type, path] : directories) {
- if (!path.empty() && !pImpl->validateDirectory(path)) {
- return false;
- }
- }
- // Set all directories
- pImpl->modelTypeDirectories = directories;
- pImpl->legacyMode = false;
- return true;
- }
- std::map<ModelType, std::string> ModelManager::getAllModelTypeDirectories() const {
- std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
- return pImpl->modelTypeDirectories;
- }
- void ModelManager::resetToLegacyDirectories() {
- // Note: This method should be called with modelsMutex already locked
- pImpl->modelTypeDirectories.clear();
- pImpl->legacyMode = true;
- }
- bool ModelManager::configureFromServerConfig(const ServerConfig& config) {
- std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
- // Set the base models directory
- pImpl->modelsDirectory = config.modelsDir;
- if (config.legacyMode) {
- // Legacy mode: use single models directory
- resetToLegacyDirectories();
- return true;
- } else {
- // Explicit mode: set per-type directories
- std::map<ModelType, std::string> directories;
- if (!config.checkpoints.empty()) {
- directories[ModelType::CHECKPOINT] = config.checkpoints;
- }
- if (!config.controlnetDir.empty()) {
- directories[ModelType::CONTROLNET] = config.controlnetDir;
- }
- if (!config.embeddingsDir.empty()) {
- directories[ModelType::EMBEDDING] = config.embeddingsDir;
- }
- if (!config.esrganDir.empty()) {
- directories[ModelType::ESRGAN] = config.esrganDir;
- }
- if (!config.loraDir.empty()) {
- directories[ModelType::LORA] = config.loraDir;
- }
- if (!config.taesdDir.empty()) {
- directories[ModelType::TAESD] = config.taesdDir;
- }
- if (!config.vaeDir.empty()) {
- directories[ModelType::VAE] = config.vaeDir;
- }
- // Validate all directories first
- for (const auto& [type, path] : directories) {
- if (!path.empty() && !pImpl->validateDirectory(path)) {
- return false;
- }
- }
- // Set all directories (inlined to avoid deadlock from calling setAllModelTypeDirectories)
- pImpl->modelTypeDirectories = directories;
- pImpl->legacyMode = false;
- return true;
- }
- }
- void ModelManager::cancelScan() {
- pImpl->scanCancelled.store(true);
- }
- // SHA256 Hashing Implementation
- std::string ModelManager::computeModelHash(const std::string& modelName) {
- std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
-
- auto it = pImpl->availableModels.find(modelName);
- if (it == pImpl->availableModels.end()) {
- std::cerr << "Model not found: " << modelName << std::endl;
- return "";
- }
-
- std::string filePath = it->second.fullPath;
- lock.unlock();
-
- std::ifstream file(filePath, std::ios::binary);
- if (!file.is_open()) {
- std::cerr << "Failed to open file for hashing: " << filePath << std::endl;
- return "";
- }
- // Create and initialize EVP context for SHA256
- EVP_MD_CTX* mdctx = EVP_MD_CTX_new();
- if (mdctx == nullptr) {
- std::cerr << "Failed to create EVP context" << std::endl;
- return "";
- }
- if (EVP_DigestInit_ex(mdctx, EVP_sha256(), nullptr) != 1) {
- std::cerr << "Failed to initialize SHA256 digest" << std::endl;
- EVP_MD_CTX_free(mdctx);
- return "";
- }
- const size_t bufferSize = 8192;
- char buffer[bufferSize];
- std::cout << "Computing SHA256 for: " << modelName << std::endl;
- size_t totalRead = 0;
- size_t lastReportedMB = 0;
- while (file.read(buffer, bufferSize) || file.gcount() > 0) {
- size_t bytesRead = file.gcount();
- if (EVP_DigestUpdate(mdctx, buffer, bytesRead) != 1) {
- std::cerr << "Failed to update digest" << std::endl;
- EVP_MD_CTX_free(mdctx);
- return "";
- }
- totalRead += bytesRead;
- // Progress reporting every 100MB
- size_t currentMB = totalRead / (1024 * 1024);
- if (currentMB >= lastReportedMB + 100) {
- std::cout << " Hashed " << currentMB << " MB..." << std::endl;
- lastReportedMB = currentMB;
- }
- }
- file.close();
- unsigned char hash[EVP_MAX_MD_SIZE];
- unsigned int hashLen = 0;
- if (EVP_DigestFinal_ex(mdctx, hash, &hashLen) != 1) {
- std::cerr << "Failed to finalize digest" << std::endl;
- EVP_MD_CTX_free(mdctx);
- return "";
- }
- EVP_MD_CTX_free(mdctx);
- // Convert to hex string
- std::ostringstream oss;
- for (unsigned int i = 0; i < hashLen; i++) {
- oss << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(hash[i]);
- }
-
- std::string hashStr = oss.str();
- std::cout << "Hash computed: " << hashStr.substr(0, 16) << "..." << std::endl;
-
- return hashStr;
- }
- std::string ModelManager::loadModelHashFromFile(const std::string& modelName) {
- std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
-
- auto it = pImpl->availableModels.find(modelName);
- if (it == pImpl->availableModels.end()) {
- return "";
- }
-
- std::string jsonPath = it->second.fullPath + ".json";
- lock.unlock();
-
- if (!fs::exists(jsonPath)) {
- return "";
- }
-
- try {
- std::ifstream jsonFile(jsonPath);
- if (!jsonFile.is_open()) {
- return "";
- }
-
- nlohmann::json j;
- jsonFile >> j;
- jsonFile.close();
-
- if (j.contains("sha256") && j["sha256"].is_string()) {
- return j["sha256"].get<std::string>();
- }
- } catch (const std::exception& e) {
- std::cerr << "Error loading hash from JSON: " << e.what() << std::endl;
- }
-
- return "";
- }
- bool ModelManager::saveModelHashToFile(const std::string& modelName, const std::string& hash) {
- std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
-
- auto it = pImpl->availableModels.find(modelName);
- if (it == pImpl->availableModels.end()) {
- return false;
- }
-
- std::string jsonPath = it->second.fullPath + ".json";
- size_t fileSize = it->second.fileSize;
- lock.unlock();
-
- try {
- nlohmann::json j;
- j["sha256"] = hash;
- j["file_size"] = fileSize;
- j["computed_at"] = std::chrono::system_clock::now().time_since_epoch().count();
-
- std::ofstream jsonFile(jsonPath);
- if (!jsonFile.is_open()) {
- std::cerr << "Failed to open file for writing: " << jsonPath << std::endl;
- return false;
- }
-
- jsonFile << j.dump(2);
- jsonFile.close();
-
- std::cout << "Saved hash to: " << jsonPath << std::endl;
- return true;
- } catch (const std::exception& e) {
- std::cerr << "Error saving hash to JSON: " << e.what() << std::endl;
- return false;
- }
- }
- std::string ModelManager::findModelByHash(const std::string& hash) {
- if (hash.length() < 10) {
- std::cerr << "Hash must be at least 10 characters" << std::endl;
- return "";
- }
-
- std::shared_lock<std::shared_mutex> lock(pImpl->modelsMutex);
-
- for (const auto& [name, info] : pImpl->availableModels) {
- if (info.sha256.empty()) {
- continue;
- }
-
- // Support full or partial match (minimum 10 chars)
- if (info.sha256 == hash || info.sha256.substr(0, hash.length()) == hash) {
- return name;
- }
- }
-
- return "";
- }
- std::string ModelManager::ensureModelHash(const std::string& modelName, bool forceCompute) {
- // Try to load existing hash if not forcing recompute
- if (!forceCompute) {
- std::string existingHash = loadModelHashFromFile(modelName);
- if (!existingHash.empty()) {
- // Update in-memory model info
- std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
- auto it = pImpl->availableModels.find(modelName);
- if (it != pImpl->availableModels.end()) {
- it->second.sha256 = existingHash;
- }
- return existingHash;
- }
- }
-
- // Compute new hash
- std::string hash = computeModelHash(modelName);
- if (hash.empty()) {
- return "";
- }
-
- // Save to file
- saveModelHashToFile(modelName, hash);
-
- // Update in-memory model info
- std::unique_lock<std::shared_mutex> lock(pImpl->modelsMutex);
- auto it = pImpl->availableModels.find(modelName);
- if (it != pImpl->availableModels.end()) {
- it->second.sha256 = hash;
- }
-
- return hash;
- }
|