| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442 |
- #ifndef SERVER_H
- #define SERVER_H
- #include <memory>
- #include <string>
- #include <thread>
- #include <atomic>
- #include <functional>
- #include <nlohmann/json.hpp>
- #include "generation_queue.h"
- #include "model_manager.h"
- // Forward declarations
- class ModelManager;
- class GenerationQueue;
- class UserManager;
- class AuthMiddleware;
- namespace httplib {
- class Server;
- class Request;
- class Response;
- }
- /**
- * @brief HTTP server class for handling REST API requests
- *
- * This class implements the HTTP server that exposes the stable-diffusion.cpp
- * functionality through a REST API. It handles incoming requests, validates
- * parameters, and coordinates with the model manager and generation queue.
- * The server runs in a separate thread to handle HTTP requests independently
- * from the generation process.
- */
- class Server {
- public:
- /**
- * @brief Construct a new Server object
- *
- * @param modelManager Pointer to the model manager instance
- * @param generationQueue Pointer to the generation queue instance
- * @param outputDir Directory where generated output files are stored
- * @param uiDir Directory containing static web UI files (optional)
- */
- Server(ModelManager* modelManager, GenerationQueue* generationQueue, const std::string& outputDir = "./output", const std::string& uiDir = "");
- /**
- * @brief Destroy the Server object
- */
- virtual ~Server();
- /**
- * @brief Start the HTTP server
- *
- * @param host The host address to bind to
- * @param port The port number to listen on
- * @return true if the server started successfully, false otherwise
- */
- bool start(const std::string& host = "0.0.0.0", int port = 8080);
- /**
- * @brief Stop the HTTP server
- */
- void stop();
- /**
- * @brief Check if the server is running
- *
- * @return true if the server is running, false otherwise
- */
- bool isRunning() const;
- /**
- * @brief Wait for the server thread to finish
- */
- void waitForStop();
- /**
- * @brief Set authentication components
- */
- void setAuthComponents(std::shared_ptr<UserManager> userManager, std::shared_ptr<AuthMiddleware> authMiddleware);
- private:
- /**
- * @brief Register all API endpoints
- */
- void registerEndpoints();
- /**
- * @brief Register authentication endpoints
- */
- void registerAuthEndpoints();
- /**
- * @brief Set up CORS headers for responses
- */
- void setupCORS();
- /**
- * @brief Health check endpoint handler
- */
- void handleHealthCheck(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief API status endpoint handler
- */
- void handleApiStatus(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Models list endpoint handler
- */
- void handleModelsList(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Hash models endpoint handler
- */
- void handleHashModels(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Convert/quantize model endpoint handler
- */
- void handleConvertModel(const httplib::Request& req, httplib::Response& res);
- // Enhanced model management endpoints
- /**
- * @brief Get detailed information about a specific model
- */
- void handleModelInfo(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Load a specific model by ID
- */
- void handleLoadModelById(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Unload a specific model by ID
- */
- void handleUnloadModelById(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief List all available model types
- */
- void handleModelTypes(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief List model directories and their contents
- */
- void handleModelDirectories(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Force refresh of model cache
- */
- void handleRefreshModels(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Get statistics about loaded models
- */
- void handleModelStats(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Batch operations on multiple models
- */
- void handleBatchModels(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Validate model files and format
- */
- void handleValidateModel(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Check model compatibility with current configuration
- */
- void handleCheckCompatibility(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Get system requirements for specific models
- */
- void handleModelRequirements(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Queue status endpoint handler
- */
- void handleQueueStatus(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Job status endpoint handler
- */
- void handleJobStatus(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Cancel job endpoint handler
- */
- void handleCancelJob(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Clear queue endpoint handler
- */
- void handleClearQueue(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Download job output file endpoint handler
- */
- void handleDownloadOutput(const httplib::Request& req, httplib::Response& res);
- // Specialized generation endpoints
- /**
- * @brief Text-to-image generation endpoint handler
- */
- void handleText2Img(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Image-to-image generation endpoint handler
- */
- void handleImg2Img(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief ControlNet generation endpoint handler
- */
- void handleControlNet(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Upscaler endpoint handler
- */
- void handleUpscale(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Inpainting endpoint handler
- */
- void handleInpainting(const httplib::Request& req, httplib::Response& res);
- // Utility endpoints
- /**
- * @brief List available sampling methods endpoint handler
- */
- void handleSamplers(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief List available schedulers endpoint handler
- */
- void handleSchedulers(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Get parameter schema and validation rules endpoint handler
- */
- void handleParameters(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Validate generation parameters endpoint handler
- */
- void handleValidate(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Estimate generation time and memory usage endpoint handler
- */
- void handleEstimate(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Get/set server configuration endpoint handler
- */
- void handleConfig(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief System information and capabilities endpoint handler
- */
- void handleSystem(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief System restart endpoint handler
- */
- void handleSystemRestart(const httplib::Request& req, httplib::Response& res);
- // Authentication endpoint handlers
- /**
- * @brief Login endpoint handler
- */
- void handleLogin(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Logout endpoint handler
- */
- void handleLogout(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Token validation endpoint handler
- */
- void handleValidateToken(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Token refresh endpoint handler
- */
- void handleRefreshToken(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Get current user endpoint handler
- */
- void handleGetCurrentUser(const httplib::Request& req, httplib::Response& res);
- /**
- * @brief Send JSON response with proper headers
- */
- void sendJsonResponse(httplib::Response& res, const nlohmann::json& json, int status_code = 200);
- /**
- * @brief Send error response with proper headers
- */
- void sendErrorResponse(httplib::Response& res, const std::string& message, int status_code = 400,
- const std::string& error_code = "", const std::string& request_id = "");
- /**
- * @brief Validate generation parameters
- */
- std::pair<bool, std::string> validateGenerationParameters(const nlohmann::json& params);
- /**
- * @brief Parse sampling method from string
- */
- SamplingMethod parseSamplingMethod(const std::string& method);
- /**
- * @brief Parse scheduler from string
- */
- Scheduler parseScheduler(const std::string& scheduler);
- /**
- * @brief Generate unique request ID
- */
- std::string generateRequestId();
- /**
- * @brief Get sampling method as string
- */
- std::string samplingMethodToString(SamplingMethod method);
- /**
- * @brief Get scheduler as string
- */
- std::string schedulerToString(Scheduler scheduler);
- /**
- * @brief Estimate generation time based on parameters
- */
- uint64_t estimateGenerationTime(const GenerationRequest& request);
- /**
- * @brief Estimate memory usage based on parameters
- */
- size_t estimateMemoryUsage(const GenerationRequest& request);
- /**
- * @brief Get model capabilities based on type
- */
- nlohmann::json getModelCapabilities(ModelType type);
- /**
- * @brief Get statistics for each model type
- */
- nlohmann::json getModelTypeStatistics();
- // Additional helper methods for model management
- /**
- * @brief Get model compatibility information
- */
- nlohmann::json getModelCompatibility(const ModelManager::ModelInfo& modelInfo);
- /**
- * @brief Get model requirements based on type
- */
- nlohmann::json getModelRequirements(ModelType type);
- /**
- * @brief Get recommended usage parameters for model type
- */
- nlohmann::json getRecommendedUsage(ModelType type);
- /**
- * @brief Load image from base64 or file path
- * @return tuple of (data, width, height, channels, success, error_message)
- */
- std::tuple<std::vector<uint8_t>, int, int, int, bool, std::string>
- loadImageFromInput(const std::string& input);
- /**
- * @brief Get model type from directory name
- */
- std::string getModelTypeFromDirectoryName(const std::string& dirName);
- /**
- * @brief Get description for model directory
- */
- std::string getDirectoryDescription(const std::string& dirName);
- /**
- * @brief Get contents of a directory
- */
- nlohmann::json getDirectoryContents(const std::string& dirPath);
- /**
- * @brief Get largest model from collection
- */
- nlohmann::json getLargestModel(const std::map<std::string, ModelManager::ModelInfo>& allModels);
- /**
- * @brief Get smallest model from collection
- */
- nlohmann::json getSmallestModel(const std::map<std::string, ModelManager::ModelInfo>& allModels);
- /**
- * @brief Validate model file and format
- */
- nlohmann::json validateModelFile(const std::string& modelPath, const std::string& modelType);
- /**
- * @brief Check model compatibility with system
- */
- nlohmann::json checkModelCompatibility(const ModelManager::ModelInfo& modelInfo, const std::string& systemInfo);
- /**
- * @brief Calculate specific requirements for model configuration
- */
- nlohmann::json calculateSpecificRequirements(const std::string& modelType, const std::string& resolution, const std::string& batchSize);
- /**
- * @brief Server thread function
- */
- void serverThreadFunction(const std::string& host, int port);
- ModelManager* m_modelManager; ///< Pointer to model manager
- GenerationQueue* m_generationQueue; ///< Pointer to generation queue
- std::unique_ptr<httplib::Server> m_httpServer; ///< HTTP server instance
- std::thread m_serverThread; ///< Thread for running the server
- std::atomic<bool> m_isRunning; ///< Flag indicating if server is running
- std::atomic<bool> m_startupFailed; ///< Flag indicating if server startup failed
- std::string m_host; ///< Host address
- int m_port; ///< Port number
- std::string m_outputDir; ///< Output directory for generated files
- std::string m_uiDir; ///< Directory containing static web UI files
- std::string m_currentlyLoadedModel; ///< Currently loaded model name
- mutable std::mutex m_currentModelMutex; ///< Mutex for thread-safe access to current model
- std::shared_ptr<UserManager> m_userManager; ///< User manager instance
- std::shared_ptr<AuthMiddleware> m_authMiddleware; ///< Authentication middleware instance
- };
- #endif // SERVER_H
|