#ifndef SERVER_H #define SERVER_H #include #include #include #include #include #include #include "generation_queue.h" #include "model_manager.h" #include "server_config.h" // Forward declarations class ModelManager; class GenerationQueue; class UserManager; class AuthMiddleware; namespace httplib { class Server; class Request; class Response; } /** * @brief HTTP server class for handling REST API requests * * This class implements the HTTP server that exposes the stable-diffusion.cpp * functionality through a REST API. It handles incoming requests, validates * parameters, and coordinates with the model manager and generation queue. * The server runs in a separate thread to handle HTTP requests independently * from the generation process. */ class Server { public: /** * @brief Construct a new Server object * * @param modelManager Pointer to the model manager instance * @param generationQueue Pointer to the generation queue instance * @param outputDir Directory where generated output files are stored * @param uiDir Directory containing static web UI files (optional) * @param config Server configuration */ Server(ModelManager* modelManager, GenerationQueue* generationQueue, const std::string& outputDir = "./output", const std::string& uiDir = "", const ServerConfig& config = ServerConfig{}); /** * @brief Destroy the Server object */ virtual ~Server(); /** * @brief Start the HTTP server * * @param host The host address to bind to * @param port The port number to listen on * @return true if the server started successfully, false otherwise */ bool start(const std::string& host = "0.0.0.0", int port = 8080); /** * @brief Stop the HTTP server */ void stop(); /** * @brief Check if the server is running * * @return true if the server is running, false otherwise */ bool isRunning() const; /** * @brief Wait for the server thread to finish */ void waitForStop(); /** * @brief Set authentication components */ void setAuthComponents(std::shared_ptr userManager, std::shared_ptr authMiddleware); private: /** * @brief Register all API endpoints */ void registerEndpoints(); /** * @brief Register authentication endpoints */ void registerAuthEndpoints(); /** * @brief Set up CORS headers for responses */ void setupCORS(); /** * @brief Log HTTP access request */ void logHttpAccess(const httplib::Request& req, const httplib::Response& res, const std::string& endpoint = ""); /** * @brief Health check endpoint handler */ void handleHealthCheck(const httplib::Request& req, httplib::Response& res); /** * @brief API status endpoint handler */ void handleApiStatus(const httplib::Request& req, httplib::Response& res); /** * @brief Version information endpoint handler */ void handleVersion(const httplib::Request& req, httplib::Response& res); /** * @brief Models list endpoint handler */ void handleModelsList(const httplib::Request& req, httplib::Response& res); /** * @brief Hash models endpoint handler */ void handleHashModels(const httplib::Request& req, httplib::Response& res); /** * @brief Convert/quantize model endpoint handler */ void handleConvertModel(const httplib::Request& req, httplib::Response& res); // Enhanced model management endpoints /** * @brief Get detailed information about a specific model */ void handleModelInfo(const httplib::Request& req, httplib::Response& res); /** * @brief Load a specific model by ID */ void handleLoadModelById(const httplib::Request& req, httplib::Response& res); /** * @brief Unload a specific model by ID */ void handleUnloadModelById(const httplib::Request& req, httplib::Response& res); /** * @brief List all available model types */ void handleModelTypes(const httplib::Request& req, httplib::Response& res); /** * @brief List model directories and their contents */ void handleModelDirectories(const httplib::Request& req, httplib::Response& res); /** * @brief Force refresh of model cache */ void handleRefreshModels(const httplib::Request& req, httplib::Response& res); /** * @brief Get statistics about loaded models */ void handleModelStats(const httplib::Request& req, httplib::Response& res); /** * @brief Batch operations on multiple models */ void handleBatchModels(const httplib::Request& req, httplib::Response& res); /** * @brief Validate model files and format */ void handleValidateModel(const httplib::Request& req, httplib::Response& res); /** * @brief Check model compatibility with current configuration */ void handleCheckCompatibility(const httplib::Request& req, httplib::Response& res); /** * @brief Get system requirements for specific models */ void handleModelRequirements(const httplib::Request& req, httplib::Response& res); /** * @brief Queue status endpoint handler */ void handleQueueStatus(const httplib::Request& req, httplib::Response& res); /** * @brief Job status endpoint handler */ void handleJobStatus(const httplib::Request& req, httplib::Response& res); /** * @brief Cancel job endpoint handler */ void handleCancelJob(const httplib::Request& req, httplib::Response& res); /** * @brief Clear queue endpoint handler */ void handleClearQueue(const httplib::Request& req, httplib::Response& res); /** * @brief Download job output file endpoint handler */ void handleDownloadOutput(const httplib::Request& req, httplib::Response& res); /** * @brief Get job output by job ID endpoint handler */ void handleJobOutput(const httplib::Request& req, httplib::Response& res); /** * @brief Get specific job output file by filename endpoint handler */ void handleJobOutputFile(const httplib::Request& req, httplib::Response& res); /** * @brief Download image from URL and return as base64 endpoint handler */ void handleDownloadImageFromUrl(const httplib::Request& req, httplib::Response& res); /** * @brief Resize image endpoint handler */ void handleImageResize(const httplib::Request& req, httplib::Response& res); /** * @brief Crop image endpoint handler */ void handleImageCrop(const httplib::Request& req, httplib::Response& res); /** * @brief Serve temporary image endpoint handler */ void handleTempImage(const httplib::Request& req, httplib::Response& res); // Specialized generation endpoints /** * @brief Text-to-image generation endpoint handler */ void handleText2Img(const httplib::Request& req, httplib::Response& res); /** * @brief Image-to-image generation endpoint handler */ void handleImg2Img(const httplib::Request& req, httplib::Response& res); /** * @brief ControlNet generation endpoint handler */ void handleControlNet(const httplib::Request& req, httplib::Response& res); /** * @brief Upscaler endpoint handler */ void handleUpscale(const httplib::Request& req, httplib::Response& res); /** * @brief Inpainting endpoint handler */ void handleInpainting(const httplib::Request& req, httplib::Response& res); // Utility endpoints /** * @brief List available sampling methods endpoint handler */ void handleSamplers(const httplib::Request& req, httplib::Response& res); /** * @brief List available schedulers endpoint handler */ void handleSchedulers(const httplib::Request& req, httplib::Response& res); /** * @brief Get parameter schema and validation rules endpoint handler */ void handleParameters(const httplib::Request& req, httplib::Response& res); /** * @brief Validate generation parameters endpoint handler */ void handleValidate(const httplib::Request& req, httplib::Response& res); /** * @brief Estimate generation time and memory usage endpoint handler */ void handleEstimate(const httplib::Request& req, httplib::Response& res); /** * @brief Get/set server configuration endpoint handler */ void handleConfig(const httplib::Request& req, httplib::Response& res); /** * @brief System information and capabilities endpoint handler */ void handleSystem(const httplib::Request& req, httplib::Response& res); /** * @brief System restart endpoint handler */ void handleSystemRestart(const httplib::Request& req, httplib::Response& res); // Authentication endpoint handlers /** * @brief Login endpoint handler */ void handleLogin(const httplib::Request& req, httplib::Response& res); /** * @brief Logout endpoint handler */ void handleLogout(const httplib::Request& req, httplib::Response& res); /** * @brief Token validation endpoint handler */ void handleValidateToken(const httplib::Request& req, httplib::Response& res); /** * @brief Token refresh endpoint handler */ void handleRefreshToken(const httplib::Request& req, httplib::Response& res); /** * @brief Get current user endpoint handler */ void handleGetCurrentUser(const httplib::Request& req, httplib::Response& res); /** * @brief Send JSON response with proper headers */ void sendJsonResponse(httplib::Response& res, const nlohmann::json& json, int status_code = 200); /** * @brief Send error response with proper headers */ void sendErrorResponse(httplib::Response& res, const std::string& message, int status_code = 400, const std::string& error_code = "", const std::string& request_id = ""); /** * @brief Validate generation parameters */ std::pair validateGenerationParameters(const nlohmann::json& params); /** * @brief Parse sampling method from string */ SamplingMethod parseSamplingMethod(const std::string& method); /** * @brief Parse scheduler from string */ Scheduler parseScheduler(const std::string& scheduler); /** * @brief Generate unique request ID */ std::string generateRequestId(); /** * @brief Get sampling method as string */ std::string samplingMethodToString(SamplingMethod method); /** * @brief Get scheduler as string */ std::string schedulerToString(Scheduler scheduler); /** * @brief Estimate generation time based on parameters */ uint64_t estimateGenerationTime(const GenerationRequest& request); /** * @brief Estimate memory usage based on parameters */ size_t estimateMemoryUsage(const GenerationRequest& request); /** * @brief Get model capabilities based on type */ nlohmann::json getModelCapabilities(ModelType type); /** * @brief Get statistics for each model type */ nlohmann::json getModelTypeStatistics(); // Additional helper methods for model management /** * @brief Get model compatibility information */ nlohmann::json getModelCompatibility(const ModelManager::ModelInfo& modelInfo); /** * @brief Get model requirements based on type */ nlohmann::json getModelRequirements(ModelType type); /** * @brief Get recommended usage parameters for model type */ nlohmann::json getRecommendedUsage(ModelType type); /** * @brief Load image from base64 or file path * @return tuple of (data, width, height, channels, success, error_message) */ std::tuple, int, int, int, bool, std::string> loadImageFromInput(const std::string& input); /** * @brief Get model type from directory name */ std::string getModelTypeFromDirectoryName(const std::string& dirName); /** * @brief Get description for model directory */ std::string getDirectoryDescription(const std::string& dirName); /** * @brief Get contents of a directory */ nlohmann::json getDirectoryContents(const std::string& dirPath); /** * @brief Get largest model from collection */ nlohmann::json getLargestModel(const std::map& allModels); /** * @brief Get smallest model from collection */ nlohmann::json getSmallestModel(const std::map& allModels); /** * @brief Validate model file and format */ nlohmann::json validateModelFile(const std::string& modelPath, const std::string& modelType); /** * @brief Check model compatibility with system */ nlohmann::json checkModelCompatibility(const ModelManager::ModelInfo& modelInfo, const std::string& systemInfo); /** * @brief Calculate specific requirements for model configuration */ nlohmann::json calculateSpecificRequirements(const std::string& modelType, const std::string& resolution, const std::string& batchSize); /** * @brief Convert ModelDetails vector to JSON array */ nlohmann::json modelDetailsToJson(const std::vector& modelDetails); /** * @brief Determine which recommended fields to include based on architecture */ std::map getRecommendedModelFields(const std::string& architecture); /** * @brief Populate recommended models with existence information */ void populateRecommendedModels(nlohmann::json& response, const ModelManager::ModelInfo& modelInfo); /** * @brief Server thread function */ void serverThreadFunction(const std::string& host, int port); ModelManager* m_modelManager; ///< Pointer to model manager GenerationQueue* m_generationQueue; ///< Pointer to generation queue std::unique_ptr m_httpServer; ///< HTTP server instance std::thread m_serverThread; ///< Thread for running the server std::atomic m_isRunning; ///< Flag indicating if server is running std::atomic m_startupFailed; ///< Flag indicating if server startup failed std::string m_host; ///< Host address int m_port; ///< Port number std::string m_outputDir; ///< Output directory for generated files std::string m_uiDir; ///< Directory containing static web UI files struct LoadedModels { std::string checkpoint; ///< Currently loaded checkpoint model (for text2img, img2img, etc.) std::string esrgan; ///< Currently loaded ESRGAN/upscaler model }; LoadedModels m_loadedModels; ///< Currently loaded models by type mutable std::mutex m_loadedModelsMutex; ///< Mutex for thread-safe access to loaded models std::shared_ptr m_userManager; ///< User manager instance /** * @brief Get reference to the appropriate model field based on model type * @param type The model type * @return Reference to the model field */ std::string& getModelField(ModelType type); std::shared_ptr m_authMiddleware; ///< Authentication middleware instance ServerConfig m_config; ///< Server configuration /** * @brief Generate thumbnail for image file * * @param imagePath Path to the original image file * @param size Thumbnail size (width and height) * @return JPEG thumbnail data as string, empty string if failed */ std::string generateThumbnail(const std::string& imagePath, int size); }; #endif // SERVER_H