| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462 |
- #ifndef MODEL_MANAGER_H
- #define MODEL_MANAGER_H
- #include <string>
- #include <memory>
- #include <map>
- #include <shared_mutex>
- #include <vector>
- #include <filesystem>
- #include <cstdint>
- #include <functional>
- // Forward declarations
- class StableDiffusionWrapper;
- class ModelPathSelector;
- class ModelDetectionCache;
- #include "model_detector.h"
- #include "server_config.h"
- /**
- * @brief Model type enumeration
- *
- * These values are bit flags that can be combined to filter model types.
- */
- enum class ModelType : uint32_t {
- NONE = 0,
- LORA = 1,
- CHECKPOINT = 2,
- VAE = 4,
- PRESETS = 8,
- PROMPTS = 16,
- NEG_PROMPTS = 32,
- TAESD = 64,
- ESRGAN = 128,
- CONTROLNET = 256,
- UPSCALER = 512,
- EMBEDDING = 1024,
- DIFFUSION_MODELS = 2048
- };
- // Enable bitwise operations for ModelType
- inline ModelType operator|(ModelType a, ModelType b) {
- return static_cast<ModelType>(static_cast<uint32_t>(a) | static_cast<uint32_t>(b));
- }
- inline ModelType operator&(ModelType a, ModelType b) {
- return static_cast<ModelType>(static_cast<uint32_t>(a) & static_cast<uint32_t>(b));
- }
- inline ModelType& operator|=(ModelType& a, ModelType b) {
- a = a | b;
- return a;
- }
- /**
- * @brief Model manager class for loading and managing stable-diffusion models
- *
- * This class handles loading, unloading, and managing multiple stable-diffusion models.
- * It provides thread-safe access to models and manages model resources efficiently.
- */
- class ModelManager {
- public:
- /**
- * @brief Model information structure
- */
- struct ModelInfo {
- std::string name; ///< Model name
- std::string path; ///< Model file path
- std::string fullPath; ///< Absolute path to the model file
- ModelType type; ///< Model type
- bool isLoaded; ///< Whether the model is currently loaded
- size_t fileSize; ///< File size in bytes
- std::string sha256; ///< SHA256 hash of the file
- std::filesystem::file_time_type createdAt; ///< File creation time
- std::filesystem::file_time_type modifiedAt; ///< Last modification time
- std::string description; ///< Model description
- std::map<std::string, std::string> metadata; ///< Additional metadata
- // Architecture detection fields
- std::string architecture; ///< Detected architecture (e.g., "Stable Diffusion XL Base", "Flux Dev")
- std::string recommendedVAE; ///< Recommended VAE for this model
- int recommendedWidth = 0; ///< Recommended image width
- int recommendedHeight = 0; ///< Recommended image height
- int recommendedSteps = 0; ///< Recommended number of steps
- std::string recommendedSampler; ///< Recommended sampler
- std::vector<std::string> requiredModels; ///< List of required auxiliary models (VAE, CLIP, etc.)
- std::vector<std::string> missingModels; ///< List of missing required models
- // Caching-related fields
- bool cacheValid = false; ///< Whether cached detection results are valid
- std::filesystem::file_time_type cacheModifiedAt; ///< Modification time when cache was created
- std::string cachePathType; ///< Path type used for this model ("model_path" or "diffusion_model_path")
- bool useFolderBasedDetection = false; ///< Whether folder-based detection was used
- std::string detectionSource; ///< Source of detection: "folder", "architecture", "fallback"
- };
- /**
- * @brief Model details structure for existence checking
- */
- struct ModelDetails {
- std::string name; ///< Model name
- bool exists; ///< Whether the model exists
- std::string type; ///< Model type ("VAE", "CLIP-L", "CLIP-G", "T5XXL", "CLIP-Vision", "Qwen2VL")
- std::string path; ///< Absolute path to the model file (empty if doesn't exist)
- size_t file_size; ///< File size in bytes (0 if doesn't exist)
- std::string sha256; ///< SHA256 hash (empty if doesn't exist)
- bool is_required; ///< True for required models
- bool is_recommended; ///< True for recommended models
- };
- /**
- * @brief Construct a new Model Manager object
- */
- ModelManager();
- /**
- * @brief Destroy the Model Manager object
- */
- virtual ~ModelManager();
- /**
- * @brief Scan the models directory to discover available models
- *
- * Recursively scans all subdirectories within the models directory to find
- * model files. For each model found, constructs the display name as
- * 'relative_path/model_name' where relative_path is the path from the models
- * root directory to the file's containing folder (using forward slashes).
- * Models in the root directory appear without a prefix.
- *
- * @return true if scanning was successful, false otherwise
- */
- bool scanModelsDirectory();
- /**
- * @brief Cancel any ongoing model directory scanning
- */
- void cancelScan();
- /**
- * @brief Load a model from the specified path
- *
- * @param name The name to assign to the model
- * @param path The file path to the model
- * @param type The type of model
- * @return true if the model was loaded successfully, false otherwise
- */
- bool loadModel(const std::string& name, const std::string& path, ModelType type);
- /**
- * @brief Load a model by name (must be discovered first)
- *
- * @param name The name of the model to load
- * @return true if the model was loaded successfully, false otherwise
- */
- bool loadModel(const std::string& name);
- /**
- * @brief Load a model by name with progress callback
- *
- * @param name The name of the model to load
- * @param progressCallback Callback function for progress updates (0.0-1.0)
- * @return true if the model was loaded successfully, false otherwise
- */
- bool loadModel(const std::string& name, std::function<void(float)> progressCallback);
- /**
- * @brief Unload a model
- *
- * @param name The name of the model to unload
- * @return true if the model was unloaded successfully, false otherwise
- */
- bool unloadModel(const std::string& name);
- /**
- * @brief Unload all currently loaded models
- *
- * This method unloads all models that are currently loaded, ensuring that
- * all contexts and parameters are properly freed. This is useful during
- * graceful shutdown to prevent memory leaks.
- */
- void unloadAllModels();
- /**
- * @brief Get a pointer to a loaded model
- *
- * @param name The name of the model
- * @return StableDiffusionWrapper* Pointer to the model wrapper, or nullptr if not found
- */
- StableDiffusionWrapper* getModel(const std::string& name);
- /**
- * @brief Get information about all models
- *
- * @return std::map<std::string, ModelInfo> Map of model names to their information
- */
- std::map<std::string, ModelInfo> getAllModels() const;
- /**
- * @brief Get information about models of a specific type
- *
- * @param type The model type to filter by
- * @return std::vector<ModelInfo> List of model information
- */
- std::vector<ModelInfo> getModelsByType(ModelType type) const;
- /**
- * @brief Get information about a specific model
- *
- * @param name The name of the model
- * @return ModelInfo Model information, or empty if not found
- */
- ModelInfo getModelInfo(const std::string& name) const;
- /**
- * @brief Check if a model is loaded
- *
- * @param name The name of the model
- * @return true if the model is loaded, false otherwise
- */
- bool isModelLoaded(const std::string& name) const;
- /**
- * @brief Get the number of loaded models
- *
- * @return size_t Number of loaded models
- */
- size_t getLoadedModelsCount() const;
- /**
- * @brief Get the number of available models
- *
- * @return size_t Number of available models
- */
- size_t getAvailableModelsCount() const;
- /**
- * @brief Set the models directory path
- *
- * @param path The path to the models directory
- */
- void setModelsDirectory(const std::string& path);
- /**
- * @brief Get the models directory path
- *
- * @return std::string The models directory path
- */
- std::string getModelsDirectory() const;
- /**
- * @brief Set directory for a specific model type
- *
- * @param type The model type
- * @param path The directory path
- * @return true if the directory was set successfully, false otherwise
- */
- bool setModelTypeDirectory(ModelType type, const std::string& path);
- /**
- * @brief Get directory for a specific model type
- *
- * @param type The model type
- * @return std::string The directory path, empty if not set
- */
- std::string getModelTypeDirectory(ModelType type) const;
- /**
- * @brief Set all model type directories at once
- *
- * @param directories Map of model types to directory paths
- * @return true if all directories were set successfully, false otherwise
- */
- bool setAllModelTypeDirectories(const std::map<ModelType, std::string>& directories);
- /**
- * @brief Get all model type directories
- *
- * @return std::map<ModelType, std::string> Map of model types to directory paths
- */
- std::map<ModelType, std::string> getAllModelTypeDirectories() const;
- // Legacy methods removed - using explicit directory configuration only
- /**
- * @brief Configure ModelManager with ServerConfig
- *
- * @param config The server configuration
- * @return true if configuration was successful, false otherwise
- */
- bool configureFromServerConfig(const struct ServerConfig& config);
- /**
- * @brief Convert ModelType to string
- *
- * @param type The model type
- * @return std::string String representation of the model type
- */
- static std::string modelTypeToString(ModelType type);
- /**
- * @brief Convert string to ModelType
- *
- * @param typeStr String representation of the model type
- * @return ModelType The model type
- */
- static ModelType stringToModelType(const std::string& typeStr);
- /**
- * @brief Compute SHA256 hash of a model file
- *
- * @param modelName The name of the model
- * @return std::string The SHA256 hash, or empty string on error
- */
- std::string computeModelHash(const std::string& modelName);
- /**
- * @brief Load hash from JSON file for a model
- *
- * @param modelName The name of the model
- * @return std::string The loaded hash, or empty string if not found
- */
- std::string loadModelHashFromFile(const std::string& modelName);
- /**
- * @brief Save hash to JSON file for a model
- *
- * @param modelName The name of the model
- * @param hash The SHA256 hash to save
- * @return true if saved successfully, false otherwise
- */
- bool saveModelHashToFile(const std::string& modelName, const std::string& hash);
- /**
- * @brief Find model by hash (full or partial - minimum 10 chars)
- *
- * @param hash Full or partial SHA256 hash (minimum 10 characters)
- * @return std::string Model name, or empty string if not found
- */
- std::string findModelByHash(const std::string& hash);
- /**
- * @brief Load hash for a model (from file or compute if missing)
- *
- * @param modelName The name of the model
- * @param forceCompute Force recomputation even if hash file exists
- * @return std::string The SHA256 hash, or empty string on error
- */
- std::string ensureModelHash(const std::string& modelName, bool forceCompute = false);
- /**
- * @brief Check if required models exist in the appropriate directories
- *
- * @param requiredModels List of required model names with types (e.g., "VAE: model.safetensors")
- * @return std::vector<ModelDetails> List of model details with existence information
- */
- std::vector<ModelDetails> checkRequiredModelsExistence(const std::vector<std::string>& requiredModels);
- private:
- class Impl;
- std::unique_ptr<Impl> pImpl; // Pimpl idiom
- /**
- * @brief Model path selector class for folder-based detection
- */
- class ModelPathSelector {
- public:
- /**
- * @brief Determine which path to use based on folder location
- *
- * @param modelPath The absolute path to the model file
- * @param checkpointsDir The checkpoints directory path
- * @param diffusionModelsDir The diffusion_models directory path
- * @return std::string "model_path" for checkpoints, "diffusion_model_path" for diffusion_models
- */
- static std::string selectPathType(
- const std::string& modelPath,
- const std::string& checkpointsDir,
- const std::string& diffusionModelsDir,
- bool verbose = false
- );
- /**
- * @brief Check if model is in a specific directory
- *
- * @param modelPath The absolute path to the model file
- * @param directory The directory to check
- * @return true if model is in the specified directory
- */
- static bool isModelInDirectory(const std::string& modelPath, const std::string& directory);
- };
- /**
- * @brief Model detection cache class for caching detection results
- */
- class ModelDetectionCache {
- public:
- /**
- * @brief Cache entry for model detection results
- */
- struct CacheEntry {
- std::string architecture;
- std::string recommendedVAE;
- int recommendedWidth = 0;
- int recommendedHeight = 0;
- int recommendedSteps = 0;
- std::string recommendedSampler;
- std::vector<std::string> requiredModels;
- std::vector<std::string> missingModels;
- std::string pathType;
- std::string detectionSource;
- std::filesystem::file_time_type cachedAt;
- std::filesystem::file_time_type fileModifiedAt;
- bool isValid = false;
- };
- /**
- * @brief Get cached detection result for a model
- *
- * @param modelPath The path to the model file
- * @param currentModifiedTime Current file modification time
- * @return CacheEntry Cached entry, or invalid if not found/expired
- */
- static CacheEntry getCachedResult(const std::string& modelPath,
- const std::filesystem::file_time_type& currentModifiedTime);
- /**
- * @brief Cache detection result for a model
- *
- * @param modelPath The path to the model file
- * @param detection The detection result to cache
- * @param pathType The path type used
- * @param detectionSource The source of detection
- * @param fileModifiedTime Current file modification time
- */
- static void cacheDetectionResult(
- const std::string& modelPath,
- const ModelDetectionResult& detection,
- const std::string& pathType,
- const std::string& detectionSource,
- const std::filesystem::file_time_type& fileModifiedTime
- );
- /**
- * @brief Invalidate cache for a model
- *
- * @param modelPath The path to the model file
- */
- static void invalidateCache(const std::string& modelPath);
- /**
- * @brief Clear all cached results
- */
- static void clearAllCache();
- private:
- static std::map<std::string, CacheEntry> cache_;
- static std::mutex cacheMutex_;
- };
- };
- #endif // MODEL_MANAGER_H
|