#include "model_manager.h" #include "model_detector.h" #include "stable_diffusion_wrapper.h" #include #include #include #include #include #include #include #include #include #include #include #include #include namespace fs = std::filesystem; // File extension mappings for each model type const std::vector CHECKPOINT_FILE_EXTENSIONS = {"safetensors", "ckpt", "gguf"}; const std::vector EMBEDDING_FILE_EXTENSIONS = {"safetensors", "pt"}; const std::vector LORA_FILE_EXTENSIONS = {"safetensors", "ckpt"}; const std::vector VAE_FILE_EXTENSIONS = {"safetensors", "pt", "ckpt", "gguf"}; const std::vector TAESD_FILE_EXTENSIONS = {"safetensors", "pth", "gguf"}; const std::vector ESRGAN_FILE_EXTENSIONS = {"pth", "pt"}; const std::vector CONTROLNET_FILE_EXTENSIONS = {"safetensors", "pth"}; class ModelManager::Impl { public: std::string modelsDirectory = "./models"; std::map modelTypeDirectories; std::map availableModels; std::map> loadedModels; mutable std::shared_mutex modelsMutex; std::atomic scanCancelled{false}; /** * @brief Validate a directory path * * @param path The directory path to validate * @return true if the directory exists and is valid, false otherwise */ bool validateDirectory(const std::string& path) const { if (path.empty()) { return false; } std::filesystem::path dirPath(path); if (!std::filesystem::exists(dirPath)) { std::cerr << "Directory does not exist: " << path << std::endl; return false; } if (!std::filesystem::is_directory(dirPath)) { std::cerr << "Path is not a directory: " << path << std::endl; return false; } return true; } /** * @brief Get default directory name for a model type * * @param type The model type * @return std::string Default directory name */ std::string getDefaultDirectoryName(ModelType type) const { switch (type) { case ModelType::CHECKPOINT: return "checkpoints"; case ModelType::CONTROLNET: return "controlnet"; case ModelType::EMBEDDING: return "embeddings"; case ModelType::ESRGAN: return "esrgan"; case ModelType::LORA: return "lora"; case ModelType::TAESD: return "taesd"; case ModelType::VAE: return "vae"; case ModelType::DIFFUSION_MODELS: return "diffusion_models"; default: return ""; } } /** * @brief Get directory path for a model type * * @param type The model type * @return std::string Directory path, empty if not set */ std::string getModelTypeDirectory(ModelType type) const { auto it = modelTypeDirectories.find(type); if (it != modelTypeDirectories.end()) { return it->second; } // Always use explicit directory configuration std::string defaultDir = getDefaultDirectoryName(type); if (!defaultDir.empty()) { return modelsDirectory + "/" + defaultDir; } return ""; } /** * @brief Get file extensions for a specific model type * * @param type The model type * @return const std::vector& Vector of file extensions */ const std::vector& getFileExtensions(ModelType type) const { switch (type) { case ModelType::CHECKPOINT: case ModelType::DIFFUSION_MODELS: return CHECKPOINT_FILE_EXTENSIONS; case ModelType::EMBEDDING: return EMBEDDING_FILE_EXTENSIONS; case ModelType::LORA: return LORA_FILE_EXTENSIONS; case ModelType::VAE: return VAE_FILE_EXTENSIONS; case ModelType::TAESD: return TAESD_FILE_EXTENSIONS; case ModelType::ESRGAN: return ESRGAN_FILE_EXTENSIONS; case ModelType::CONTROLNET: return CONTROLNET_FILE_EXTENSIONS; default: static const std::vector empty; return empty; } } /** * @brief Check if a file extension matches a model type * * @param extension The file extension * @param type The model type * @return true if the extension matches the model type */ bool isExtensionMatch(const std::string& extension, ModelType type) const { const auto& extensions = getFileExtensions(type); return std::find(extensions.begin(), extensions.end(), extension) != extensions.end(); } /** * @brief FIXED: Determine model type based on file path and extension * * THIS IS THE FIXED VERSION - no extension-based fallback! * Only returns a model type if the file is actually in the right directory. * * @param filePath The file path * @return ModelType The determined model type */ ModelType determineModelType(const fs::path& filePath) const { std::string extension = filePath.extension().string(); if (extension.empty()) { return ModelType::NONE; } // Remove the dot from extension if (extension[0] == '.') { extension = extension.substr(1); } // Convert to lowercase for comparison std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower); // Check if the file resides under a directory registered for a given ModelType fs::path absoluteFilePath = fs::absolute(filePath); // First check configured directories (if any) for (const auto& [type, directory] : modelTypeDirectories) { if (!directory.empty()) { fs::path absoluteDirPath = fs::absolute(directory).lexically_normal(); fs::path normalizedFilePath = absoluteFilePath.lexically_normal(); // Check if the file is under this directory (directly or in subdirectories) // Get the relative path from directory to file auto relativePath = normalizedFilePath.lexically_relative(absoluteDirPath); // If relative path doesn't start with "..", then file is under the directory std::string relPathStr = relativePath.string(); bool isUnderDirectory = !relPathStr.empty() && relPathStr.substr(0, 2) != ".." && relPathStr[0] != '/'; if (isUnderDirectory && isExtensionMatch(extension, type)) { return type; } } } // Check default directory structure when no explicit directories configured if (modelTypeDirectories.empty()) { // Check the entire path hierarchy for model type directories fs::path currentPath = filePath.parent_path(); while (currentPath.has_filename()) { std::string dirName = currentPath.filename().string(); std::transform(dirName.begin(), dirName.end(), dirName.begin(), ::tolower); // Check default directory names if (dirName == "checkpoints" || dirName == "stable-diffusion") { if (isExtensionMatch(extension, ModelType::CHECKPOINT)) { return ModelType::CHECKPOINT; } } else if (dirName == "controlnet") { if (isExtensionMatch(extension, ModelType::CONTROLNET)) { return ModelType::CONTROLNET; } } else if (dirName == "lora") { if (isExtensionMatch(extension, ModelType::LORA)) { return ModelType::LORA; } } else if (dirName == "vae") { if (isExtensionMatch(extension, ModelType::VAE)) { return ModelType::VAE; } } else if (dirName == "taesd") { if (isExtensionMatch(extension, ModelType::TAESD)) { return ModelType::TAESD; } } else if (dirName == "esrgan" || dirName == "upscaler") { if (isExtensionMatch(extension, ModelType::ESRGAN)) { return ModelType::ESRGAN; } } else if (dirName == "embeddings" || dirName == "textual-inversion") { } else if (dirName == "diffusion_models" || dirName == "diffusion") { if (isExtensionMatch(extension, ModelType::DIFFUSION_MODELS)) { return ModelType::DIFFUSION_MODELS; } if (isExtensionMatch(extension, ModelType::EMBEDDING)) { return ModelType::EMBEDDING; } } // Move up to parent directory currentPath = currentPath.parent_path(); } } // NO EXTENSION-BASED FALLBACK - this was the bug! // Files must be in the correct directory to be recognized return ModelType::NONE; } /** * @brief Get file information with timeout * * @param filePath The file path to get info for * @param timeoutMs Timeout in milliseconds * @return std::pair> Success flag and file info */ std::pair> getFileInfoWithTimeout( const fs::path& filePath, int timeoutMs = 5000) { auto future = std::async(std::launch::async, [&filePath]() -> std::pair { try { uintmax_t fileSize = fs::file_size(filePath); fs::file_time_type modifiedAt = fs::last_write_time(filePath); return {fileSize, modifiedAt}; } catch (const fs::filesystem_error&) { return {0, fs::file_time_type{}}; } }); if (future.wait_for(std::chrono::milliseconds(timeoutMs)) == std::future_status::timeout) { std::cerr << "Timeout getting file info for " << filePath << std::endl; return {false, {0, fs::file_time_type{}}}; } return {true, future.get()}; } /** * @brief Scan a directory for models recursively (all types, without holding mutex) * * Recursively walks the directory tree to find all model files. For each model * found, constructs the display name as 'relative_path/model_name' where * relative_path is the path from the models root directory to the file's * containing folder (using forward slashes). Models in the root directory * appear without a prefix. * * @param directory The directory to scan * @param modelsMap Reference to the map to store results * @return bool True if scanning completed without cancellation */ bool scanDirectory(const fs::path& directory, std::map& modelsMap) { if (scanCancelled.load()) { return false; } if (!fs::exists(directory) || !fs::is_directory(directory)) { return true; } try { for (const auto& entry : fs::recursive_directory_iterator(directory)) { if (scanCancelled.load()) { return false; } if (entry.is_regular_file()) { fs::path filePath = entry.path(); ModelType detectedType = determineModelType(filePath); // Only add files that have a valid model type (not NONE) if (detectedType != ModelType::NONE) { ModelInfo info; // Calculate relative path from the model type directory to exclude model type folder names fs::path relativePath; try { // Get the specific model type directory for this detected type std::string modelTypeDir = getModelTypeDirectory(detectedType); if (!modelTypeDir.empty()) { fs::path typeBaseDir(modelTypeDir); // Get relative path from the model type directory relativePath = fs::relative(filePath, typeBaseDir); } else { // Fallback: use the base models directory if (!modelsDirectory.empty()) { fs::path baseDir(modelsDirectory); relativePath = fs::relative(filePath, baseDir); } else { relativePath = fs::relative(filePath, directory); } } } catch (const fs::filesystem_error&) { // If relative path calculation fails, use filename only relativePath = filePath.filename(); } std::string modelName = relativePath.string(); // Normalize path separators for consistency std::replace(modelName.begin(), modelName.end(), '\\', '/'); // Check if model already exists to avoid duplicates if (modelsMap.find(modelName) == modelsMap.end()) { info.name = modelName; info.path = filePath.string(); info.fullPath = fs::absolute(filePath).string(); info.type = detectedType; info.isLoaded = false; info.description = ""; // Initialize description info.metadata = {}; // Initialize metadata // Get file info with timeout auto [success, fileInfo] = getFileInfoWithTimeout(filePath); if (success) { info.fileSize = fileInfo.first; info.modifiedAt = fileInfo.second; info.createdAt = fileInfo.second; // Use modified time as creation time for now } else { info.fileSize = 0; info.modifiedAt = fs::file_time_type{}; info.createdAt = fs::file_time_type{}; } // Try to load cached hash from .json file std::string hashFile = info.fullPath + ".json"; if (fs::exists(hashFile)) { try { std::ifstream file(hashFile); nlohmann::json hashData = nlohmann::json::parse(file); if (hashData.contains("sha256") && hashData["sha256"].is_string()) { info.sha256 = hashData["sha256"]; } else { info.sha256 = ""; } } catch (...) { info.sha256 = ""; // If parsing fails, leave empty } } else { info.sha256 = ""; // No cached hash file } // Detect architecture for checkpoint models (including diffusion_models) if (detectedType == ModelType::CHECKPOINT || detectedType == ModelType::DIFFUSION_MODELS) { // Try to get cached result first ModelDetectionCache::CacheEntry cachedEntry = ModelDetectionCache::getCachedResult(info.fullPath, info.modifiedAt); if (cachedEntry.isValid) { // Use cached results info.architecture = cachedEntry.architecture; info.recommendedVAE = cachedEntry.recommendedVAE; info.recommendedWidth = cachedEntry.recommendedWidth; info.recommendedHeight = cachedEntry.recommendedHeight; info.recommendedSteps = cachedEntry.recommendedSteps; info.recommendedSampler = cachedEntry.recommendedSampler; info.requiredModels = cachedEntry.requiredModels; info.missingModels = cachedEntry.missingModels; info.cacheValid = true; info.cacheModifiedAt = cachedEntry.cachedAt; info.cachePathType = cachedEntry.pathType; info.useFolderBasedDetection = (cachedEntry.detectionSource == "folder"); info.detectionSource = cachedEntry.detectionSource; std::cout << "Using cached detection for " << info.name << " (source: " << cachedEntry.detectionSource << ")" << std::endl; } else { // Perform new detection try { // First try folder-based detection std::string checkpointsDir = getModelTypeDirectory(ModelType::CHECKPOINT); std::string diffusionModelsDir = getModelTypeDirectory(ModelType::DIFFUSION_MODELS); std::string pathType = ModelPathSelector::selectPathType( info.fullPath, checkpointsDir, diffusionModelsDir); bool useFolderBasedDetection = (pathType == "diffusion_model_path"); ModelDetectionResult detection; std::string detectionSource; if (useFolderBasedDetection) { // For models in diffusion_models directory, we can skip full detection // and use folder-based logic detectionSource = "folder"; info.architecture = "Modern Architecture (Flux/SD3)"; info.recommendedVAE = "ae.safetensors"; info.recommendedWidth = 1024; info.recommendedHeight = 1024; info.recommendedSteps = 20; info.recommendedSampler = "euler"; // Create a minimal detection result for caching detection.architecture = ModelArchitecture::FLUX_DEV; // Default modern detection.architectureName = info.architecture; detection.recommendedVAE = info.recommendedVAE; detection.suggestedParams["width"] = std::to_string(info.recommendedWidth); detection.suggestedParams["height"] = std::to_string(info.recommendedHeight); detection.suggestedParams["steps"] = std::to_string(info.recommendedSteps); detection.suggestedParams["sampler"] = info.recommendedSampler; std::cout << "Using folder-based detection for " << info.name << " in " << pathType << std::endl; } else { // Perform full architecture detection detectionSource = "architecture"; detection = ModelDetector::detectModel(info.fullPath); // For .ckpt files that can't be detected, default to SD1.5 if (detection.architecture == ModelArchitecture::UNKNOWN && (filePath.extension() == ".ckpt" || filePath.extension() == ".pt")) { info.architecture = "Stable Diffusion 1.5 (assumed)"; info.recommendedVAE = "vae-ft-mse-840000-ema-pruned.safetensors"; info.recommendedWidth = 512; info.recommendedHeight = 512; info.recommendedSteps = 20; info.recommendedSampler = "euler_a"; detectionSource = "fallback"; } else { info.architecture = detection.architectureName; info.recommendedVAE = detection.recommendedVAE; // Parse recommended parameters if (detection.suggestedParams.count("width")) { info.recommendedWidth = std::stoi(detection.suggestedParams["width"]); } if (detection.suggestedParams.count("height")) { info.recommendedHeight = std::stoi(detection.suggestedParams["height"]); } if (detection.suggestedParams.count("steps")) { info.recommendedSteps = std::stoi(detection.suggestedParams["steps"]); } if (detection.suggestedParams.count("sampler")) { info.recommendedSampler = detection.suggestedParams["sampler"]; } } std::cout << "Using architecture-based detection for " << info.name << std::endl; } // Build list of required models based on detection // Note: VAE is now optional for SD1x and SDXL models, so we don't add it to requiredModels // The VAE will still be recommended but not required // Add CLIP-L if required if (detection.suggestedParams.count("clip_l_required")) { info.requiredModels.push_back("CLIP-L: " + detection.suggestedParams.at("clip_l_required")); } // Add CLIP-G if required if (detection.suggestedParams.count("clip_g_required")) { info.requiredModels.push_back("CLIP-G: " + detection.suggestedParams.at("clip_g_required")); } // Add T5XXL if required if (detection.suggestedParams.count("t5xxl_required")) { info.requiredModels.push_back("T5XXL: " + detection.suggestedParams.at("t5xxl_required")); } // Add Qwen models if required if (detection.suggestedParams.count("qwen2vl_required")) { info.requiredModels.push_back("Qwen2-VL: " + detection.suggestedParams.at("qwen2vl_required")); } if (detection.suggestedParams.count("qwen2vl_vision_required")) { info.requiredModels.push_back("Qwen2-VL-Vision: " + detection.suggestedParams.at("qwen2vl_vision_required")); } // Check if required models exist if (!info.requiredModels.empty()) { // Create a temporary ModelManager instance to check existence ModelManager tempManager; std::vector modelDetails = tempManager.checkRequiredModelsExistence(info.requiredModels); // Clear missing models and repopulate based on existence check info.missingModels.clear(); for (const auto& detail : modelDetails) { if (!detail.exists) { info.missingModels.push_back(detail.type + ": " + detail.name); } } std::cout << "Model " << info.name << " requires " << info.requiredModels.size() << " models, " << info.missingModels.size() << " are missing" << std::endl; } // Cache the detection result ModelDetectionCache::cacheDetectionResult( info.fullPath, detection, pathType, detectionSource, info.modifiedAt); info.cacheValid = true; info.cacheModifiedAt = std::filesystem::file_time_type::clock::now(); info.cachePathType = pathType; info.useFolderBasedDetection = useFolderBasedDetection; info.detectionSource = detectionSource; } catch (const std::exception& e) { // If detection fails completely, default to SD1.5 info.architecture = "Stable Diffusion 1.5 (assumed)"; info.recommendedVAE = "vae-ft-mse-840000-ema-pruned.safetensors"; info.recommendedWidth = 512; info.recommendedHeight = 512; info.recommendedSteps = 20; info.recommendedSampler = "euler_a"; info.detectionSource = "fallback"; std::cerr << "Detection failed for " << info.name << ": " << e.what() << ", using SD1.5 defaults" << std::endl; } } } modelsMap[info.name] = info; } } } } } catch (const fs::filesystem_error& e) { // Silently handle filesystem errors } return !scanCancelled.load(); } }; ModelManager::ModelManager() : pImpl(std::make_unique()) { } ModelManager::~ModelManager() = default; bool ModelManager::scanModelsDirectory() { // Reset cancellation flag pImpl->scanCancelled.store(false); // Create temporary map to store scan results (outside of lock) std::map tempModels; // Scan all configured directories for all model types recursively // We scan each directory once and detect all model types within it std::set scannedDirectories; // Avoid scanning the same directory multiple times std::vector directoriesToScan; // Collect unique directories to scan std::vector allTypes = { ModelType::CHECKPOINT, ModelType::CONTROLNET, ModelType::LORA, ModelType::VAE, ModelType::TAESD, ModelType::ESRGAN, ModelType::EMBEDDING, ModelType::DIFFUSION_MODELS }; for (const auto& type : allTypes) { std::string dirPath = pImpl->getModelTypeDirectory(type); if (!dirPath.empty() && scannedDirectories.find(dirPath) == scannedDirectories.end()) { directoriesToScan.push_back(dirPath); scannedDirectories.insert(dirPath); } } // Also scan the base models directory if it exists and isn't already covered if (!pImpl->modelsDirectory.empty() && fs::exists(pImpl->modelsDirectory) && scannedDirectories.find(pImpl->modelsDirectory) == scannedDirectories.end()) { directoriesToScan.push_back(pImpl->modelsDirectory); } // Scan each unique directory recursively for all model types for (const auto& dirPath : directoriesToScan) { if (!pImpl->scanDirectory(dirPath, tempModels)) { return false; } } // Brief exclusive lock only to swap the data { std::unique_lock lock(pImpl->modelsMutex); pImpl->availableModels.swap(tempModels); } return true; } bool ModelManager::loadModel(const std::string& name, const std::string& path, ModelType type) { std::unique_lock lock(pImpl->modelsMutex); // Check if model is already loaded if (pImpl->loadedModels.find(name) != pImpl->loadedModels.end()) { return true; } // Check if file exists if (!fs::exists(path)) { std::cerr << "Model file does not exist: " << path << std::endl; return false; } // Create and initialize the stable-diffusion wrapper auto wrapper = std::make_unique(); // Set up generation parameters for model loading StableDiffusionWrapper::GenerationParams loadParams; loadParams.modelPath = path; loadParams.modelType = "f16"; // Default to f16 for better performance // Try to detect model type automatically for checkpoint and diffusion models if (type == ModelType::CHECKPOINT || type == ModelType::DIFFUSION_MODELS) { try { ModelDetectionResult detection = ModelDetector::detectModel(path); // Apply detected model type and parameters if (detection.architecture != ModelArchitecture::UNKNOWN) { std::cout << "Detected model architecture: " << detection.architectureName << " for " << name << std::endl; // Set model type from detection if available if (detection.suggestedParams.count("model_type")) { loadParams.modelType = detection.suggestedParams.at("model_type"); } // Set additional model paths based on detection // VAE is now optional for SD1x and SDXL models, but we still set it if available if (!detection.recommendedVAE.empty()) { loadParams.vaePath = detection.recommendedVAE; } // Apply other suggested parameters (only for fields that exist in GenerationParams) for (const auto& [param, value] : detection.suggestedParams) { if (param == "clip_l_path") { loadParams.clipLPath = value; } else if (param == "clip_g_path") { loadParams.clipGPath = value; } // Note: t5xxl_path and qwen2vl_path are not available in GenerationParams structure // These would need to be passed through the underlying stable-diffusion.cpp library directly } } else { std::cout << "Could not detect model architecture for " << name << ", using defaults" << std::endl; } } catch (const std::exception& e) { std::cerr << "Model detection failed for " << name << ": " << e.what() << " - using defaults" << std::endl; } } // Try to load the model if (!wrapper->loadModel(path, loadParams)) { std::cerr << "Failed to load model '" << name << "': " << wrapper->getLastError() << std::endl; return false; } pImpl->loadedModels[name] = std::move(wrapper); // Update model info if (pImpl->availableModels.find(name) != pImpl->availableModels.end()) { pImpl->availableModels[name].isLoaded = true; } else { // Create a new model info entry ModelInfo info; info.name = name; info.path = path; info.fullPath = fs::absolute(path).string(); info.type = type; info.isLoaded = true; info.sha256 = ""; info.description = ""; // Initialize description info.metadata = {}; // Initialize metadata try { info.fileSize = fs::file_size(path); info.modifiedAt = fs::last_write_time(path); info.createdAt = info.modifiedAt; // Use modified time as creation time for now } catch (const fs::filesystem_error& e) { std::cerr << "Error getting file info for " << path << ": " << e.what() << std::endl; info.fileSize = 0; info.modifiedAt = fs::file_time_type{}; info.createdAt = fs::file_time_type{}; } pImpl->availableModels[name] = info; } return true; } bool ModelManager::loadModel(const std::string& name) { std::string path; ModelType type; { std::unique_lock lock(pImpl->modelsMutex); // Check if model exists in available models auto it = pImpl->availableModels.find(name); if (it == pImpl->availableModels.end()) { std::cerr << "Model '" << name << "' not found in available models" << std::endl; return false; } // Check if already loaded if (pImpl->loadedModels.find(name) != pImpl->loadedModels.end()) { return true; } // Extract path and type while we have the lock path = it->second.path; type = it->second.type; } // Release lock here // Load the model without holding the lock return loadModel(name, path, type); } bool ModelManager::unloadModel(const std::string& name) { std::unique_lock lock(pImpl->modelsMutex); // Check if model is loaded auto loadedIt = pImpl->loadedModels.find(name); if (loadedIt == pImpl->loadedModels.end()) { return false; } // Unload the model properly if (loadedIt->second) { loadedIt->second->unloadModel(); } pImpl->loadedModels.erase(loadedIt); // Update model info auto availableIt = pImpl->availableModels.find(name); if (availableIt != pImpl->availableModels.end()) { availableIt->second.isLoaded = false; } return true; } StableDiffusionWrapper* ModelManager::getModel(const std::string& name) { std::shared_lock lock(pImpl->modelsMutex); auto it = pImpl->loadedModels.find(name); if (it == pImpl->loadedModels.end()) { return nullptr; } return it->second.get(); } std::map ModelManager::getAllModels() const { std::shared_lock lock(pImpl->modelsMutex); return pImpl->availableModels; } std::vector ModelManager::getModelsByType(ModelType type) const { std::shared_lock lock(pImpl->modelsMutex); std::vector result; for (const auto& pair : pImpl->availableModels) { if (pair.second.type == type) { result.push_back(pair.second); } } return result; } ModelManager::ModelInfo ModelManager::getModelInfo(const std::string& name) const { std::unique_lock lock(pImpl->modelsMutex); auto it = pImpl->availableModels.find(name); if (it == pImpl->availableModels.end()) { return ModelInfo{}; // Return empty ModelInfo if not found } return it->second; } bool ModelManager::isModelLoaded(const std::string& name) const { std::unique_lock lock(pImpl->modelsMutex); auto it = pImpl->loadedModels.find(name); return it != pImpl->loadedModels.end(); } size_t ModelManager::getLoadedModelsCount() const { std::shared_lock lock(pImpl->modelsMutex); return pImpl->loadedModels.size(); } size_t ModelManager::getAvailableModelsCount() const { std::shared_lock lock(pImpl->modelsMutex); return pImpl->availableModels.size(); } void ModelManager::setModelsDirectory(const std::string& path) { pImpl->modelsDirectory = path; } std::string ModelManager::getModelsDirectory() const { return pImpl->modelsDirectory; } std::string ModelManager::modelTypeToString(ModelType type) { switch (type) { case ModelType::LORA: return "lora"; case ModelType::CHECKPOINT: return "checkpoint"; case ModelType::VAE: return "vae"; case ModelType::PRESETS: return "presets"; case ModelType::PROMPTS: return "prompts"; case ModelType::NEG_PROMPTS: return "neg_prompts"; case ModelType::TAESD: return "taesd"; case ModelType::ESRGAN: return "esrgan"; case ModelType::CONTROLNET: return "controlnet"; case ModelType::UPSCALER: return "upscaler"; case ModelType::EMBEDDING: return "embedding"; case ModelType::DIFFUSION_MODELS: return "checkpoint"; default: return "unknown"; } } ModelType ModelManager::stringToModelType(const std::string& typeStr) { std::string lowerType = typeStr; std::transform(lowerType.begin(), lowerType.end(), lowerType.begin(), ::tolower); if (lowerType == "lora") { return ModelType::LORA; } else if (lowerType == "checkpoint" || lowerType == "stable-diffusion") { return ModelType::CHECKPOINT; } else if (lowerType == "vae") { return ModelType::VAE; } else if (lowerType == "presets") { return ModelType::PRESETS; } else if (lowerType == "prompts") { return ModelType::PROMPTS; } else if (lowerType == "neg_prompts" || lowerType == "negative_prompts") { return ModelType::NEG_PROMPTS; } else if (lowerType == "taesd") { return ModelType::TAESD; } else if (lowerType == "esrgan") { return ModelType::ESRGAN; } else if (lowerType == "controlnet") { return ModelType::CONTROLNET; } else if (lowerType == "upscaler") { return ModelType::UPSCALER; } else if (lowerType == "embedding" || lowerType == "textual-inversion") { return ModelType::EMBEDDING; } return ModelType::NONE; } bool ModelManager::setModelTypeDirectory(ModelType type, const std::string& path) { std::unique_lock lock(pImpl->modelsMutex); if (!pImpl->validateDirectory(path)) { return false; } pImpl->modelTypeDirectories[type] = path; return true; } std::string ModelManager::getModelTypeDirectory(ModelType type) const { std::shared_lock lock(pImpl->modelsMutex); return pImpl->getModelTypeDirectory(type); } bool ModelManager::setAllModelTypeDirectories(const std::map& directories) { std::unique_lock lock(pImpl->modelsMutex); // Validate all directories first for (const auto& [type, path] : directories) { if (!path.empty() && !pImpl->validateDirectory(path)) { return false; } } // Set all directories pImpl->modelTypeDirectories = directories; return true; } std::map ModelManager::getAllModelTypeDirectories() const { std::shared_lock lock(pImpl->modelsMutex); return pImpl->modelTypeDirectories; } // Legacy resetToLegacyDirectories method removed // Using explicit directory configuration only bool ModelManager::configureFromServerConfig(const ServerConfig& config) { std::unique_lock lock(pImpl->modelsMutex); // Set the base models directory pImpl->modelsDirectory = config.modelsDir; // Always use explicit directory configuration std::map directories; if (!config.checkpoints.empty()) { directories[ModelType::CHECKPOINT] = config.checkpoints; } if (!config.controlnetDir.empty()) { directories[ModelType::CONTROLNET] = config.controlnetDir; } if (!config.embeddingsDir.empty()) { directories[ModelType::EMBEDDING] = config.embeddingsDir; } if (!config.esrganDir.empty()) { directories[ModelType::ESRGAN] = config.esrganDir; } if (!config.loraDir.empty()) { directories[ModelType::LORA] = config.loraDir; } if (!config.taesdDir.empty()) { directories[ModelType::TAESD] = config.taesdDir; } if (!config.vaeDir.empty()) { directories[ModelType::VAE] = config.vaeDir; } if (!config.diffusionModelsDir.empty()) { directories[ModelType::DIFFUSION_MODELS] = config.diffusionModelsDir; } // Validate all directories first for (const auto& [type, path] : directories) { if (!path.empty() && !pImpl->validateDirectory(path)) { return false; } } // Set all directories (inlined to avoid deadlock from calling setAllModelTypeDirectories) pImpl->modelTypeDirectories = directories; return true; } void ModelManager::cancelScan() { pImpl->scanCancelled.store(true); } // SHA256 Hashing Implementation std::string ModelManager::computeModelHash(const std::string& modelName) { std::shared_lock lock(pImpl->modelsMutex); auto it = pImpl->availableModels.find(modelName); if (it == pImpl->availableModels.end()) { std::cerr << "Model not found: " << modelName << std::endl; return ""; } std::string filePath = it->second.fullPath; lock.unlock(); std::ifstream file(filePath, std::ios::binary); if (!file.is_open()) { std::cerr << "Failed to open file for hashing: " << filePath << std::endl; return ""; } // Create and initialize EVP context for SHA256 EVP_MD_CTX* mdctx = EVP_MD_CTX_new(); if (mdctx == nullptr) { std::cerr << "Failed to create EVP context" << std::endl; return ""; } if (EVP_DigestInit_ex(mdctx, EVP_sha256(), nullptr) != 1) { std::cerr << "Failed to initialize SHA256 digest" << std::endl; EVP_MD_CTX_free(mdctx); return ""; } const size_t bufferSize = 8192; char buffer[bufferSize]; std::cout << "Computing SHA256 for: " << modelName << std::endl; size_t totalRead = 0; size_t lastReportedMB = 0; while (file.read(buffer, bufferSize) || file.gcount() > 0) { size_t bytesRead = file.gcount(); if (EVP_DigestUpdate(mdctx, buffer, bytesRead) != 1) { std::cerr << "Failed to update digest" << std::endl; EVP_MD_CTX_free(mdctx); return ""; } totalRead += bytesRead; // Progress reporting every 100MB size_t currentMB = totalRead / (1024 * 1024); if (currentMB >= lastReportedMB + 100) { std::cout << " Hashed " << currentMB << " MB..." << std::endl; lastReportedMB = currentMB; } } file.close(); unsigned char hash[EVP_MAX_MD_SIZE]; unsigned int hashLen = 0; if (EVP_DigestFinal_ex(mdctx, hash, &hashLen) != 1) { std::cerr << "Failed to finalize digest" << std::endl; EVP_MD_CTX_free(mdctx); return ""; } EVP_MD_CTX_free(mdctx); // Convert to hex string std::ostringstream oss; for (unsigned int i = 0; i < hashLen; i++) { oss << std::hex << std::setw(2) << std::setfill('0') << static_cast(hash[i]); } std::string hashStr = oss.str(); std::cout << "Hash computed: " << hashStr.substr(0, 16) << "..." << std::endl; return hashStr; } std::string ModelManager::loadModelHashFromFile(const std::string& modelName) { std::shared_lock lock(pImpl->modelsMutex); auto it = pImpl->availableModels.find(modelName); if (it == pImpl->availableModels.end()) { return ""; } std::string jsonPath = it->second.fullPath + ".json"; lock.unlock(); if (!fs::exists(jsonPath)) { return ""; } try { std::ifstream jsonFile(jsonPath); if (!jsonFile.is_open()) { return ""; } nlohmann::json j; jsonFile >> j; jsonFile.close(); if (j.contains("sha256") && j["sha256"].is_string()) { return j["sha256"].get(); } } catch (const std::exception& e) { std::cerr << "Error loading hash from JSON: " << e.what() << std::endl; } return ""; } bool ModelManager::saveModelHashToFile(const std::string& modelName, const std::string& hash) { std::shared_lock lock(pImpl->modelsMutex); auto it = pImpl->availableModels.find(modelName); if (it == pImpl->availableModels.end()) { return false; } std::string jsonPath = it->second.fullPath + ".json"; size_t fileSize = it->second.fileSize; lock.unlock(); try { nlohmann::json j; j["sha256"] = hash; j["file_size"] = fileSize; j["computed_at"] = std::chrono::system_clock::now().time_since_epoch().count(); std::ofstream jsonFile(jsonPath); if (!jsonFile.is_open()) { std::cerr << "Failed to open file for writing: " << jsonPath << std::endl; return false; } jsonFile << j.dump(2); jsonFile.close(); std::cout << "Saved hash to: " << jsonPath << std::endl; return true; } catch (const std::exception& e) { std::cerr << "Error saving hash to JSON: " << e.what() << std::endl; return false; } } std::string ModelManager::findModelByHash(const std::string& hash) { if (hash.length() < 10) { std::cerr << "Hash must be at least 10 characters" << std::endl; return ""; } std::shared_lock lock(pImpl->modelsMutex); for (const auto& [name, info] : pImpl->availableModels) { if (info.sha256.empty()) { continue; } // Support full or partial match (minimum 10 chars) if (info.sha256 == hash || info.sha256.substr(0, hash.length()) == hash) { return name; } } return ""; } std::string ModelManager::ensureModelHash(const std::string& modelName, bool forceCompute) { // Try to load existing hash if not forcing recompute if (!forceCompute) { std::string existingHash = loadModelHashFromFile(modelName); if (!existingHash.empty()) { // Update in-memory model info std::unique_lock lock(pImpl->modelsMutex); auto it = pImpl->availableModels.find(modelName); if (it != pImpl->availableModels.end()) { it->second.sha256 = existingHash; } return existingHash; } } // Compute new hash std::string hash = computeModelHash(modelName); if (hash.empty()) { return ""; } // Save to file saveModelHashToFile(modelName, hash); // Update in-memory model info std::unique_lock lock(pImpl->modelsMutex); auto it = pImpl->availableModels.find(modelName); if (it != pImpl->availableModels.end()) { it->second.sha256 = hash; } return hash; } // ModelPathSelector Implementation std::string ModelManager::ModelPathSelector::selectPathType( const std::string& modelPath, const std::string& checkpointsDir, const std::string& diffusionModelsDir) { std::cout << "Selecting path type for model: " << modelPath << std::endl; std::cout << "Checkpoints directory: " << checkpointsDir << std::endl; std::cout << "Diffusion models directory: " << diffusionModelsDir << std::endl; // Check if model is in diffusion_models directory first (priority) if (!diffusionModelsDir.empty() && isModelInDirectory(modelPath, diffusionModelsDir)) { std::cout << "Model is in diffusion_models directory, using diffusion_model_path" << std::endl; return "diffusion_model_path"; } // Check if model is in checkpoints directory if (!checkpointsDir.empty() && isModelInDirectory(modelPath, checkpointsDir)) { std::cout << "Model is in checkpoints directory, using model_path" << std::endl; return "model_path"; } // Fallback: use directory name detection std::filesystem::path modelFilePath(modelPath); std::filesystem::path parentDir = modelFilePath.parent_path(); if (parentDir.filename().string() == "diffusion_models") { std::cout << "Model is in diffusion_models directory (detected from path), using diffusion_model_path" << std::endl; return "diffusion_model_path"; } else if (parentDir.filename().string() == "checkpoints") { std::cout << "Model is in checkpoints directory (detected from path), using model_path" << std::endl; return "model_path"; } // Default fallback for unknown locations std::cout << "Model location unknown, defaulting to model_path for backward compatibility" << std::endl; return "model_path"; } bool ModelManager::ModelPathSelector::isModelInDirectory(const std::string& modelPath, const std::string& directory) { if (modelPath.empty() || directory.empty()) { return false; } try { std::filesystem::path absoluteModelPath = std::filesystem::absolute(modelPath).lexically_normal(); std::filesystem::path absoluteDirPath = std::filesystem::absolute(directory).lexically_normal(); // Get relative path from directory to model auto relativePath = absoluteModelPath.lexically_relative(absoluteDirPath); std::string relPathStr = relativePath.string(); // Check if the relative path doesn't start with ".." and is not empty bool isUnderDirectory = !relPathStr.empty() && relPathStr.substr(0, 2) != ".." && relPathStr[0] != '/'; return isUnderDirectory; } catch (const std::filesystem::filesystem_error& e) { std::cerr << "Error checking if model is in directory: " << e.what() << std::endl; return false; } } // ModelDetectionCache Implementation std::map ModelManager::ModelDetectionCache::cache_; std::mutex ModelManager::ModelDetectionCache::cacheMutex_; ModelManager::ModelDetectionCache::CacheEntry ModelManager::ModelDetectionCache::getCachedResult( const std::string& modelPath, const std::filesystem::file_time_type& currentModifiedTime) { std::lock_guard lock(cacheMutex_); auto it = cache_.find(modelPath); if (it == cache_.end()) { return CacheEntry{}; // Return invalid entry if not found } const CacheEntry& entry = it->second; // Check if cache is still valid (file hasn't been modified) if (entry.fileModifiedAt == currentModifiedTime && entry.isValid) { std::cout << "Using cached detection result for: " << modelPath << std::endl; return entry; } // Cache is stale, remove it cache_.erase(it); std::cout << "Cache entry expired for: " << modelPath << std::endl; return CacheEntry{}; // Return invalid entry } void ModelManager::ModelDetectionCache::cacheDetectionResult( const std::string& modelPath, const ModelDetectionResult& detection, const std::string& pathType, const std::string& detectionSource, const std::filesystem::file_time_type& fileModifiedTime) { std::lock_guard lock(cacheMutex_); CacheEntry entry; entry.architecture = detection.architectureName; entry.recommendedVAE = detection.recommendedVAE; entry.recommendedWidth = 0; entry.recommendedHeight = 0; entry.recommendedSteps = 0; entry.recommendedSampler = ""; entry.pathType = pathType; entry.detectionSource = detectionSource; entry.cachedAt = std::filesystem::file_time_type::clock::now(); entry.fileModifiedAt = fileModifiedTime; entry.isValid = true; // Parse recommended parameters for (const auto& [param, value] : detection.suggestedParams) { if (param == "width") { entry.recommendedWidth = std::stoi(value); } else if (param == "height") { entry.recommendedHeight = std::stoi(value); } else if (param == "steps") { entry.recommendedSteps = std::stoi(value); } else if (param == "sampler") { entry.recommendedSampler = value; } } // Build list of required models // Note: VAE is now optional for SD1x and SDXL models, so we don't add it to requiredModels // The VAE will still be recommended but not required if (detection.suggestedParams.count("clip_l_required")) { entry.requiredModels.push_back("CLIP-L: " + detection.suggestedParams.at("clip_l_required")); } if (detection.suggestedParams.count("clip_g_required")) { entry.requiredModels.push_back("CLIP-G: " + detection.suggestedParams.at("clip_g_required")); } if (detection.suggestedParams.count("t5xxl_required")) { entry.requiredModels.push_back("T5XXL: " + detection.suggestedParams.at("t5xxl_required")); } if (detection.suggestedParams.count("qwen2vl_required")) { entry.requiredModels.push_back("Qwen2-VL: " + detection.suggestedParams.at("qwen2vl_required")); } if (detection.suggestedParams.count("qwen2vl_vision_required")) { entry.requiredModels.push_back("Qwen2-VL-Vision: " + detection.suggestedParams.at("qwen2vl_vision_required")); } // Check for missing models and store in cache if (!entry.requiredModels.empty()) { // Create a temporary ModelManager instance to check existence // Note: This is a simplified approach - in a production environment, // we might want to pass the models directory or use a different approach std::string baseModelsDir = "/data/SD_MODELS"; for (const auto& requiredModel : entry.requiredModels) { size_t colonPos = requiredModel.find(':'); if (colonPos == std::string::npos) continue; std::string modelType = requiredModel.substr(0, colonPos); std::string modelName = requiredModel.substr(colonPos + 1); // Trim whitespace modelType.erase(0, modelType.find_first_not_of(" \t")); modelType.erase(modelType.find_last_not_of(" \t") + 1); modelName.erase(0, modelName.find_first_not_of(" \t")); modelName.erase(modelName.find_last_not_of(" \t") + 1); // Determine the appropriate subdirectory std::string subdirectory; if (modelType == "VAE") { subdirectory = "vae"; } else if (modelType == "CLIP-L" || modelType == "CLIP-G") { subdirectory = "clip"; } else if (modelType == "T5XXL") { subdirectory = "t5xxl"; } else if (modelType == "CLIP-Vision") { subdirectory = "clip"; } else if (modelType == "Qwen2-VL" || modelType == "Qwen2-VL-Vision") { subdirectory = "qwen2vl"; } // Check if model exists std::string fullPath; if (!subdirectory.empty()) { fullPath = baseModelsDir + "/" + subdirectory + "/" + modelName; } else { fullPath = baseModelsDir + "/" + modelName; } try { if (!fs::exists(fullPath) || !fs::is_regular_file(fullPath)) { entry.missingModels.push_back(requiredModel); } } catch (const fs::filesystem_error&) { // If we can't check, assume it's missing entry.missingModels.push_back(requiredModel); } } } cache_[modelPath] = entry; std::cout << "Cached detection result for: " << modelPath << " (source: " << detectionSource << ", path type: " << pathType << ")" << std::endl; } void ModelManager::ModelDetectionCache::invalidateCache(const std::string& modelPath) { std::lock_guard lock(cacheMutex_); auto it = cache_.find(modelPath); if (it != cache_.end()) { cache_.erase(it); std::cout << "Invalidated cache for: " << modelPath << std::endl; } } void ModelManager::ModelDetectionCache::clearAllCache() { std::lock_guard lock(cacheMutex_); size_t count = cache_.size(); cache_.clear(); std::cout << "Cleared " << count << " cache entries" << std::endl; } std::vector ModelManager::checkRequiredModelsExistence(const std::vector& requiredModels) { std::vector modelDetails; // Base models directory according to project guidelines std::string baseModelsDir = "/data/SD_MODELS"; for (const auto& requiredModel : requiredModels) { ModelDetails details; // Parse the required model string (format: "TYPE: filename") size_t colonPos = requiredModel.find(':'); if (colonPos == std::string::npos) { // Invalid format, skip continue; } std::string modelType = requiredModel.substr(0, colonPos); std::string modelName = requiredModel.substr(colonPos + 1); // Trim whitespace modelType.erase(0, modelType.find_first_not_of(" \t")); modelType.erase(modelType.find_last_not_of(" \t") + 1); modelName.erase(0, modelName.find_first_not_of(" \t")); modelName.erase(modelName.find_last_not_of(" \t") + 1); details.name = modelName; details.type = modelType; details.is_required = true; details.is_recommended = false; details.exists = false; details.file_size = 0; details.path = ""; details.sha256 = ""; // Determine the appropriate subdirectory based on model type std::string subdirectory; if (modelType == "VAE") { subdirectory = "vae"; } else if (modelType == "CLIP-L" || modelType == "CLIP-G") { subdirectory = "clip"; } else if (modelType == "T5XXL") { subdirectory = "t5xxl"; } else if (modelType == "CLIP-Vision") { subdirectory = "clip"; } else if (modelType == "Qwen2-VL" || modelType == "Qwen2-VL-Vision") { subdirectory = "qwen2vl"; } else { // For unknown types, check in root directory subdirectory = ""; } // Construct the full path to check std::string fullPath; if (!subdirectory.empty()) { fullPath = baseModelsDir + "/" + subdirectory + "/" + modelName; } else { fullPath = baseModelsDir + "/" + modelName; } // Check if the file exists try { if (fs::exists(fullPath) && fs::is_regular_file(fullPath)) { details.exists = true; details.path = fs::absolute(fullPath).string(); details.file_size = fs::file_size(fullPath); // Try to get cached hash std::string jsonPath = fullPath + ".json"; if (fs::exists(jsonPath)) { try { std::ifstream jsonFile(jsonPath); if (jsonFile.is_open()) { nlohmann::json j; jsonFile >> j; jsonFile.close(); if (j.contains("sha256") && j["sha256"].is_string()) { details.sha256 = j["sha256"].get(); } } } catch (const std::exception& e) { std::cerr << "Error loading hash for " << fullPath << ": " << e.what() << std::endl; } } std::cout << "Found required model: " << modelType << " at " << details.path << std::endl; } else { std::cout << "Missing required model: " << modelType << " - expected at " << fullPath << std::endl; } } catch (const fs::filesystem_error& e) { std::cerr << "Error checking model existence for " << fullPath << ": " << e.what() << std::endl; } modelDetails.push_back(details); } return modelDetails; }