#include "model_manager.h" #include "model_detector.h" #include "stable_diffusion_wrapper.h" #include #include #include #include #include #include #include #include #include #include #include #include namespace fs = std::filesystem; // File extension mappings for each model type const std::vector CHECKPOINT_FILE_EXTENSIONS = {"safetensors", "ckpt", "gguf"}; const std::vector EMBEDDING_FILE_EXTENSIONS = {"safetensors", "pt"}; const std::vector LORA_FILE_EXTENSIONS = {"safetensors", "ckpt"}; const std::vector VAE_FILE_EXTENSIONS = {"safetensors", "pt", "ckpt", "gguf"}; const std::vector TAESD_FILE_EXTENSIONS = {"safetensors", "pth", "gguf"}; const std::vector ESRGAN_FILE_EXTENSIONS = {"pth", "pt"}; const std::vector CONTROLNET_FILE_EXTENSIONS = {"safetensors", "pth"}; class ModelManager::Impl { public: std::string modelsDirectory = "./models"; std::map modelTypeDirectories; std::map availableModels; std::map> loadedModels; mutable std::shared_mutex modelsMutex; std::atomic scanCancelled{false}; bool legacyMode = true; /** * @brief Validate a directory path * * @param path The directory path to validate * @return true if the directory exists and is valid, false otherwise */ bool validateDirectory(const std::string& path) const { if (path.empty()) { return false; } std::filesystem::path dirPath(path); if (!std::filesystem::exists(dirPath)) { std::cerr << "Directory does not exist: " << path << std::endl; return false; } if (!std::filesystem::is_directory(dirPath)) { std::cerr << "Path is not a directory: " << path << std::endl; return false; } return true; } /** * @brief Get default directory name for a model type * * @param type The model type * @return std::string Default directory name */ std::string getDefaultDirectoryName(ModelType type) const { switch (type) { case ModelType::CHECKPOINT: return "checkpoints"; case ModelType::CONTROLNET: return "controlnet"; case ModelType::EMBEDDING: return "embeddings"; case ModelType::ESRGAN: return "esrgan"; case ModelType::LORA: return "lora"; case ModelType::TAESD: return "taesd"; case ModelType::VAE: return "vae"; default: return ""; } } /** * @brief Get directory path for a model type * * @param type The model type * @return std::string Directory path, empty if not set */ std::string getModelTypeDirectory(ModelType type) const { auto it = modelTypeDirectories.find(type); if (it != modelTypeDirectories.end()) { return it->second; } // If in legacy mode, construct default path if (legacyMode) { std::string defaultDir = getDefaultDirectoryName(type); if (!defaultDir.empty()) { return modelsDirectory + "/" + defaultDir; } } return ""; } /** * @brief Get file extensions for a specific model type * * @param type The model type * @return const std::vector& Vector of file extensions */ const std::vector& getFileExtensions(ModelType type) const { switch (type) { case ModelType::CHECKPOINT: return CHECKPOINT_FILE_EXTENSIONS; case ModelType::EMBEDDING: return EMBEDDING_FILE_EXTENSIONS; case ModelType::LORA: return LORA_FILE_EXTENSIONS; case ModelType::VAE: return VAE_FILE_EXTENSIONS; case ModelType::TAESD: return TAESD_FILE_EXTENSIONS; case ModelType::ESRGAN: return ESRGAN_FILE_EXTENSIONS; case ModelType::CONTROLNET: return CONTROLNET_FILE_EXTENSIONS; default: static const std::vector empty; return empty; } } /** * @brief Check if a file extension matches a model type * * @param extension The file extension * @param type The model type * @return true if the extension matches the model type */ bool isExtensionMatch(const std::string& extension, ModelType type) const { const auto& extensions = getFileExtensions(type); return std::find(extensions.begin(), extensions.end(), extension) != extensions.end(); } /** * @brief Determine model type based on file path and extension * * @param filePath The file path * @return ModelType The determined model type */ ModelType determineModelType(const fs::path& filePath) const { std::string extension = filePath.extension().string(); if (extension.empty()) { return ModelType::NONE; } // Remove the dot from extension if (extension[0] == '.') { extension = extension.substr(1); } // Convert to lowercase for comparison std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower); // Check if the file resides under a directory registered for a given ModelType fs::path absoluteFilePath = fs::absolute(filePath); // First check configured directories (if any) for (const auto& [type, directory] : modelTypeDirectories) { if (!directory.empty()) { fs::path absoluteDirPath = fs::absolute(directory).lexically_normal(); fs::path normalizedFilePath = absoluteFilePath.lexically_normal(); // Check if the file is under this directory (directly or in subdirectories) // Get the relative path from directory to file auto relativePath = normalizedFilePath.lexically_relative(absoluteDirPath); // If relative path doesn't start with "..", then file is under the directory std::string relPathStr = relativePath.string(); bool isUnderDirectory = !relPathStr.empty() && relPathStr.substr(0, 2) != ".." && relPathStr[0] != '/'; if (isUnderDirectory && isExtensionMatch(extension, type)) { return type; } } } // If in legacy mode or no configured directories matched, check default directory structure if (legacyMode || modelTypeDirectories.empty()) { std::string parentPath = filePath.parent_path().filename().string(); std::transform(parentPath.begin(), parentPath.end(), parentPath.begin(), ::tolower); // Check default directory names if (parentPath == "checkpoints" || parentPath == "stable-diffusion") { if (isExtensionMatch(extension, ModelType::CHECKPOINT)) { return ModelType::CHECKPOINT; } } else if (parentPath == "controlnet") { if (isExtensionMatch(extension, ModelType::CONTROLNET)) { return ModelType::CONTROLNET; } } else if (parentPath == "lora") { if (isExtensionMatch(extension, ModelType::LORA)) { return ModelType::LORA; } } else if (parentPath == "vae") { if (isExtensionMatch(extension, ModelType::VAE)) { return ModelType::VAE; } } else if (parentPath == "taesd") { if (isExtensionMatch(extension, ModelType::TAESD)) { return ModelType::TAESD; } } else if (parentPath == "esrgan" || parentPath == "upscaler") { if (isExtensionMatch(extension, ModelType::ESRGAN)) { return ModelType::ESRGAN; } } else if (parentPath == "embeddings" || parentPath == "textual-inversion") { if (isExtensionMatch(extension, ModelType::EMBEDDING)) { return ModelType::EMBEDDING; } } } // Fall back to extension-based detection // Only return a model type if the extension matches expected extensions for that type if (isExtensionMatch(extension, ModelType::CHECKPOINT)) { return ModelType::CHECKPOINT; } else if (isExtensionMatch(extension, ModelType::LORA)) { return ModelType::LORA; } else if (isExtensionMatch(extension, ModelType::VAE)) { return ModelType::VAE; } else if (isExtensionMatch(extension, ModelType::TAESD)) { return ModelType::TAESD; } else if (isExtensionMatch(extension, ModelType::ESRGAN)) { return ModelType::ESRGAN; } else if (isExtensionMatch(extension, ModelType::CONTROLNET)) { return ModelType::CONTROLNET; } else if (isExtensionMatch(extension, ModelType::EMBEDDING)) { return ModelType::EMBEDDING; } return ModelType::NONE; } /** * @brief Get file information with timeout * * @param filePath The file path to get info for * @param timeoutMs Timeout in milliseconds * @return std::pair> Success flag and file info */ std::pair> getFileInfoWithTimeout( const fs::path& filePath, int timeoutMs = 5000) { auto future = std::async(std::launch::async, [&filePath]() -> std::pair { try { uintmax_t fileSize = fs::file_size(filePath); fs::file_time_type modifiedAt = fs::last_write_time(filePath); return {fileSize, modifiedAt}; } catch (const fs::filesystem_error&) { return {0, fs::file_time_type{}}; } }); if (future.wait_for(std::chrono::milliseconds(timeoutMs)) == std::future_status::timeout) { std::cerr << "Timeout getting file info for " << filePath << std::endl; return {false, {0, fs::file_time_type{}}}; } return {true, future.get()}; } /** * @brief Scan a directory for models of a specific type (without holding mutex) * * @param directory The directory to scan * @param type The model type to look for * @param modelsMap Reference to the map to store results * @return bool True if scanning completed without cancellation */ bool scanDirectory(const fs::path& directory, ModelType type, std::map& modelsMap) { if (scanCancelled.load()) { return false; } if (!fs::exists(directory) || !fs::is_directory(directory)) { return true; } try { for (const auto& entry : fs::recursive_directory_iterator(directory)) { if (scanCancelled.load()) { return false; } if (entry.is_regular_file()) { fs::path filePath = entry.path(); ModelType detectedType = determineModelType(filePath); // Only add files that have a valid model type (not NONE) if (detectedType != ModelType::NONE && (type == ModelType::NONE || detectedType == type)) { ModelInfo info; // Calculate relative path from the scanned directory (not base models directory) fs::path relativePath = fs::relative(filePath, directory); std::string modelName = relativePath.string(); // Check if model already exists to avoid duplicates if (modelsMap.find(modelName) == modelsMap.end()) { info.name = modelName; info.path = filePath.string(); info.fullPath = fs::absolute(filePath).string(); info.type = detectedType; info.isLoaded = false; info.description = ""; // Initialize description info.metadata = {}; // Initialize metadata // Get file info with timeout auto [success, fileInfo] = getFileInfoWithTimeout(filePath); if (success) { info.fileSize = fileInfo.first; info.modifiedAt = fileInfo.second; info.createdAt = fileInfo.second; // Use modified time as creation time for now } else { info.fileSize = 0; info.modifiedAt = fs::file_time_type{}; info.createdAt = fs::file_time_type{}; } // Try to load cached hash from .json file std::string hashFile = info.fullPath + ".json"; if (fs::exists(hashFile)) { try { std::ifstream file(hashFile); nlohmann::json hashData = nlohmann::json::parse(file); if (hashData.contains("sha256") && hashData["sha256"].is_string()) { info.sha256 = hashData["sha256"]; } else { info.sha256 = ""; } } catch (...) { info.sha256 = ""; // If parsing fails, leave empty } } else { info.sha256 = ""; // No cached hash file } // Detect architecture for checkpoint models if (detectedType == ModelType::CHECKPOINT) { try { ModelDetectionResult detection = ModelDetector::detectModel(info.fullPath); // For .ckpt files that can't be detected, default to SD1.5 if (detection.architecture == ModelArchitecture::UNKNOWN && (filePath.extension() == ".ckpt" || filePath.extension() == ".pt")) { info.architecture = "Stable Diffusion 1.5 (assumed)"; info.recommendedVAE = "vae-ft-mse-840000-ema-pruned.safetensors"; info.recommendedWidth = 512; info.recommendedHeight = 512; info.recommendedSteps = 20; info.recommendedSampler = "euler_a"; } else { info.architecture = detection.architectureName; info.recommendedVAE = detection.recommendedVAE; // Parse recommended parameters if (detection.suggestedParams.count("width")) { info.recommendedWidth = std::stoi(detection.suggestedParams["width"]); } if (detection.suggestedParams.count("height")) { info.recommendedHeight = std::stoi(detection.suggestedParams["height"]); } if (detection.suggestedParams.count("steps")) { info.recommendedSteps = std::stoi(detection.suggestedParams["steps"]); } if (detection.suggestedParams.count("sampler")) { info.recommendedSampler = detection.suggestedParams["sampler"]; } } // Build list of required models based on architecture if (detection.needsVAE && !detection.recommendedVAE.empty()) { info.requiredModels.push_back("VAE: " + detection.recommendedVAE); } // Add CLIP-L if required if (detection.suggestedParams.count("clip_l_required")) { info.requiredModels.push_back("CLIP-L: " + detection.suggestedParams.at("clip_l_required")); } // Add CLIP-G if required if (detection.suggestedParams.count("clip_g_required")) { info.requiredModels.push_back("CLIP-G: " + detection.suggestedParams.at("clip_g_required")); } // Add T5XXL if required if (detection.suggestedParams.count("t5xxl_required")) { info.requiredModels.push_back("T5XXL: " + detection.suggestedParams.at("t5xxl_required")); } // Add Qwen models if required if (detection.suggestedParams.count("qwen2vl_required")) { info.requiredModels.push_back("Qwen2-VL: " + detection.suggestedParams.at("qwen2vl_required")); } if (detection.suggestedParams.count("qwen2vl_vision_required")) { info.requiredModels.push_back("Qwen2-VL-Vision: " + detection.suggestedParams.at("qwen2vl_vision_required")); } } catch (const std::exception& e) { // If detection fails completely, default to SD1.5 info.architecture = "Stable Diffusion 1.5 (assumed)"; info.recommendedVAE = "vae-ft-mse-840000-ema-pruned.safetensors"; info.recommendedWidth = 512; info.recommendedHeight = 512; info.recommendedSteps = 20; info.recommendedSampler = "euler_a"; } } modelsMap[info.name] = info; } } } } } catch (const fs::filesystem_error& e) { // Silently handle filesystem errors } return !scanCancelled.load(); } }; ModelManager::ModelManager() : pImpl(std::make_unique()) { } ModelManager::~ModelManager() = default; bool ModelManager::scanModelsDirectory() { // Reset cancellation flag pImpl->scanCancelled.store(false); // Create temporary map to store scan results (outside of lock) std::map tempModels; if (pImpl->legacyMode) { // Legacy mode: scan the models directory itself and its subdirectories fs::path modelsPath(pImpl->modelsDirectory); if (!fs::exists(modelsPath) || !fs::is_directory(modelsPath)) { std::cerr << "Models directory does not exist: " << pImpl->modelsDirectory << std::endl; return false; } // First, scan the main models directory itself for any model files // This handles the case where models are directly in the specified directory if (!pImpl->scanDirectory(modelsPath, ModelType::NONE, tempModels)) { return false; } // Then scan known subdirectories for organized models std::vector> directoriesToScan = { {modelsPath / "stable-diffusion", ModelType::CHECKPOINT}, {modelsPath / "controlnet", ModelType::CONTROLNET}, {modelsPath / "lora", ModelType::LORA}, {modelsPath / "vae", ModelType::VAE}, {modelsPath / "taesd", ModelType::TAESD}, {modelsPath / "esrgan", ModelType::ESRGAN}, {modelsPath / "upscaler", ModelType::ESRGAN}, {modelsPath / "embeddings", ModelType::EMBEDDING}, {modelsPath / "textual-inversion", ModelType::EMBEDDING}, {modelsPath / "checkpoints", ModelType::CHECKPOINT}, // Also scan checkpoints subdirectory {modelsPath / "other", ModelType::NONE} // Scan for any type }; for (const auto& [dirPath, type] : directoriesToScan) { if (!pImpl->scanDirectory(dirPath, type, tempModels)) { return false; } } } else { // Explicit mode: scan configured directories for each model type std::vector> directoriesToScan = { {ModelType::CHECKPOINT, pImpl->getModelTypeDirectory(ModelType::CHECKPOINT)}, {ModelType::CONTROLNET, pImpl->getModelTypeDirectory(ModelType::CONTROLNET)}, {ModelType::LORA, pImpl->getModelTypeDirectory(ModelType::LORA)}, {ModelType::VAE, pImpl->getModelTypeDirectory(ModelType::VAE)}, {ModelType::TAESD, pImpl->getModelTypeDirectory(ModelType::TAESD)}, {ModelType::ESRGAN, pImpl->getModelTypeDirectory(ModelType::ESRGAN)}, {ModelType::EMBEDDING, pImpl->getModelTypeDirectory(ModelType::EMBEDDING)} }; for (const auto& [type, dirPath] : directoriesToScan) { if (!dirPath.empty()) { if (!pImpl->scanDirectory(dirPath, type, tempModels)) { return false; } } } } // Brief exclusive lock only to swap the data { std::unique_lock lock(pImpl->modelsMutex); pImpl->availableModels.swap(tempModels); } return true; } bool ModelManager::loadModel(const std::string& name, const std::string& path, ModelType type) { std::unique_lock lock(pImpl->modelsMutex); // Check if model is already loaded if (pImpl->loadedModels.find(name) != pImpl->loadedModels.end()) { return true; } // Check if file exists if (!fs::exists(path)) { std::cerr << "Model file does not exist: " << path << std::endl; return false; } // Create and initialize the stable-diffusion wrapper auto wrapper = std::make_unique(); // Set up generation parameters for model loading StableDiffusionWrapper::GenerationParams loadParams; loadParams.modelPath = path; loadParams.modelType = "f16"; // Default to f16 for better performance // Try to load the model if (!wrapper->loadModel(path, loadParams)) { std::cerr << "Failed to load model '" << name << "': " << wrapper->getLastError() << std::endl; return false; } pImpl->loadedModels[name] = std::move(wrapper); // Update model info if (pImpl->availableModels.find(name) != pImpl->availableModels.end()) { pImpl->availableModels[name].isLoaded = true; } else { // Create a new model info entry ModelInfo info; info.name = name; info.path = path; info.fullPath = fs::absolute(path).string(); info.type = type; info.isLoaded = true; info.sha256 = ""; info.description = ""; // Initialize description info.metadata = {}; // Initialize metadata try { info.fileSize = fs::file_size(path); info.modifiedAt = fs::last_write_time(path); info.createdAt = info.modifiedAt; // Use modified time as creation time for now } catch (const fs::filesystem_error& e) { std::cerr << "Error getting file info for " << path << ": " << e.what() << std::endl; info.fileSize = 0; info.modifiedAt = fs::file_time_type{}; info.createdAt = fs::file_time_type{}; } pImpl->availableModels[name] = info; } return true; } bool ModelManager::loadModel(const std::string& name) { std::string path; ModelType type; { std::unique_lock lock(pImpl->modelsMutex); // Check if model exists in available models auto it = pImpl->availableModels.find(name); if (it == pImpl->availableModels.end()) { std::cerr << "Model '" << name << "' not found in available models" << std::endl; return false; } // Check if already loaded if (pImpl->loadedModels.find(name) != pImpl->loadedModels.end()) { return true; } // Extract path and type while we have the lock path = it->second.path; type = it->second.type; } // Release lock here // Load the model without holding the lock return loadModel(name, path, type); } bool ModelManager::unloadModel(const std::string& name) { std::unique_lock lock(pImpl->modelsMutex); // Check if model is loaded auto loadedIt = pImpl->loadedModels.find(name); if (loadedIt == pImpl->loadedModels.end()) { return false; } // Unload the model properly if (loadedIt->second) { loadedIt->second->unloadModel(); } pImpl->loadedModels.erase(loadedIt); // Update model info auto availableIt = pImpl->availableModels.find(name); if (availableIt != pImpl->availableModels.end()) { availableIt->second.isLoaded = false; } return true; } StableDiffusionWrapper* ModelManager::getModel(const std::string& name) { std::shared_lock lock(pImpl->modelsMutex); auto it = pImpl->loadedModels.find(name); if (it == pImpl->loadedModels.end()) { return nullptr; } return it->second.get(); } std::map ModelManager::getAllModels() const { std::shared_lock lock(pImpl->modelsMutex); return pImpl->availableModels; } std::vector ModelManager::getModelsByType(ModelType type) const { std::shared_lock lock(pImpl->modelsMutex); std::vector result; for (const auto& pair : pImpl->availableModels) { if (pair.second.type == type) { result.push_back(pair.second); } } return result; } ModelManager::ModelInfo ModelManager::getModelInfo(const std::string& name) const { std::unique_lock lock(pImpl->modelsMutex); auto it = pImpl->availableModels.find(name); if (it == pImpl->availableModels.end()) { return ModelInfo{}; // Return empty ModelInfo if not found } return it->second; } bool ModelManager::isModelLoaded(const std::string& name) const { std::unique_lock lock(pImpl->modelsMutex); auto it = pImpl->loadedModels.find(name); return it != pImpl->loadedModels.end(); } size_t ModelManager::getLoadedModelsCount() const { std::shared_lock lock(pImpl->modelsMutex); return pImpl->loadedModels.size(); } size_t ModelManager::getAvailableModelsCount() const { std::shared_lock lock(pImpl->modelsMutex); return pImpl->availableModels.size(); } void ModelManager::setModelsDirectory(const std::string& path) { pImpl->modelsDirectory = path; } std::string ModelManager::getModelsDirectory() const { return pImpl->modelsDirectory; } std::string ModelManager::modelTypeToString(ModelType type) { switch (type) { case ModelType::LORA: return "lora"; case ModelType::CHECKPOINT: return "checkpoint"; case ModelType::VAE: return "vae"; case ModelType::PRESETS: return "presets"; case ModelType::PROMPTS: return "prompts"; case ModelType::NEG_PROMPTS: return "neg_prompts"; case ModelType::TAESD: return "taesd"; case ModelType::ESRGAN: return "esrgan"; case ModelType::CONTROLNET: return "controlnet"; case ModelType::UPSCALER: return "upscaler"; case ModelType::EMBEDDING: return "embedding"; default: return "unknown"; } } ModelType ModelManager::stringToModelType(const std::string& typeStr) { std::string lowerType = typeStr; std::transform(lowerType.begin(), lowerType.end(), lowerType.begin(), ::tolower); if (lowerType == "lora") { return ModelType::LORA; } else if (lowerType == "checkpoint" || lowerType == "stable-diffusion") { return ModelType::CHECKPOINT; } else if (lowerType == "vae") { return ModelType::VAE; } else if (lowerType == "presets") { return ModelType::PRESETS; } else if (lowerType == "prompts") { return ModelType::PROMPTS; } else if (lowerType == "neg_prompts" || lowerType == "negative_prompts") { return ModelType::NEG_PROMPTS; } else if (lowerType == "taesd") { return ModelType::TAESD; } else if (lowerType == "esrgan") { return ModelType::ESRGAN; } else if (lowerType == "controlnet") { return ModelType::CONTROLNET; } else if (lowerType == "upscaler") { return ModelType::UPSCALER; } else if (lowerType == "embedding" || lowerType == "textual-inversion") { return ModelType::EMBEDDING; } return ModelType::NONE; } bool ModelManager::setModelTypeDirectory(ModelType type, const std::string& path) { std::unique_lock lock(pImpl->modelsMutex); if (!pImpl->validateDirectory(path)) { return false; } pImpl->modelTypeDirectories[type] = path; pImpl->legacyMode = false; return true; } std::string ModelManager::getModelTypeDirectory(ModelType type) const { std::shared_lock lock(pImpl->modelsMutex); return pImpl->getModelTypeDirectory(type); } bool ModelManager::setAllModelTypeDirectories(const std::map& directories) { std::unique_lock lock(pImpl->modelsMutex); // Validate all directories first for (const auto& [type, path] : directories) { if (!path.empty() && !pImpl->validateDirectory(path)) { return false; } } // Set all directories pImpl->modelTypeDirectories = directories; pImpl->legacyMode = false; return true; } std::map ModelManager::getAllModelTypeDirectories() const { std::shared_lock lock(pImpl->modelsMutex); return pImpl->modelTypeDirectories; } void ModelManager::resetToLegacyDirectories() { // Note: This method should be called with modelsMutex already locked pImpl->modelTypeDirectories.clear(); pImpl->legacyMode = true; } bool ModelManager::configureFromServerConfig(const ServerConfig& config) { std::unique_lock lock(pImpl->modelsMutex); // Set the base models directory pImpl->modelsDirectory = config.modelsDir; if (config.legacyMode) { // Legacy mode: use single models directory resetToLegacyDirectories(); return true; } else { // Explicit mode: set per-type directories std::map directories; if (!config.checkpoints.empty()) { directories[ModelType::CHECKPOINT] = config.checkpoints; } if (!config.controlnetDir.empty()) { directories[ModelType::CONTROLNET] = config.controlnetDir; } if (!config.embeddingsDir.empty()) { directories[ModelType::EMBEDDING] = config.embeddingsDir; } if (!config.esrganDir.empty()) { directories[ModelType::ESRGAN] = config.esrganDir; } if (!config.loraDir.empty()) { directories[ModelType::LORA] = config.loraDir; } if (!config.taesdDir.empty()) { directories[ModelType::TAESD] = config.taesdDir; } if (!config.vaeDir.empty()) { directories[ModelType::VAE] = config.vaeDir; } // Validate all directories first for (const auto& [type, path] : directories) { if (!path.empty() && !pImpl->validateDirectory(path)) { return false; } } // Set all directories (inlined to avoid deadlock from calling setAllModelTypeDirectories) pImpl->modelTypeDirectories = directories; pImpl->legacyMode = false; return true; } } void ModelManager::cancelScan() { pImpl->scanCancelled.store(true); } // SHA256 Hashing Implementation std::string ModelManager::computeModelHash(const std::string& modelName) { std::shared_lock lock(pImpl->modelsMutex); auto it = pImpl->availableModels.find(modelName); if (it == pImpl->availableModels.end()) { std::cerr << "Model not found: " << modelName << std::endl; return ""; } std::string filePath = it->second.fullPath; lock.unlock(); std::ifstream file(filePath, std::ios::binary); if (!file.is_open()) { std::cerr << "Failed to open file for hashing: " << filePath << std::endl; return ""; } // Create and initialize EVP context for SHA256 EVP_MD_CTX* mdctx = EVP_MD_CTX_new(); if (mdctx == nullptr) { std::cerr << "Failed to create EVP context" << std::endl; return ""; } if (EVP_DigestInit_ex(mdctx, EVP_sha256(), nullptr) != 1) { std::cerr << "Failed to initialize SHA256 digest" << std::endl; EVP_MD_CTX_free(mdctx); return ""; } const size_t bufferSize = 8192; char buffer[bufferSize]; std::cout << "Computing SHA256 for: " << modelName << std::endl; size_t totalRead = 0; size_t lastReportedMB = 0; while (file.read(buffer, bufferSize) || file.gcount() > 0) { size_t bytesRead = file.gcount(); if (EVP_DigestUpdate(mdctx, buffer, bytesRead) != 1) { std::cerr << "Failed to update digest" << std::endl; EVP_MD_CTX_free(mdctx); return ""; } totalRead += bytesRead; // Progress reporting every 100MB size_t currentMB = totalRead / (1024 * 1024); if (currentMB >= lastReportedMB + 100) { std::cout << " Hashed " << currentMB << " MB..." << std::endl; lastReportedMB = currentMB; } } file.close(); unsigned char hash[EVP_MAX_MD_SIZE]; unsigned int hashLen = 0; if (EVP_DigestFinal_ex(mdctx, hash, &hashLen) != 1) { std::cerr << "Failed to finalize digest" << std::endl; EVP_MD_CTX_free(mdctx); return ""; } EVP_MD_CTX_free(mdctx); // Convert to hex string std::ostringstream oss; for (unsigned int i = 0; i < hashLen; i++) { oss << std::hex << std::setw(2) << std::setfill('0') << static_cast(hash[i]); } std::string hashStr = oss.str(); std::cout << "Hash computed: " << hashStr.substr(0, 16) << "..." << std::endl; return hashStr; } std::string ModelManager::loadModelHashFromFile(const std::string& modelName) { std::shared_lock lock(pImpl->modelsMutex); auto it = pImpl->availableModels.find(modelName); if (it == pImpl->availableModels.end()) { return ""; } std::string jsonPath = it->second.fullPath + ".json"; lock.unlock(); if (!fs::exists(jsonPath)) { return ""; } try { std::ifstream jsonFile(jsonPath); if (!jsonFile.is_open()) { return ""; } nlohmann::json j; jsonFile >> j; jsonFile.close(); if (j.contains("sha256") && j["sha256"].is_string()) { return j["sha256"].get(); } } catch (const std::exception& e) { std::cerr << "Error loading hash from JSON: " << e.what() << std::endl; } return ""; } bool ModelManager::saveModelHashToFile(const std::string& modelName, const std::string& hash) { std::shared_lock lock(pImpl->modelsMutex); auto it = pImpl->availableModels.find(modelName); if (it == pImpl->availableModels.end()) { return false; } std::string jsonPath = it->second.fullPath + ".json"; size_t fileSize = it->second.fileSize; lock.unlock(); try { nlohmann::json j; j["sha256"] = hash; j["file_size"] = fileSize; j["computed_at"] = std::chrono::system_clock::now().time_since_epoch().count(); std::ofstream jsonFile(jsonPath); if (!jsonFile.is_open()) { std::cerr << "Failed to open file for writing: " << jsonPath << std::endl; return false; } jsonFile << j.dump(2); jsonFile.close(); std::cout << "Saved hash to: " << jsonPath << std::endl; return true; } catch (const std::exception& e) { std::cerr << "Error saving hash to JSON: " << e.what() << std::endl; return false; } } std::string ModelManager::findModelByHash(const std::string& hash) { if (hash.length() < 10) { std::cerr << "Hash must be at least 10 characters" << std::endl; return ""; } std::shared_lock lock(pImpl->modelsMutex); for (const auto& [name, info] : pImpl->availableModels) { if (info.sha256.empty()) { continue; } // Support full or partial match (minimum 10 chars) if (info.sha256 == hash || info.sha256.substr(0, hash.length()) == hash) { return name; } } return ""; } std::string ModelManager::ensureModelHash(const std::string& modelName, bool forceCompute) { // Try to load existing hash if not forcing recompute if (!forceCompute) { std::string existingHash = loadModelHashFromFile(modelName); if (!existingHash.empty()) { // Update in-memory model info std::unique_lock lock(pImpl->modelsMutex); auto it = pImpl->availableModels.find(modelName); if (it != pImpl->availableModels.end()) { it->second.sha256 = existingHash; } return existingHash; } } // Compute new hash std::string hash = computeModelHash(modelName); if (hash.empty()) { return ""; } // Save to file saveModelHashToFile(modelName, hash); // Update in-memory model info std::unique_lock lock(pImpl->modelsMutex); auto it = pImpl->availableModels.find(modelName); if (it != pImpl->availableModels.end()) { it->second.sha256 = hash; } return hash; }