#include "server.h" #include "model_manager.h" #include "generation_queue.h" #include "utils.h" #include "auth_middleware.h" #include "user_manager.h" #include "version.h" #include #include #include #include #include #include #include #include #include #include // Include stb_image for loading images (implementation is in generation_queue.cpp) #include "../stable-diffusion.cpp-src/thirdparty/stb_image.h" #include #include #include #include Server::Server(ModelManager* modelManager, GenerationQueue* generationQueue, const std::string& outputDir, const std::string& uiDir, const ServerConfig& config) : m_modelManager(modelManager) , m_generationQueue(generationQueue) , m_isRunning(false) , m_startupFailed(false) , m_port(config.port) , m_outputDir(outputDir) , m_uiDir(uiDir) , m_userManager(nullptr) , m_authMiddleware(nullptr) , m_config(config) { m_httpServer = std::make_unique(); } Server::~Server() { stop(); } bool Server::start(const std::string& host, int port) { if (m_isRunning.load()) { return false; } m_host = host; m_port = port; // Validate host and port if (host.empty() || (port < 1 || port > 65535)) { return false; } // Set up CORS headers setupCORS(); // Register API endpoints registerEndpoints(); // Reset startup flags m_startupFailed.store(false); // Start server in a separate thread m_serverThread = std::thread(&Server::serverThreadFunction, this, host, port); // Wait for server to actually start and bind to the port // Give more time for server to actually start and bind for (int i = 0; i < 100; i++) { // Wait up to 10 seconds std::this_thread::sleep_for(std::chrono::milliseconds(100)); // Check if startup failed early if (m_startupFailed.load()) { if (m_serverThread.joinable()) { m_serverThread.join(); } return false; } if (m_isRunning.load()) { // Give it a moment more to ensure server is fully started std::this_thread::sleep_for(std::chrono::milliseconds(500)); if (m_isRunning.load()) { return true; } } } if (m_isRunning.load()) { return true; } else { if (m_serverThread.joinable()) { m_serverThread.join(); } return false; } } void Server::stop() { // Use atomic check to ensure thread safety bool wasRunning = m_isRunning.exchange(false); if (!wasRunning) { return; // Already stopped } if (m_httpServer) { m_httpServer->stop(); // Give the server a moment to stop the blocking listen call std::this_thread::sleep_for(std::chrono::milliseconds(100)); // If server thread is still running, try to force unblock the listen call // by making a quick connection to the server port if (m_serverThread.joinable()) { try { // Create a quick connection to interrupt the blocking listen httplib::Client client("127.0.0.1", m_port); client.set_connection_timeout(0, m_config.connectionTimeoutMs * 1000); // Convert ms to microseconds client.set_read_timeout(0, m_config.readTimeoutMs * 1000); // Convert ms to microseconds client.set_write_timeout(0, m_config.writeTimeoutMs * 1000); // Convert ms to microseconds auto res = client.Get("/api/health"); // We don't care about the response, just trying to unblock } catch (...) { // Ignore any connection errors - we're just trying to unblock } } } if (m_serverThread.joinable()) { m_serverThread.join(); } } bool Server::isRunning() const { return m_isRunning.load(); } void Server::waitForStop() { if (m_serverThread.joinable()) { m_serverThread.join(); } } void Server::registerEndpoints() { // Register authentication endpoints first (before applying middleware) registerAuthEndpoints(); // Health check endpoint (public) m_httpServer->Get("/api/health", [this](const httplib::Request& req, httplib::Response& res) { handleHealthCheck(req, res); }); // API status endpoint (public) m_httpServer->Get("/api/status", [this](const httplib::Request& req, httplib::Response& res) { handleApiStatus(req, res); }); // Version information endpoint (public) m_httpServer->Get("/api/version", [this](const httplib::Request& req, httplib::Response& res) { handleVersion(req, res); }); // Apply authentication middleware to protected endpoints auto withAuth = [this](std::function handler) { return [this, handler](const httplib::Request& req, httplib::Response& res) { if (m_authMiddleware) { AuthContext authContext = m_authMiddleware->authenticate(req, res); if (!authContext.authenticated) { m_authMiddleware->sendAuthError(res, authContext.errorMessage, authContext.errorCode); return; } } handler(req, res); }; }; // Specialized generation endpoints (protected) m_httpServer->Post("/api/generate/text2img", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleText2Img(req, res); })); m_httpServer->Post("/api/generate/img2img", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleImg2Img(req, res); })); m_httpServer->Post("/api/generate/controlnet", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleControlNet(req, res); })); m_httpServer->Post("/api/generate/upscale", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleUpscale(req, res); })); m_httpServer->Post("/api/generate/inpainting", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleInpainting(req, res); })); // Utility endpoints (now protected - require authentication) m_httpServer->Get("/api/samplers", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleSamplers(req, res); })); m_httpServer->Get("/api/schedulers", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleSchedulers(req, res); })); m_httpServer->Get("/api/parameters", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleParameters(req, res); })); m_httpServer->Post("/api/validate", [this](const httplib::Request& req, httplib::Response& res) { handleValidate(req, res); }); m_httpServer->Post("/api/estimate", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleEstimate(req, res); })); m_httpServer->Get("/api/config", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleConfig(req, res); })); m_httpServer->Get("/api/system", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleSystem(req, res); })); m_httpServer->Post("/api/system/restart", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleSystemRestart(req, res); })); // Models list endpoint (now protected - require authentication) m_httpServer->Get("/api/models", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleModelsList(req, res); })); // Model-specific endpoints m_httpServer->Get("/api/models/(.*)", [this](const httplib::Request& req, httplib::Response& res) { handleModelInfo(req, res); }); m_httpServer->Post("/api/models/(.*)/load", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleLoadModelById(req, res); })); m_httpServer->Post("/api/models/(.*)/unload", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleUnloadModelById(req, res); })); // Model management endpoints (now protected - require authentication) m_httpServer->Get("/api/models/types", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleModelTypes(req, res); })); m_httpServer->Get("/api/models/directories", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleModelDirectories(req, res); })); m_httpServer->Post("/api/models/refresh", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleRefreshModels(req, res); })); m_httpServer->Post("/api/models/hash", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleHashModels(req, res); })); m_httpServer->Post("/api/models/convert", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleConvertModel(req, res); })); m_httpServer->Get("/api/models/stats", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleModelStats(req, res); })); m_httpServer->Post("/api/models/batch", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleBatchModels(req, res); })); // Model validation endpoints (already protected with withAuth) m_httpServer->Post("/api/models/validate", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleValidateModel(req, res); })); m_httpServer->Post("/api/models/compatible", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleCheckCompatibility(req, res); })); m_httpServer->Post("/api/models/requirements", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleModelRequirements(req, res); })); // Queue status endpoint (now protected - require authentication) m_httpServer->Get("/api/queue/status", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleQueueStatus(req, res); })); // Download job output file endpoint (must be before job status endpoint to match more specific pattern first) // Note: This endpoint is public to allow frontend to display generated images without authentication m_httpServer->Get("/api/queue/job/(.*)/output/(.*)", [this](const httplib::Request& req, httplib::Response& res) { handleDownloadOutput(req, res); }); // Get job output by job ID endpoint (public to allow frontend to display generated images without authentication) m_httpServer->Get("/api/v1/jobs/(.*)/output", [this](const httplib::Request& req, httplib::Response& res) { handleJobOutput(req, res); }); // Download image from URL endpoint (public for CORS-free image handling) m_httpServer->Get("/api/image/download", [this](const httplib::Request& req, httplib::Response& res) { handleDownloadImageFromUrl(req, res); }); // Image resize endpoint (protected) m_httpServer->Post("/api/image/resize", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleImageResize(req, res); })); // Image crop endpoint (protected) m_httpServer->Post("/api/image/crop", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleImageCrop(req, res); })); // Job status endpoint (now protected - require authentication) m_httpServer->Get("/api/queue/job/(.*)", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleJobStatus(req, res); })); // Cancel job endpoint (protected) m_httpServer->Post("/api/queue/cancel", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleCancelJob(req, res); })); // Clear queue endpoint (protected) m_httpServer->Post("/api/queue/clear", withAuth([this](const httplib::Request& req, httplib::Response& res) { handleClearQueue(req, res); })); // Serve static web UI files if uiDir is configured if (!m_uiDir.empty() && std::filesystem::exists(m_uiDir)) { std::cout << "Serving static UI files from: " << m_uiDir << " at /ui" << std::endl; // Read UI version from version.nlohmann::json if available std::string uiVersion = "unknown"; std::string versionFilePath = m_uiDir + "/version.nlohmann::json"; if (std::filesystem::exists(versionFilePath)) { try { std::ifstream versionFile(versionFilePath); if (versionFile.is_open()) { nlohmann::json versionData = nlohmann::json::parse(versionFile); if (versionData.contains("version")) { uiVersion = versionData["version"].get(); } versionFile.close(); } } catch (const std::exception& e) { std::cerr << "Failed to read UI version: " << e.what() << std::endl; } } std::cout << "UI version: " << uiVersion << std::endl; // Serve dynamic config.js that provides runtime configuration to the web UI m_httpServer->Get("/ui/config.js", [this, uiVersion](const httplib::Request& /*req*/, httplib::Response& res) { // Generate JavaScript configuration with current server settings std::ostringstream configJs; configJs << "// Auto-generated configuration\n" << "window.__SERVER_CONFIG__ = {\n" << " apiUrl: 'http://" << m_host << ":" << m_port << "',\n" << " apiBasePath: '/api',\n" << " host: '" << m_host << "',\n" << " port: " << m_port << ",\n" << " uiVersion: '" << uiVersion << "',\n"; // Add authentication method information if (m_authMiddleware) { auto authConfig = m_authMiddleware->getConfig(); std::string authMethod = "none"; switch (authConfig.authMethod) { case AuthMethod::UNIX: authMethod = "unix"; break; case AuthMethod::JWT: authMethod = "jwt"; break; default: authMethod = "none"; break; } configJs << " authMethod: '" << authMethod << "',\n" << " authEnabled: " << (authConfig.authMethod != AuthMethod::NONE ? "true" : "false") << "\n"; } else { configJs << " authMethod: 'none',\n" << " authEnabled: false\n"; } configJs << "};\n"; // No cache for config.js - always fetch fresh res.set_header("Cache-Control", "no-cache, no-store, must-revalidate"); res.set_header("Pragma", "no-cache"); res.set_header("Expires", "0"); res.set_content(configJs.str(), "application/javascript"); }); // Set up file request handler for caching static assets m_httpServer->set_file_request_handler([uiVersion](const httplib::Request& req, httplib::Response& res) { // Add cache headers based on file type and version std::string path = req.path; // For versioned static assets (.js, .css, images), use long cache if (path.find("/_next/") != std::string::npos || path.find(".js") != std::string::npos || path.find(".css") != std::string::npos || path.find(".png") != std::string::npos || path.find(".jpg") != std::string::npos || path.find(".svg") != std::string::npos || path.find(".ico") != std::string::npos || path.find(".woff") != std::string::npos || path.find(".woff2") != std::string::npos || path.find(".ttf") != std::string::npos) { // Long cache (1 year) for static assets res.set_header("Cache-Control", "public, max-age=31536000, immutable"); // Add ETag based on UI version for cache validation res.set_header("ETag", "\"" + uiVersion + "\""); // Check If-None-Match for conditional requests if (req.has_header("If-None-Match")) { std::string clientETag = req.get_header_value("If-None-Match"); if (clientETag == "\"" + uiVersion + "\"") { res.status = 304; // Not Modified return; } } } else if (path.find(".html") != std::string::npos || path == "/ui/" || path == "/ui") { // HTML files should revalidate but can be cached briefly res.set_header("Cache-Control", "public, max-age=0, must-revalidate"); res.set_header("ETag", "\"" + uiVersion + "\""); } }); // Create a handler for UI routes with authentication check auto uiHandler = [this](const httplib::Request& req, httplib::Response& res) { // Check if authentication is enabled if (m_authMiddleware) { auto authConfig = m_authMiddleware->getConfig(); if (authConfig.authMethod != AuthMethod::NONE) { // Authentication is enabled, check if user is authenticated AuthContext authContext = m_authMiddleware->authenticate(req, res); // For Unix auth, we need to check if the user is authenticated // The authenticateUnix function will return a guest context for UI requests // when no Authorization header is present, but we still need to show the login page if (!authContext.authenticated) { // Check if this is a request for a static asset (JS, CSS, images) // These should be served even without authentication to allow the login page to work bool isStaticAsset = false; std::string path = req.path; if (path.find(".js") != std::string::npos || path.find(".css") != std::string::npos || path.find(".png") != std::string::npos || path.find(".jpg") != std::string::npos || path.find(".jpeg") != std::string::npos || path.find(".svg") != std::string::npos || path.find(".ico") != std::string::npos || path.find("/_next/") != std::string::npos) { isStaticAsset = true; } // For static assets, allow them to be served without authentication if (isStaticAsset) { // Continue to serve the file } else { // For HTML requests, redirect to login page if (req.path.find(".html") != std::string::npos || req.path == "/ui/" || req.path == "/ui") { // Serve the login page instead of the requested page std::string loginPagePath = m_uiDir + "/login.html"; if (std::filesystem::exists(loginPagePath)) { std::ifstream loginFile(loginPagePath); if (loginFile.is_open()) { std::string content((std::istreambuf_iterator(loginFile)), std::istreambuf_iterator()); res.set_content(content, "text/html"); return; } } // If login.html doesn't exist, serve a simple login page std::string simpleLoginPage = R"( Login Required

Login Required

Please enter your username to continue.

)"; res.set_content(simpleLoginPage, "text/html"); return; } else { // For non-HTML files, return unauthorized m_authMiddleware->sendAuthError(res, "Authentication required", "AUTH_REQUIRED"); return; } } } } } // If we get here, either auth is disabled or user is authenticated // Serve the requested file std::string filePath = req.path.substr(3); // Remove "/ui" prefix if (filePath.empty() || filePath == "/") { filePath = "/index.html"; } std::string fullPath = m_uiDir + filePath; if (std::filesystem::exists(fullPath) && std::filesystem::is_regular_file(fullPath)) { std::ifstream file(fullPath, std::ios::binary); if (file.is_open()) { std::string content((std::istreambuf_iterator(file)), std::istreambuf_iterator()); // Determine content type based on file extension std::string contentType = "text/plain"; if (filePath.find(".html") != std::string::npos) { contentType = "text/html"; } else if (filePath.find(".js") != std::string::npos) { contentType = "application/javascript"; } else if (filePath.find(".css") != std::string::npos) { contentType = "text/css"; } else if (filePath.find(".png") != std::string::npos) { contentType = "image/png"; } else if (filePath.find(".jpg") != std::string::npos || filePath.find(".jpeg") != std::string::npos) { contentType = "image/jpeg"; } else if (filePath.find(".svg") != std::string::npos) { contentType = "image/svg+xml"; } res.set_content(content, contentType); } else { res.status = 404; res.set_content("File not found", "text/plain"); } } else { // For SPA routing, if the file doesn't exist, serve index.html // This allows Next.js to handle client-side routing std::string indexPath = m_uiDir + "/index.html"; if (std::filesystem::exists(indexPath)) { std::ifstream indexFile(indexPath, std::ios::binary); if (indexFile.is_open()) { std::string content((std::istreambuf_iterator(indexFile)), std::istreambuf_iterator()); res.set_content(content, "text/html"); } else { res.status = 404; res.set_content("File not found", "text/plain"); } } else { res.status = 404; res.set_content("File not found", "text/plain"); } } }; // Set up UI routes with authentication m_httpServer->Get("/ui/.*", uiHandler); // Redirect /ui to /ui/ to ensure proper routing m_httpServer->Get("/ui", [](const httplib::Request& /*req*/, httplib::Response& res) { res.set_redirect("/ui/"); }); } } void Server::setAuthComponents(std::shared_ptr userManager, std::shared_ptr authMiddleware) { m_userManager = userManager; m_authMiddleware = authMiddleware; } void Server::registerAuthEndpoints() { // Login endpoint m_httpServer->Post("/api/auth/login", [this](const httplib::Request& req, httplib::Response& res) { handleLogin(req, res); }); // Logout endpoint m_httpServer->Post("/api/auth/logout", [this](const httplib::Request& req, httplib::Response& res) { handleLogout(req, res); }); // Token validation endpoint m_httpServer->Get("/api/auth/validate", [this](const httplib::Request& req, httplib::Response& res) { handleValidateToken(req, res); }); // Refresh token endpoint m_httpServer->Post("/api/auth/refresh", [this](const httplib::Request& req, httplib::Response& res) { handleRefreshToken(req, res); }); // Get current user endpoint m_httpServer->Get("/api/auth/me", [this](const httplib::Request& req, httplib::Response& res) { handleGetCurrentUser(req, res); }); } void Server::handleLogin(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { if (!m_userManager || !m_authMiddleware) { sendErrorResponse(res, "Authentication system not available", 500, "AUTH_UNAVAILABLE", requestId); return; } // Parse request body nlohmann::json requestJson; try { requestJson = nlohmann::json::parse(req.body); } catch (const nlohmann::json::parse_error& e) { sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId); return; } // Check if using Unix authentication if (m_authMiddleware->getConfig().authMethod == AuthMethod::UNIX) { // For Unix auth, get username and password from request body std::string username = requestJson.value("username", ""); std::string password = requestJson.value("password", ""); if (username.empty()) { sendErrorResponse(res, "Missing username", 400, "MISSING_USERNAME", requestId); return; } // Check if PAM is enabled - if so, password is required if (m_userManager->isPamAuthEnabled() && password.empty()) { sendErrorResponse(res, "Password is required for Unix authentication", 400, "MISSING_PASSWORD", requestId); return; } // Authenticate Unix user (with or without password depending on PAM) auto result = m_userManager->authenticateUnix(username, password); if (!result.success) { sendErrorResponse(res, result.errorMessage, 401, "UNIX_AUTH_FAILED", requestId); return; } // Generate simple token for Unix auth std::string token = "unix_token_" + std::to_string(std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()).count()) + "_" + username; nlohmann::json response = { {"token", token}, {"user", { {"id", result.userId}, {"username", result.username}, {"role", result.role}, {"permissions", result.permissions} }}, {"message", "Unix authentication successful"} }; sendJsonResponse(res, response); return; } // For non-Unix auth, validate required fields if (!requestJson.contains("username") || !requestJson.contains("password")) { sendErrorResponse(res, "Missing username or password", 400, "MISSING_CREDENTIALS", requestId); return; } std::string username = requestJson["username"]; std::string password = requestJson["password"]; // Authenticate user auto result = m_userManager->authenticateUser(username, password); if (!result.success) { sendErrorResponse(res, result.errorMessage, 401, "INVALID_CREDENTIALS", requestId); return; } // Generate JWT token if using JWT auth std::string token; if (m_authMiddleware->getConfig().authMethod == AuthMethod::JWT) { // For now, create a simple token (in a real implementation, use JWT) token = "token_" + std::to_string(std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()).count()) + "_" + username; } nlohmann::json response = { {"token", token}, {"user", { {"id", result.userId}, {"username", result.username}, {"role", result.role}, {"permissions", result.permissions} }}, {"message", "Login successful"} }; sendJsonResponse(res, response); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Login failed: ") + e.what(), 500, "LOGIN_ERROR", requestId); } } void Server::handleLogout(const httplib::Request& /*req*/, httplib::Response& res) { std::string requestId = generateRequestId(); try { // For now, just return success (in a real implementation, invalidate the token) nlohmann::json response = { {"message", "Logout successful"} }; sendJsonResponse(res, response); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Logout failed: ") + e.what(), 500, "LOGOUT_ERROR", requestId); } } void Server::handleValidateToken(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { if (!m_userManager || !m_authMiddleware) { sendErrorResponse(res, "Authentication system not available", 500, "AUTH_UNAVAILABLE", requestId); return; } // Extract token from header std::string authHeader = req.get_header_value("Authorization"); if (authHeader.empty()) { sendErrorResponse(res, "Missing authorization token", 401, "MISSING_TOKEN", requestId); return; } // Simple token validation (in a real implementation, validate JWT) // For now, just check if it starts with "token_" if (authHeader.find("Bearer ") != 0) { sendErrorResponse(res, "Invalid authorization header format", 401, "INVALID_HEADER", requestId); return; } std::string token = authHeader.substr(7); // Remove "Bearer " if (token.find("token_") != 0) { sendErrorResponse(res, "Invalid token", 401, "INVALID_TOKEN", requestId); return; } // Extract username from token (simple format: token_timestamp_username) size_t last_underscore = token.find_last_of('_'); if (last_underscore == std::string::npos) { sendErrorResponse(res, "Invalid token format", 401, "INVALID_TOKEN", requestId); return; } std::string username = token.substr(last_underscore + 1); // Get user info auto userInfo = m_userManager->getUserInfoByUsername(username); if (userInfo.id.empty()) { sendErrorResponse(res, "User not found", 401, "USER_NOT_FOUND", requestId); return; } nlohmann::json response = { {"user", { {"id", userInfo.id}, {"username", userInfo.username}, {"role", userInfo.role}, {"permissions", userInfo.permissions} }}, {"valid", true} }; sendJsonResponse(res, response); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Token validation failed: ") + e.what(), 500, "VALIDATION_ERROR", requestId); } } void Server::handleRefreshToken(const httplib::Request& /*req*/, httplib::Response& res) { std::string requestId = generateRequestId(); try { // For now, just return a new token (in a real implementation, refresh JWT) nlohmann::json response = { {"token", "new_token_" + std::to_string(std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()).count())}, {"message", "Token refreshed successfully"} }; sendJsonResponse(res, response); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Token refresh failed: ") + e.what(), 500, "REFRESH_ERROR", requestId); } } void Server::handleGetCurrentUser(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { if (!m_userManager || !m_authMiddleware) { sendErrorResponse(res, "Authentication system not available", 500, "AUTH_UNAVAILABLE", requestId); return; } // Authenticate the request AuthContext authContext = m_authMiddleware->authenticate(req, res); if (!authContext.authenticated) { sendErrorResponse(res, "Authentication required", 401, "AUTH_REQUIRED", requestId); return; } nlohmann::json response = { {"user", { {"id", authContext.userId}, {"username", authContext.username}, {"role", authContext.role}, {"permissions", authContext.permissions} }} }; sendJsonResponse(res, response); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Get current user failed: ") + e.what(), 500, "USER_ERROR", requestId); } } void Server::setupCORS() { // Use post-routing handler to set CORS headers after the response is generated // This ensures we don't duplicate headers that may be set by other handlers m_httpServer->set_post_routing_handler([](const httplib::Request& /*req*/, httplib::Response& res) { // Only add CORS headers if they haven't been set already if (!res.has_header("Access-Control-Allow-Origin")) { res.set_header("Access-Control-Allow-Origin", "*"); } if (!res.has_header("Access-Control-Allow-Methods")) { res.set_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS"); } if (!res.has_header("Access-Control-Allow-Headers")) { res.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization"); } }); // Handle OPTIONS requests for CORS preflight (API endpoints only) m_httpServer->Options("/api/.*", [](const httplib::Request&, httplib::Response& res) { res.set_header("Access-Control-Allow-Origin", "*"); res.set_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS"); res.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization"); res.status = 200; }); } void Server::handleHealthCheck(const httplib::Request& /*req*/, httplib::Response& res) { try { nlohmann::json response = { {"status", "healthy"}, {"timestamp", std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()).count()}, {"version", sd_rest::VERSION_INFO.version_full} }; sendJsonResponse(res, response); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Health check failed: ") + e.what(), 500); } } void Server::handleApiStatus(const httplib::Request& /*req*/, httplib::Response& res) { try { nlohmann::json response = { {"server", { {"running", m_isRunning.load()}, {"host", m_host}, {"port", m_port} }}, {"generation_queue", { {"running", m_generationQueue ? m_generationQueue->isRunning() : false}, {"queue_size", m_generationQueue ? m_generationQueue->getQueueSize() : 0}, {"active_generations", m_generationQueue ? m_generationQueue->getActiveGenerations() : 0} }}, {"models", { {"loaded_count", m_modelManager ? m_modelManager->getLoadedModelsCount() : 0}, {"available_count", m_modelManager ? m_modelManager->getAvailableModelsCount() : 0} }} }; sendJsonResponse(res, response); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Status check failed: ") + e.what(), 500); } } void Server::handleVersion(const httplib::Request& /*req*/, httplib::Response& res) { try { nlohmann::json response = { {"version", sd_rest::VERSION_INFO.version_full}, {"type", sd_rest::VERSION_INFO.version_type}, {"commit", { {"short", sd_rest::VERSION_INFO.commit_short}, {"full", sd_rest::VERSION_INFO.commit_full} }}, {"branch", sd_rest::VERSION_INFO.branch}, {"clean", sd_rest::VERSION_INFO.is_clean}, {"build_time", sd_rest::VERSION_INFO.build_time} }; sendJsonResponse(res, response); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Version check failed: ") + e.what(), 500); } } // Helper function to convert ModelDetails vector to JSON array nlohmann::json Server::modelDetailsToJson(const std::vector& modelDetails) { nlohmann::json jsonArray = nlohmann::json::array(); for (const auto& detail : modelDetails) { nlohmann::json modelJson = { {"name", detail.name}, {"exists", detail.exists}, {"type", detail.type}, {"file_size", detail.file_size} }; // Handle path and sha256 separately to avoid type mismatch if (detail.exists) { modelJson["path"] = detail.path; modelJson["sha256"] = detail.sha256; } else { modelJson["path"] = nullptr; modelJson["sha256"] = ""; } // Add conditional fields for required/recommended models if (detail.is_required) { modelJson["is_required"] = true; } if (detail.is_recommended) { modelJson["is_recommended"] = true; } jsonArray.push_back(modelJson); } return jsonArray; } // Helper function to determine which recommended fields to include based on architecture std::map Server::getRecommendedModelFields(const std::string& architecture) { std::map recommendedFields; // Initialize all fields as false (will be set to null if not applicable) recommendedFields["recommended_vae"] = false; recommendedFields["recommended_clip_l"] = false; recommendedFields["recommended_clip_g"] = false; recommendedFields["recommended_t5xxl"] = false; recommendedFields["recommended_clip_vision"] = false; recommendedFields["recommended_qwen2vl"] = false; // Architecture-specific field inclusion based on actual architecture strings if (architecture.find("Stable Diffusion 1.5") != std::string::npos) { // SD 1.x: recommended_vae only recommendedFields["recommended_vae"] = true; } else if (architecture.find("Stable Diffusion XL") != std::string::npos) { // SDXL: recommended_vae only recommendedFields["recommended_vae"] = true; } else if (architecture.find("Modern Architecture") != std::string::npos || architecture.find("Flux Dev") != std::string::npos || architecture.find("Flux Chroma") != std::string::npos) { // FLUX/SD3/Modern Architecture: recommended_vae, recommended_clip_l, recommended_t5xxl recommendedFields["recommended_vae"] = true; recommendedFields["recommended_clip_l"] = true; recommendedFields["recommended_t5xxl"] = true; } else if (architecture.find("SD 3") != std::string::npos) { // SD3: recommended_vae, recommended_clip_l, recommended_clip_g, recommended_t5xxl recommendedFields["recommended_vae"] = true; recommendedFields["recommended_clip_l"] = true; recommendedFields["recommended_clip_g"] = true; recommendedFields["recommended_t5xxl"] = true; } else if (architecture.find("Wan") != std::string::npos) { // Wan models: recommended_vae, recommended_t5xxl, recommended_clip_vision recommendedFields["recommended_vae"] = true; recommendedFields["recommended_t5xxl"] = true; recommendedFields["recommended_clip_vision"] = true; } else if (architecture.find("Qwen") != std::string::npos) { // Qwen models: recommended_vae, recommended_qwen2vl recommendedFields["recommended_vae"] = true; recommendedFields["recommended_qwen2vl"] = true; } // For UNKNOWN architecture, keep all fields false return recommendedFields; } // Helper function to populate recommended models with existence information void Server::populateRecommendedModels(nlohmann::json& response, const ModelManager::ModelInfo& modelInfo) { if (modelInfo.requiredModels.empty()) { return; } // Check existence of required models auto requiredModelsDetails = m_modelManager->checkRequiredModelsExistence(modelInfo.requiredModels); // Get the recommended fields for this architecture auto recommendedFields = getRecommendedModelFields(modelInfo.architecture); // Group models by type std::map> modelsByType; for (const auto& detail : requiredModelsDetails) { modelsByType[detail.type].push_back(detail); } // Populate recommended fields based on model types and architecture requirements for (const auto& [type, models] : modelsByType) { if (type == "VAE" && recommendedFields["recommended_vae"]) { response["recommended_vae"] = modelDetailsToJson(models); } else if (type == "CLIP-L" && recommendedFields["recommended_clip_l"]) { response["recommended_clip_l"] = modelDetailsToJson(models); } else if (type == "CLIP-G" && recommendedFields["recommended_clip_g"]) { response["recommended_clip_g"] = modelDetailsToJson(models); } else if (type == "T5XXL" && recommendedFields["recommended_t5xxl"]) { response["recommended_t5xxl"] = modelDetailsToJson(models); } else if (type == "CLIP-Vision" && recommendedFields["recommended_clip_vision"]) { response["recommended_clip_vision"] = modelDetailsToJson(models); } else if (type == "Qwen2VL" && recommendedFields["recommended_qwen2vl"]) { response["recommended_qwen2vl"] = modelDetailsToJson(models); } } // Set non-applicable fields to null for (const auto& [fieldName, shouldInclude] : recommendedFields) { if (!shouldInclude || !response.contains(fieldName)) { response[fieldName] = nlohmann::json(nullptr); } } } void Server::handleModelsList(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { if (!m_modelManager) { sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId); return; } // Parse query parameters for enhanced filtering std::string typeFilter = req.get_param_value("type"); std::string searchQuery = req.get_param_value("search"); std::string sortBy = req.get_param_value("sort_by"); std::string sortOrder = req.get_param_value("sort_order"); std::string dateFilter = req.get_param_value("date"); std::string sizeFilter = req.get_param_value("size"); // Pagination parameters - only apply if limit is explicitly provided int page = 1; int limit = 50; bool usePagination = false; try { if (!req.get_param_value("limit").empty()) { limit = std::stoi(req.get_param_value("limit")); // Special case: limit<=0 means return all models (no pagination) if (limit <= 0) { usePagination = false; limit = INT_MAX; // Set to very large number to effectively disable pagination } else { usePagination = true; if (!req.get_param_value("page").empty()) { page = std::stoi(req.get_param_value("page")); if (page < 1) page = 1; } } } } catch (const std::exception& e) { sendErrorResponse(res, "Invalid pagination parameters", 400, "INVALID_PAGINATION", requestId); return; } // Filter parameters bool includeLoaded = req.get_param_value("loaded") == "true"; bool includeUnloaded = req.get_param_value("unloaded") == "true"; (void)req.get_param_value("include_metadata"); // unused but kept for API compatibility (void)req.get_param_value("include_thumbnails"); // unused but kept for API compatibility // Get all models auto allModels = m_modelManager->getAllModels(); nlohmann::json models = nlohmann::json::array(); // Apply filters and build response for (const auto& pair : allModels) { const auto& modelInfo = pair.second; // Apply type filter if (!typeFilter.empty()) { ModelType filterType = ModelManager::stringToModelType(typeFilter); if (modelInfo.type != filterType) continue; } // Apply loaded/unloaded filters if (includeLoaded && !modelInfo.isLoaded) continue; if (includeUnloaded && modelInfo.isLoaded) continue; // Apply search filter (case-insensitive search in name and description) if (!searchQuery.empty()) { std::string searchLower = searchQuery; std::transform(searchLower.begin(), searchLower.end(), searchLower.begin(), ::tolower); std::string nameLower = modelInfo.name; std::transform(nameLower.begin(), nameLower.end(), nameLower.begin(), ::tolower); std::string descLower = modelInfo.description; std::transform(descLower.begin(), descLower.end(), descLower.begin(), ::tolower); if (nameLower.find(searchLower) == std::string::npos && descLower.find(searchLower) == std::string::npos) { continue; } } // Apply date filter (simplified - expects "recent", "old", or YYYY-MM-DD) if (!dateFilter.empty()) { auto now = std::filesystem::file_time_type::clock::now(); auto modelTime = modelInfo.modifiedAt; auto duration = std::chrono::duration_cast(now - modelTime).count(); if (dateFilter == "recent" && duration > 24 * 7) continue; // Older than 1 week if (dateFilter == "old" && duration < 24 * 30) continue; // Newer than 1 month } // Apply size filter (expects "small", "medium", "large", or size in MB) if (!sizeFilter.empty()) { double sizeMB = modelInfo.fileSize / (1024.0 * 1024.0); if (sizeFilter == "small" && sizeMB > 1024) continue; // > 1GB if (sizeFilter == "medium" && (sizeMB < 1024 || sizeMB > 4096)) continue; // < 1GB or > 4GB if (sizeFilter == "large" && sizeMB < 4096) continue; // < 4GB // Try to parse as specific size in MB try { double maxSizeMB = std::stod(sizeFilter); if (sizeMB > maxSizeMB) continue; } catch (...) { // Ignore if parsing fails } } // Build model JSON with enhanced structure nlohmann::json modelJson = { {"name", modelInfo.name}, {"type", ModelManager::modelTypeToString(modelInfo.type)}, {"file_size", modelInfo.fileSize}, {"file_size_mb", modelInfo.fileSize / (1024.0 * 1024.0)}, {"sha256", modelInfo.sha256.empty() ? nullptr : nlohmann::json(modelInfo.sha256)}, {"sha256_short", (modelInfo.sha256.empty() || modelInfo.sha256.length() < 10) ? nullptr : nlohmann::json(modelInfo.sha256.substr(0, 10))} }; // Add architecture information if available (checkpoints only) if (!modelInfo.architecture.empty()) { modelJson["architecture"] = modelInfo.architecture; modelJson["recommended_vae"] = modelInfo.recommendedVAE.empty() ? nullptr : nlohmann::json(modelInfo.recommendedVAE); if (modelInfo.recommendedWidth > 0) { modelJson["recommended_width"] = modelInfo.recommendedWidth; } if (modelInfo.recommendedHeight > 0) { modelJson["recommended_height"] = modelInfo.recommendedHeight; } if (modelInfo.recommendedSteps > 0) { modelJson["recommended_steps"] = modelInfo.recommendedSteps; } if (!modelInfo.recommendedSampler.empty()) { modelJson["recommended_sampler"] = modelInfo.recommendedSampler; } // Enhanced model information with existence checking if (!modelInfo.requiredModels.empty()) { auto requiredModelsDetails = m_modelManager->checkRequiredModelsExistence(modelInfo.requiredModels); modelJson["required_models"] = modelDetailsToJson(requiredModelsDetails); // Populate recommended models based on architecture populateRecommendedModels(modelJson, modelInfo); } // Backward compatibility - keep existing fields if (!modelInfo.missingModels.empty()) { modelJson["missing_models"] = modelInfo.missingModels; modelJson["has_missing_dependencies"] = true; } else { modelJson["has_missing_dependencies"] = false; } } models.push_back(modelJson); } // Apply sorting if (!sortBy.empty()) { std::sort(models.begin(), models.end(), [&sortBy, &sortOrder](const nlohmann::json& a, const nlohmann::json& b) { bool ascending = sortOrder != "desc"; if (sortBy == "name") { return ascending ? a["name"] < b["name"] : a["name"] > b["name"]; } else if (sortBy == "size") { return ascending ? a["file_size"] < b["file_size"] : a["file_size"] > b["file_size"]; } else if (sortBy == "date") { return ascending ? a["last_modified"] < b["last_modified"] : a["last_modified"] > b["last_modified"]; } else if (sortBy == "type") { return ascending ? a["type"] < b["type"] : a["type"] > b["type"]; } else if (sortBy == "loaded") { return ascending ? a["is_loaded"] < b["is_loaded"] : a["is_loaded"] > b["is_loaded"]; } return false; }); } // Apply pagination only if limit parameter was provided int totalCount = models.size(); nlohmann::json paginatedModels = nlohmann::json::array(); nlohmann::json paginationInfo = nlohmann::json::object(); if (usePagination) { // Apply pagination int totalPages = (totalCount + limit - 1) / limit; int startIndex = (page - 1) * limit; int endIndex = std::min(startIndex + limit, totalCount); for (int i = startIndex; i < endIndex; ++i) { paginatedModels.push_back(models[i]); } paginationInfo = { {"page", page}, {"limit", limit}, {"total_count", totalCount}, {"total_pages", totalPages}, {"has_next", page < totalPages}, {"has_prev", page > 1} }; } else { // Return all models without pagination paginatedModels = models; paginationInfo = { {"page", 1}, {"limit", totalCount}, {"total_count", totalCount}, {"total_pages", 1}, {"has_next", false}, {"has_prev", false} }; } // Build comprehensive response nlohmann::json response = { {"models", paginatedModels}, {"pagination", paginationInfo}, {"filters_applied", { {"type", typeFilter.empty() ? nlohmann::json(nullptr) : nlohmann::json(typeFilter)}, {"search", searchQuery.empty() ? nlohmann::json(nullptr) : nlohmann::json(searchQuery)}, {"date", dateFilter.empty() ? nlohmann::json(nullptr) : nlohmann::json(dateFilter)}, {"size", sizeFilter.empty() ? nlohmann::json(nullptr) : nlohmann::json(sizeFilter)}, {"loaded", includeLoaded ? nlohmann::json(true) : nlohmann::json(nullptr)}, {"unloaded", includeUnloaded ? nlohmann::json(true) : nlohmann::json(nullptr)} }}, {"sorting", { {"sort_by", sortBy.empty() ? "name" : nlohmann::json(sortBy)}, {"sort_order", sortOrder.empty() ? "asc" : nlohmann::json(sortOrder)} }}, {"statistics", { {"loaded_count", m_modelManager->getLoadedModelsCount()}, {"available_count", m_modelManager->getAvailableModelsCount()} }}, {"request_id", requestId} }; sendJsonResponse(res, response); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Failed to list models: ") + e.what(), 500, "MODEL_LIST_ERROR", requestId); } } void Server::handleQueueStatus(const httplib::Request& /*req*/, httplib::Response& res) { try { if (!m_generationQueue) { sendErrorResponse(res, "Generation queue not available", 500); return; } // Get detailed queue status auto jobs = m_generationQueue->getQueueStatus(); // Convert jobs to JSON nlohmann::json jobsJson = nlohmann::json::array(); for (const auto& job : jobs) { std::string statusStr; switch (job.status) { case GenerationStatus::QUEUED: statusStr = "queued"; break; case GenerationStatus::PROCESSING: statusStr = "processing"; break; case GenerationStatus::COMPLETED: statusStr = "completed"; break; case GenerationStatus::FAILED: statusStr = "failed"; break; } // Convert time points to timestamps auto queuedTime = std::chrono::duration_cast( job.queuedTime.time_since_epoch()).count(); auto startTime = std::chrono::duration_cast( job.startTime.time_since_epoch()).count(); auto endTime = std::chrono::duration_cast( job.endTime.time_since_epoch()).count(); jobsJson.push_back({ {"id", job.id}, {"status", statusStr}, {"prompt", job.prompt}, {"queued_time", queuedTime}, {"start_time", startTime > 0 ? nlohmann::json(startTime) : nlohmann::json(nullptr)}, {"end_time", endTime > 0 ? nlohmann::json(endTime) : nlohmann::json(nullptr)}, {"position", job.position}, {"progress", job.progress} }); } nlohmann::json response = { {"queue", { {"size", m_generationQueue->getQueueSize()}, {"active_generations", m_generationQueue->getActiveGenerations()}, {"running", m_generationQueue->isRunning()}, {"jobs", jobsJson} }} }; sendJsonResponse(res, response); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Queue status check failed: ") + e.what(), 500); } } void Server::handleJobStatus(const httplib::Request& req, httplib::Response& res) { try { if (!m_generationQueue) { sendErrorResponse(res, "Generation queue not available", 500); return; } // Extract job ID from URL path std::string jobId = req.matches[1].str(); if (jobId.empty()) { sendErrorResponse(res, "Missing job ID", 400); return; } // Get job information auto jobInfo = m_generationQueue->getJobInfo(jobId); if (jobInfo.id.empty()) { sendErrorResponse(res, "Job not found", 404); return; } // Convert status to string std::string statusStr; switch (jobInfo.status) { case GenerationStatus::QUEUED: statusStr = "queued"; break; case GenerationStatus::PROCESSING: statusStr = "processing"; break; case GenerationStatus::COMPLETED: statusStr = "completed"; break; case GenerationStatus::FAILED: statusStr = "failed"; break; } // Convert time points to timestamps auto queuedTime = std::chrono::duration_cast( jobInfo.queuedTime.time_since_epoch()).count(); auto startTime = std::chrono::duration_cast( jobInfo.startTime.time_since_epoch()).count(); auto endTime = std::chrono::duration_cast( jobInfo.endTime.time_since_epoch()).count(); // Create download URLs for output files nlohmann::json outputUrls = nlohmann::json::array(); for (const auto& filePath : jobInfo.outputFiles) { // Extract filename from full path std::filesystem::path p(filePath); std::string filename = p.filename().string(); // Create download URL std::string url = "/api/queue/job/" + jobInfo.id + "/output/" + filename; nlohmann::json fileInfo = { {"filename", filename}, {"url", url}, {"path", filePath} }; outputUrls.push_back(fileInfo); } nlohmann::json response = { {"job", { {"id", jobInfo.id}, {"status", statusStr}, {"prompt", jobInfo.prompt}, {"queued_time", queuedTime}, {"start_time", startTime > 0 ? nlohmann::json(startTime) : nlohmann::json(nullptr)}, {"end_time", endTime > 0 ? nlohmann::json(endTime) : nlohmann::json(nullptr)}, {"position", jobInfo.position}, {"outputs", outputUrls}, {"error_message", jobInfo.errorMessage}, {"progress", jobInfo.progress} }} }; sendJsonResponse(res, response); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Job status check failed: ") + e.what(), 500); } } void Server::handleCancelJob(const httplib::Request& req, httplib::Response& res) { try { if (!m_generationQueue) { sendErrorResponse(res, "Generation queue not available", 500); return; } // Parse JSON request body nlohmann::json requestJson = nlohmann::json::parse(req.body); // Validate required fields if (!requestJson.contains("job_id") || !requestJson["job_id"].is_string()) { sendErrorResponse(res, "Missing or invalid 'job_id' field", 400); return; } std::string jobId = requestJson["job_id"]; // Try to cancel the job bool cancelled = m_generationQueue->cancelJob(jobId); if (cancelled) { nlohmann::json response = { {"status", "success"}, {"message", "Job cancelled successfully"}, {"job_id", jobId} }; sendJsonResponse(res, response); } else { nlohmann::json response = { {"status", "error"}, {"message", "Job not found or already processing"}, {"job_id", jobId} }; sendJsonResponse(res, response, 404); } } catch (const nlohmann::json::parse_error& e) { sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Job cancellation failed: ") + e.what(), 500); } } void Server::handleClearQueue(const httplib::Request& /*req*/, httplib::Response& res) { try { if (!m_generationQueue) { sendErrorResponse(res, "Generation queue not available", 500); return; } // Clear the queue m_generationQueue->clearQueue(); nlohmann::json response = { {"status", "success"}, {"message", "Queue cleared successfully"} }; sendJsonResponse(res, response); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Queue clear failed: ") + e.what(), 500); } } void Server::handleDownloadOutput(const httplib::Request& req, httplib::Response& res) { try { // Extract job ID and filename from URL path if (req.matches.size() < 3) { sendErrorResponse(res, "Invalid request: job ID and filename required", 400, "INVALID_REQUEST", ""); return; } std::string jobId = req.matches[1]; std::string filename = req.matches[2]; // Validate inputs if (jobId.empty() || filename.empty()) { sendErrorResponse(res, "Job ID and filename cannot be empty", 400, "INVALID_PARAMETERS", ""); return; } // Construct absolute file path using the same logic as when saving: // {outputDir}/{jobId}/{filename} std::string fullPath = std::filesystem::absolute(m_outputDir + "/" + jobId + "/" + filename).string(); // Log the request for debugging std::cout << "Image download request: jobId=" << jobId << ", filename=" << filename << ", fullPath=" << fullPath << std::endl; // Check if file exists if (!std::filesystem::exists(fullPath)) { std::cerr << "Output file not found: " << fullPath << std::endl; sendErrorResponse(res, "Output file not found: " + filename, 404, "FILE_NOT_FOUND", ""); return; } // Check file size to detect zero-byte files auto fileSize = std::filesystem::file_size(fullPath); if (fileSize == 0) { std::cerr << "Output file is zero bytes: " << fullPath << std::endl; sendErrorResponse(res, "Output file is empty (corrupted generation)", 500, "EMPTY_FILE", ""); return; } // Check if file is accessible std::ifstream file(fullPath, std::ios::binary); if (!file.is_open()) { std::cerr << "Failed to open output file: " << fullPath << std::endl; sendErrorResponse(res, "Output file not accessible", 500, "FILE_ACCESS_ERROR", ""); return; } // Read file contents std::string fileContent; try { fileContent = std::string( std::istreambuf_iterator(file), std::istreambuf_iterator() ); file.close(); } catch (const std::exception& e) { std::cerr << "Failed to read file content: " << e.what() << std::endl; sendErrorResponse(res, "Failed to read file content", 500, "FILE_READ_ERROR", ""); return; } // Verify we actually read data if (fileContent.empty()) { std::cerr << "File content is empty after read: " << fullPath << std::endl; sendErrorResponse(res, "File content is empty after read", 500, "EMPTY_CONTENT", ""); return; } // Determine content type based on file extension std::string contentType = "application/octet-stream"; if (Utils::endsWith(filename, ".png")) { contentType = "image/png"; } else if (Utils::endsWith(filename, ".jpg") || Utils::endsWith(filename, ".jpeg")) { contentType = "image/jpeg"; } else if (Utils::endsWith(filename, ".mp4")) { contentType = "video/mp4"; } else if (Utils::endsWith(filename, ".gif")) { contentType = "image/gif"; } else if (Utils::endsWith(filename, ".webp")) { contentType = "image/webp"; } // Set response headers for proper browser handling res.set_header("Content-Type", contentType); res.set_header("Content-Length", std::to_string(fileContent.length())); res.set_header("Cache-Control", "public, max-age=3600"); // Cache for 1 hour res.set_header("Access-Control-Allow-Origin", "*"); // CORS for image access // Uncomment if you want to force download instead of inline display: // res.set_header("Content-Disposition", "attachment; filename=\"" + filename + "\""); // Set the content res.set_content(fileContent, contentType); res.status = 200; std::cout << "Successfully served image: " << filename << " (" << fileContent.length() << " bytes)" << std::endl; } catch (const std::exception& e) { std::cerr << "Exception in handleDownloadOutput: " << e.what() << std::endl; sendErrorResponse(res, std::string("Failed to download file: ") + e.what(), 500, "DOWNLOAD_ERROR", ""); } } void Server::handleJobOutput(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { // Extract job ID from URL path if (req.matches.size() < 2) { sendErrorResponse(res, "Invalid request: job ID required", 400, "INVALID_REQUEST", requestId); return; } std::string jobId = req.matches[1].str(); // Validate job ID if (jobId.empty()) { sendErrorResponse(res, "Job ID cannot be empty", 400, "INVALID_PARAMETERS", requestId); return; } // Log the request for debugging std::cout << "Job output request: jobId=" << jobId << std::endl; // Get job information to check if it exists and is completed if (!m_generationQueue) { sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId); return; } auto jobInfo = m_generationQueue->getJobInfo(jobId); if (jobInfo.id.empty()) { sendErrorResponse(res, "Job not found", 404, "JOB_NOT_FOUND", requestId); return; } // Check if job is completed if (jobInfo.status != GenerationStatus::COMPLETED) { std::string statusStr; switch (jobInfo.status) { case GenerationStatus::QUEUED: statusStr = "queued"; break; case GenerationStatus::PROCESSING: statusStr = "processing"; break; case GenerationStatus::FAILED: statusStr = "failed"; break; default: statusStr = "unknown"; break; } nlohmann::json response = { {"error", { {"message", "Job not completed yet"}, {"status_code", 400}, {"error_code", "JOB_NOT_COMPLETED"}, {"request_id", requestId}, {"timestamp", std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()).count()}, {"job_status", statusStr} }} }; sendJsonResponse(res, response, 400); return; } // Check if job has output files if (jobInfo.outputFiles.empty()) { sendErrorResponse(res, "No output files found for completed job", 404, "NO_OUTPUT_FILES", requestId); return; } // For simplicity, return the first output file // In a more complex implementation, we could return all files or allow file selection std::string firstOutputFile = jobInfo.outputFiles[0]; // Extract filename from full path std::filesystem::path filePath(firstOutputFile); std::string filename = filePath.filename().string(); // Construct absolute file path std::string fullPath = std::filesystem::absolute(firstOutputFile).string(); // Check if file exists if (!std::filesystem::exists(fullPath)) { std::cerr << "Output file not found: " << fullPath << std::endl; sendErrorResponse(res, "Output file not found: " + filename, 404, "FILE_NOT_FOUND", requestId); return; } // Check file size to detect zero-byte files auto fileSize = std::filesystem::file_size(fullPath); if (fileSize == 0) { std::cerr << "Output file is zero bytes: " << fullPath << std::endl; sendErrorResponse(res, "Output file is empty (corrupted generation)", 500, "EMPTY_FILE", requestId); return; } // Check if file is accessible std::ifstream file(fullPath, std::ios::binary); if (!file.is_open()) { std::cerr << "Failed to open output file: " << fullPath << std::endl; sendErrorResponse(res, "Output file not accessible", 500, "FILE_ACCESS_ERROR", requestId); return; } // Read file contents std::string fileContent; try { fileContent = std::string( std::istreambuf_iterator(file), std::istreambuf_iterator() ); file.close(); } catch (const std::exception& e) { std::cerr << "Failed to read file content: " << e.what() << std::endl; sendErrorResponse(res, "Failed to read file content", 500, "FILE_READ_ERROR", requestId); return; } // Verify we actually read data if (fileContent.empty()) { std::cerr << "File content is empty after read: " << fullPath << std::endl; sendErrorResponse(res, "File content is empty after read", 500, "EMPTY_CONTENT", requestId); return; } // Determine content type based on file extension std::string contentType = "application/octet-stream"; if (Utils::endsWith(filename, ".png")) { contentType = "image/png"; } else if (Utils::endsWith(filename, ".jpg") || Utils::endsWith(filename, ".jpeg")) { contentType = "image/jpeg"; } else if (Utils::endsWith(filename, ".mp4")) { contentType = "video/mp4"; } else if (Utils::endsWith(filename, ".gif")) { contentType = "image/gif"; } else if (Utils::endsWith(filename, ".webp")) { contentType = "image/webp"; } // Set response headers for proper browser handling res.set_header("Content-Type", contentType); res.set_header("Content-Length", std::to_string(fileContent.length())); res.set_header("Cache-Control", "public, max-age=3600"); // Cache for 1 hour res.set_header("Access-Control-Allow-Origin", "*"); // CORS for image access // Set additional metadata headers res.set_header("X-Job-ID", jobId); res.set_header("X-Filename", filename); res.set_header("X-File-Size", std::to_string(fileSize)); // If there are multiple files, indicate this if (jobInfo.outputFiles.size() > 1) { res.set_header("X-Total-Files", std::to_string(jobInfo.outputFiles.size())); res.set_header("X-File-Index", "1"); } // Set the content res.set_content(fileContent, contentType); res.status = 200; std::cout << "Successfully served job output: jobId=" << jobId << ", filename=" << filename << " (" << fileContent.length() << " bytes)" << std::endl; } catch (const std::exception& e) { std::cerr << "Exception in handleJobOutput: " << e.what() << std::endl; sendErrorResponse(res, std::string("Failed to get job output: ") + e.what(), 500, "OUTPUT_ERROR", requestId); } } void Server::handleImageResize(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { // Parse JSON request body nlohmann::json requestJson = nlohmann::json::parse(req.body); // Validate required fields if (!requestJson.contains("image") || !requestJson["image"].is_string()) { sendErrorResponse(res, "Missing or invalid 'image' field", 400, "INVALID_PARAMETERS", requestId); return; } if (!requestJson.contains("width") || !requestJson["width"].is_number_integer()) { sendErrorResponse(res, "Missing or invalid 'width' field", 400, "INVALID_PARAMETERS", requestId); return; } if (!requestJson.contains("height") || !requestJson["height"].is_number_integer()) { sendErrorResponse(res, "Missing or invalid 'height' field", 400, "INVALID_PARAMETERS", requestId); return; } std::string imageInput = requestJson["image"]; int targetWidth = requestJson["width"]; int targetHeight = requestJson["height"]; // Validate dimensions if (targetWidth < 1 || targetWidth > 4096) { sendErrorResponse(res, "Width must be between 1 and 4096", 400, "INVALID_DIMENSIONS", requestId); return; } if (targetHeight < 1 || targetHeight > 4096) { sendErrorResponse(res, "Height must be between 1 and 4096", 400, "INVALID_DIMENSIONS", requestId); return; } // Load the source image auto [imageData, sourceWidth, sourceHeight, sourceChannels, success, loadError] = loadImageFromInput(imageInput); if (!success) { sendErrorResponse(res, "Failed to load image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId); return; } // Convert image data to stb_image format for processing int channels = 3; // Force RGB size_t sourceSize = sourceWidth * sourceHeight * channels; std::vector sourcePixels(sourceSize); std::memcpy(sourcePixels.data(), imageData.data(), std::min(imageData.size(), sourceSize)); // Resize the image using stb_image_resize if available, otherwise use simple scaling std::vector resizedPixels(targetWidth * targetHeight * channels); // Simple nearest-neighbor scaling for now (can be improved with better algorithms) float xScale = static_cast(sourceWidth) / targetWidth; float yScale = static_cast(sourceHeight) / targetHeight; for (int y = 0; y < targetHeight; y++) { for (int x = 0; x < targetWidth; x++) { int sourceX = static_cast(x * xScale); int sourceY = static_cast(y * yScale); // Clamp to source bounds sourceX = std::min(sourceX, sourceWidth - 1); sourceY = std::min(sourceY, sourceHeight - 1); for (int c = 0; c < channels; c++) { resizedPixels[(y * targetWidth + x) * channels + c] = sourcePixels[(sourceY * sourceWidth + sourceX) * channels + c]; } } } // Convert resized image to base64 std::string base64Data = Utils::base64Encode(resizedPixels); // Determine MIME type based on input std::string mimeType = "image/jpeg"; // default if (Utils::startsWith(imageInput, "data:image/png")) { mimeType = "image/png"; } else if (Utils::startsWith(imageInput, "data:image/gif")) { mimeType = "image/gif"; } else if (Utils::startsWith(imageInput, "data:image/webp")) { mimeType = "image/webp"; } else if (Utils::startsWith(imageInput, "data:image/bmp")) { mimeType = "image/bmp"; } // Create data URL format std::string dataUrl = "data:" + mimeType + ";base64," + base64Data; // Build response nlohmann::json response = { {"success", true}, {"original_width", sourceWidth}, {"original_height", sourceHeight}, {"resized_width", targetWidth}, {"resized_height", targetHeight}, {"mime_type", mimeType}, {"base64_data", dataUrl}, {"file_size_bytes", resizedPixels.size()}, {"request_id", requestId} }; sendJsonResponse(res, response, 200); std::cout << "Successfully resized image from " << sourceWidth << "x" << sourceHeight << " to " << targetWidth << "x" << targetHeight << " (" << resizedPixels.size() << " bytes)" << std::endl; } catch (const nlohmann::json::parse_error& e) { sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId); } catch (const std::exception& e) { std::cerr << "Exception in handleImageResize: " << e.what() << std::endl; sendErrorResponse(res, std::string("Failed to resize image: ") + e.what(), 500, "INTERNAL_ERROR", requestId); } } void Server::handleImageCrop(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { // Parse JSON request body nlohmann::json requestJson = nlohmann::json::parse(req.body); // Validate required fields if (!requestJson.contains("image") || !requestJson["image"].is_string()) { sendErrorResponse(res, "Missing or invalid 'image' field", 400, "INVALID_PARAMETERS", requestId); return; } if (!requestJson.contains("x") || !requestJson["x"].is_number_integer()) { sendErrorResponse(res, "Missing or invalid 'x' field", 400, "INVALID_PARAMETERS", requestId); return; } if (!requestJson.contains("y") || !requestJson["y"].is_number_integer()) { sendErrorResponse(res, "Missing or invalid 'y' field", 400, "INVALID_PARAMETERS", requestId); return; } if (!requestJson.contains("width") || !requestJson["width"].is_number_integer()) { sendErrorResponse(res, "Missing or invalid 'width' field", 400, "INVALID_PARAMETERS", requestId); return; } if (!requestJson.contains("height") || !requestJson["height"].is_number_integer()) { sendErrorResponse(res, "Missing or invalid 'height' field", 400, "INVALID_PARAMETERS", requestId); return; } std::string imageInput = requestJson["image"]; int cropX = requestJson["x"]; int cropY = requestJson["y"]; int cropWidth = requestJson["width"]; int cropHeight = requestJson["height"]; // Load the source image auto [imageData, sourceWidth, sourceHeight, sourceChannels, success, loadError] = loadImageFromInput(imageInput); if (!success) { sendErrorResponse(res, "Failed to load image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId); return; } // Validate crop dimensions if (cropX < 0 || cropY < 0) { sendErrorResponse(res, "Crop coordinates must be non-negative", 400, "INVALID_CROP_AREA", requestId); return; } if (cropX + cropWidth > sourceWidth || cropY + cropHeight > sourceHeight) { sendErrorResponse(res, "Crop area exceeds image dimensions", 400, "INVALID_CROP_AREA", requestId); return; } if (cropWidth < 1 || cropHeight < 1) { sendErrorResponse(res, "Crop width and height must be at least 1", 400, "INVALID_CROP_AREA", requestId); return; } // Convert image data to stb_image format for processing int channels = 3; // Force RGB size_t sourceSize = sourceWidth * sourceHeight * channels; std::vector sourcePixels(sourceSize); std::memcpy(sourcePixels.data(), imageData.data(), std::min(imageData.size(), sourceSize)); // Crop the image std::vector croppedPixels(cropWidth * cropHeight * channels); for (int y = 0; y < cropHeight; y++) { for (int x = 0; x < cropWidth; x++) { int sourceX = cropX + x; int sourceY = cropY + y; for (int c = 0; c < channels; c++) { croppedPixels[(y * cropWidth + x) * channels + c] = sourcePixels[(sourceY * sourceWidth + sourceX) * channels + c]; } } } // Convert cropped image to base64 std::string base64Data = Utils::base64Encode(croppedPixels); // Determine MIME type based on input std::string mimeType = "image/jpeg"; // default if (Utils::startsWith(imageInput, "data:image/png")) { mimeType = "image/png"; } else if (Utils::startsWith(imageInput, "data:image/gif")) { mimeType = "image/gif"; } else if (Utils::startsWith(imageInput, "data:image/webp")) { mimeType = "image/webp"; } else if (Utils::startsWith(imageInput, "data:image/bmp")) { mimeType = "image/bmp"; } // Create data URL format std::string dataUrl = "data:" + mimeType + ";base64," + base64Data; // Build response nlohmann::json response = { {"success", true}, {"original_width", sourceWidth}, {"original_height", sourceHeight}, {"crop_x", cropX}, {"crop_y", cropY}, {"cropped_width", cropWidth}, {"cropped_height", cropHeight}, {"mime_type", mimeType}, {"base64_data", dataUrl}, {"file_size_bytes", croppedPixels.size()}, {"request_id", requestId} }; sendJsonResponse(res, response, 200); std::cout << "Successfully cropped image from " << sourceWidth << "x" << sourceHeight << " to " << cropWidth << "x" << cropHeight << " at (" << cropX << "," << cropY << ")" << " (" << croppedPixels.size() << " bytes)" << std::endl; } catch (const nlohmann::json::parse_error& e) { sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId); } catch (const std::exception& e) { std::cerr << "Exception in handleImageCrop: " << e.what() << std::endl; sendErrorResponse(res, std::string("Failed to crop image: ") + e.what(), 500, "INTERNAL_ERROR", requestId); } } void Server::handleDownloadImageFromUrl(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { // Parse query parameters std::string imageUrl = req.get_param_value("url"); if (imageUrl.empty()) { sendErrorResponse(res, "Missing 'url' parameter", 400, "MISSING_URL", requestId); return; } // Basic URL format validation if (!Utils::startsWith(imageUrl, "http://") && !Utils::startsWith(imageUrl, "https://")) { sendErrorResponse(res, "Invalid URL format. URL must start with http:// or https://", 400, "INVALID_URL_FORMAT", requestId); return; } // Extract filename from URL for content type detection std::string filename = imageUrl; size_t lastSlash = imageUrl.find_last_of('/'); if (lastSlash != std::string::npos) { filename = imageUrl.substr(lastSlash + 1); } // Remove query parameters and fragments size_t questionMark = filename.find('?'); if (questionMark != std::string::npos) { filename = filename.substr(0, questionMark); } size_t hashMark = filename.find('#'); if (hashMark != std::string::npos) { filename = filename.substr(0, hashMark); } // Check if URL has image extension std::string extension; size_t lastDot = filename.find_last_of('.'); if (lastDot != std::string::npos) { extension = filename.substr(lastDot + 1); std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower); } // Validate image extension const std::vector validExtensions = {"jpg", "jpeg", "png", "gif", "webp", "bmp"}; if (extension.empty() || std::find(validExtensions.begin(), validExtensions.end(), extension) == validExtensions.end()) { sendErrorResponse(res, "URL must point to an image file with a valid extension: " + std::accumulate(validExtensions.begin(), validExtensions.end(), std::string(), [](const std::string& a, const std::string& b) { return a.empty() ? b : a + ", " + b; }), 400, "INVALID_IMAGE_EXTENSION", requestId); return; } // Load image using existing loadImageFromInput function auto [imageData, width, height, channels, success, error] = loadImageFromInput(imageUrl); if (!success) { sendErrorResponse(res, "Failed to download image from URL: " + error, 400, "IMAGE_DOWNLOAD_FAILED", requestId); return; } // Convert image data to base64 std::string base64Data = Utils::base64Encode(imageData); // Determine MIME type based on extension std::string mimeType = "image/jpeg"; // default if (extension == "png") { mimeType = "image/png"; } else if (extension == "gif") { mimeType = "image/gif"; } else if (extension == "webp") { mimeType = "image/webp"; } else if (extension == "bmp") { mimeType = "image/bmp"; } else if (extension == "jpg" || extension == "jpeg") { mimeType = "image/jpeg"; } // Create data URL format std::string dataUrl = "data:" + mimeType + ";base64," + base64Data; // Build response nlohmann::json response = { {"success", true}, {"url", imageUrl}, {"filename", filename}, {"width", width}, {"height", height}, {"channels", channels}, {"mime_type", mimeType}, {"base64_data", dataUrl}, {"file_size_bytes", imageData.size()}, {"request_id", requestId} }; sendJsonResponse(res, response, 200); std::cout << "Successfully downloaded and encoded image from URL: " << imageUrl << " (" << width << "x" << height << ", " << imageData.size() << " bytes)" << std::endl; } catch (const nlohmann::json::parse_error& e) { sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId); } catch (const std::exception& e) { std::cerr << "Exception in handleDownloadImageFromUrl: " << e.what() << std::endl; sendErrorResponse(res, std::string("Failed to download image from URL: ") + e.what(), 500, "INTERNAL_ERROR", requestId); } } void Server::sendJsonResponse(httplib::Response& res, const nlohmann::json& json, int status_code) { res.set_header("Content-Type", "application/json"); res.status = status_code; res.body = json.dump(); } void Server::sendErrorResponse(httplib::Response& res, const std::string& message, int status_code, const std::string& error_code, const std::string& request_id) { nlohmann::json errorResponse = { {"error", { {"message", message}, {"status_code", status_code}, {"error_code", error_code}, {"request_id", request_id}, {"timestamp", std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()).count()} }} }; sendJsonResponse(res, errorResponse, status_code); } std::pair Server::validateGenerationParameters(const nlohmann::json& params) { // Validate required fields if (!params.contains("prompt") || !params["prompt"].is_string()) { return {false, "Missing or invalid 'prompt' field"}; } const std::string& prompt = params["prompt"]; if (prompt.empty()) { return {false, "Prompt cannot be empty"}; } if (prompt.length() > m_config.maxPromptLength) { return {false, "Prompt too long (max " + std::to_string(m_config.maxPromptLength) + " characters)"}; } // Validate negative prompt if present if (params.contains("negative_prompt")) { if (!params["negative_prompt"].is_string()) { return {false, "Invalid 'negative_prompt' field, must be string"}; } if (params["negative_prompt"].get().length() > m_config.maxNegativePromptLength) { return {false, "Negative prompt too long (max " + std::to_string(m_config.maxNegativePromptLength) + " characters)"}; } } // Validate width if (params.contains("width")) { if (!params["width"].is_number_integer()) { return {false, "Invalid 'width' field, must be integer"}; } int width = params["width"]; if (width < 64 || width > 2048 || width % 64 != 0) { return {false, "Width must be between 64 and 2048 and divisible by 64"}; } } // Validate height if (params.contains("height")) { if (!params["height"].is_number_integer()) { return {false, "Invalid 'height' field, must be integer"}; } int height = params["height"]; if (height < 64 || height > 2048 || height % 64 != 0) { return {false, "Height must be between 64 and 2048 and divisible by 64"}; } } // Validate batch count if (params.contains("batch_count")) { if (!params["batch_count"].is_number_integer()) { return {false, "Invalid 'batch_count' field, must be integer"}; } int batchCount = params["batch_count"]; if (batchCount < 1 || batchCount > 100) { return {false, "Batch count must be between 1 and 100"}; } } // Validate steps if (params.contains("steps")) { if (!params["steps"].is_number_integer()) { return {false, "Invalid 'steps' field, must be integer"}; } int steps = params["steps"]; if (steps < 1 || steps > 150) { return {false, "Steps must be between 1 and 150"}; } } // Validate CFG scale if (params.contains("cfg_scale")) { if (!params["cfg_scale"].is_number()) { return {false, "Invalid 'cfg_scale' field, must be number"}; } float cfgScale = params["cfg_scale"]; if (cfgScale < 1.0f || cfgScale > 30.0f) { return {false, "CFG scale must be between 1.0 and 30.0"}; } } // Validate seed if (params.contains("seed")) { if (!params["seed"].is_string() && !params["seed"].is_number_integer()) { return {false, "Invalid 'seed' field, must be string or integer"}; } } // Validate sampling method if (params.contains("sampling_method")) { if (!params["sampling_method"].is_string()) { return {false, "Invalid 'sampling_method' field, must be string"}; } std::string method = params["sampling_method"]; std::vector validMethods = { "euler", "euler_a", "heun", "dpm2", "dpm++2s_a", "dpm++2m", "dpm++2mv2", "ipndm", "ipndm_v", "lcm", "ddim_trailing", "tcd", "default" }; if (std::find(validMethods.begin(), validMethods.end(), method) == validMethods.end()) { return {false, "Invalid sampling method"}; } } // Validate scheduler if (params.contains("scheduler")) { if (!params["scheduler"].is_string()) { return {false, "Invalid 'scheduler' field, must be string"}; } std::string scheduler = params["scheduler"]; std::vector validSchedulers = { "discrete", "karras", "exponential", "ays", "gits", "smoothstep", "sgm_uniform", "simple", "default" }; if (std::find(validSchedulers.begin(), validSchedulers.end(), scheduler) == validSchedulers.end()) { return {false, "Invalid scheduler"}; } } // Validate strength if (params.contains("strength")) { if (!params["strength"].is_number()) { return {false, "Invalid 'strength' field, must be number"}; } float strength = params["strength"]; if (strength < 0.0f || strength > 1.0f) { return {false, "Strength must be between 0.0 and 1.0"}; } } // Validate control strength if (params.contains("control_strength")) { if (!params["control_strength"].is_number()) { return {false, "Invalid 'control_strength' field, must be number"}; } float controlStrength = params["control_strength"]; if (controlStrength < 0.0f || controlStrength > 1.0f) { return {false, "Control strength must be between 0.0 and 1.0"}; } } // Validate clip skip if (params.contains("clip_skip")) { if (!params["clip_skip"].is_number_integer()) { return {false, "Invalid 'clip_skip' field, must be integer"}; } int clipSkip = params["clip_skip"]; if (clipSkip < -1 || clipSkip > 12) { return {false, "Clip skip must be between -1 and 12"}; } } // Validate threads if (params.contains("threads")) { if (!params["threads"].is_number_integer()) { return {false, "Invalid 'threads' field, must be integer"}; } int threads = params["threads"]; if (threads < -1 || threads > 32) { return {false, "Threads must be between -1 (auto) and 32"}; } } return {true, ""}; } SamplingMethod Server::parseSamplingMethod(const std::string& method) { if (method == "euler") return SamplingMethod::EULER; else if (method == "euler_a") return SamplingMethod::EULER_A; else if (method == "heun") return SamplingMethod::HEUN; else if (method == "dpm2") return SamplingMethod::DPM2; else if (method == "dpm++2s_a") return SamplingMethod::DPMPP2S_A; else if (method == "dpm++2m") return SamplingMethod::DPMPP2M; else if (method == "dpm++2mv2") return SamplingMethod::DPMPP2MV2; else if (method == "ipndm") return SamplingMethod::IPNDM; else if (method == "ipndm_v") return SamplingMethod::IPNDM_V; else if (method == "lcm") return SamplingMethod::LCM; else if (method == "ddim_trailing") return SamplingMethod::DDIM_TRAILING; else if (method == "tcd") return SamplingMethod::TCD; else return SamplingMethod::DEFAULT; } Scheduler Server::parseScheduler(const std::string& scheduler) { if (scheduler == "discrete") return Scheduler::DISCRETE; else if (scheduler == "karras") return Scheduler::KARRAS; else if (scheduler == "exponential") return Scheduler::EXPONENTIAL; else if (scheduler == "ays") return Scheduler::AYS; else if (scheduler == "gits") return Scheduler::GITS; else if (scheduler == "smoothstep") return Scheduler::SMOOTHSTEP; else if (scheduler == "sgm_uniform") return Scheduler::SGM_UNIFORM; else if (scheduler == "simple") return Scheduler::SIMPLE; else return Scheduler::DEFAULT; } std::string Server::generateRequestId() { std::random_device rd; std::mt19937 gen(rd()); std::uniform_int_distribution<> dis(100000, 999999); return "req_" + std::to_string(dis(gen)); } std::tuple, int, int, int, bool, std::string> Server::loadImageFromInput(const std::string& input) { std::vector imageData; int width = 0, height = 0, channels = 0; // Auto-detect input source type // 1. Check if input is a URL (starts with http:// or https://) if (Utils::startsWith(input, "http://") || Utils::startsWith(input, "https://")) { // Parse URL to extract host and path std::string url = input; std::string scheme, host, path; int port = 80; // Determine scheme and port if (Utils::startsWith(url, "https://")) { scheme = "https"; port = 443; url = url.substr(8); // Remove "https://" } else { scheme = "http"; port = 80; url = url.substr(7); // Remove "http://" } // Extract host and path size_t slashPos = url.find('/'); if (slashPos != std::string::npos) { host = url.substr(0, slashPos); path = url.substr(slashPos); } else { host = url; path = "/"; } // Check for custom port size_t colonPos = host.find(':'); if (colonPos != std::string::npos) { try { port = std::stoi(host.substr(colonPos + 1)); host = host.substr(0, colonPos); } catch (...) { return {imageData, 0, 0, 0, false, "Invalid port in URL"}; } } // Download image using httplib try { httplib::Result res; if (scheme == "https") { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT httplib::SSLClient client(host, port); client.set_follow_location(true); client.set_connection_timeout(30, 0); // 30 seconds client.set_read_timeout(60, 0); // 60 seconds res = client.Get(path.c_str()); #else return {imageData, 0, 0, 0, false, "HTTPS not supported (OpenSSL not available)"}; #endif } else { httplib::Client client(host, port); client.set_follow_location(true); client.set_connection_timeout(30, 0); // 30 seconds client.set_read_timeout(60, 0); // 60 seconds res = client.Get(path.c_str()); } if (!res) { return {imageData, 0, 0, 0, false, "Failed to download image from URL: Connection error"}; } if (res->status != 200) { return {imageData, 0, 0, 0, false, "Failed to download image from URL: HTTP " + std::to_string(res->status)}; } // Convert response body to vector std::vector downloadedData(res->body.begin(), res->body.end()); // Load image from memory int w, h, c; unsigned char* pixels = stbi_load_from_memory( downloadedData.data(), downloadedData.size(), &w, &h, &c, 3 // Force RGB ); if (!pixels) { return {imageData, 0, 0, 0, false, "Failed to decode image from URL"}; } width = w; height = h; channels = 3; size_t dataSize = width * height * channels; imageData.resize(dataSize); std::memcpy(imageData.data(), pixels, dataSize); stbi_image_free(pixels); } catch (const std::exception& e) { return {imageData, 0, 0, 0, false, "Failed to download image from URL: " + std::string(e.what())}; } } // 2. Check if input is base64 encoded data URI (starts with "data:image") else if (Utils::startsWith(input, "data:image")) { // Extract base64 data after the comma size_t commaPos = input.find(','); if (commaPos == std::string::npos) { return {imageData, 0, 0, 0, false, "Invalid data URI format"}; } std::string base64Data = input.substr(commaPos + 1); std::vector decodedData = Utils::base64Decode(base64Data); // Load image from memory using stb_image int w, h, c; unsigned char* pixels = stbi_load_from_memory( decodedData.data(), decodedData.size(), &w, &h, &c, 3 // Force RGB ); if (!pixels) { return {imageData, 0, 0, 0, false, "Failed to decode image from base64 data URI"}; } width = w; height = h; channels = 3; // We forced RGB // Copy pixel data size_t dataSize = width * height * channels; imageData.resize(dataSize); std::memcpy(imageData.data(), pixels, dataSize); stbi_image_free(pixels); } // 3. Check if input is raw base64 (long string without slashes, likely base64) else if (input.length() > 100 && input.find('/') == std::string::npos && input.find('.') == std::string::npos) { // Likely raw base64 without data URI prefix std::vector decodedData = Utils::base64Decode(input); int w, h, c; unsigned char* pixels = stbi_load_from_memory( decodedData.data(), decodedData.size(), &w, &h, &c, 3 // Force RGB ); if (!pixels) { return {imageData, 0, 0, 0, false, "Failed to decode image from base64"}; } width = w; height = h; channels = 3; size_t dataSize = width * height * channels; imageData.resize(dataSize); std::memcpy(imageData.data(), pixels, dataSize); stbi_image_free(pixels); } // 4. Treat as local file path else { int w, h, c; unsigned char* pixels = stbi_load(input.c_str(), &w, &h, &c, 3); if (!pixels) { return {imageData, 0, 0, 0, false, "Failed to load image from file: " + input}; } width = w; height = h; channels = 3; size_t dataSize = width * height * channels; imageData.resize(dataSize); std::memcpy(imageData.data(), pixels, dataSize); stbi_image_free(pixels); } return {imageData, width, height, channels, true, ""}; } std::string Server::samplingMethodToString(SamplingMethod method) { switch (method) { case SamplingMethod::EULER: return "euler"; case SamplingMethod::EULER_A: return "euler_a"; case SamplingMethod::HEUN: return "heun"; case SamplingMethod::DPM2: return "dpm2"; case SamplingMethod::DPMPP2S_A: return "dpm++2s_a"; case SamplingMethod::DPMPP2M: return "dpm++2m"; case SamplingMethod::DPMPP2MV2: return "dpm++2mv2"; case SamplingMethod::IPNDM: return "ipndm"; case SamplingMethod::IPNDM_V: return "ipndm_v"; case SamplingMethod::LCM: return "lcm"; case SamplingMethod::DDIM_TRAILING: return "ddim_trailing"; case SamplingMethod::TCD: return "tcd"; default: return "default"; } } std::string Server::schedulerToString(Scheduler scheduler) { switch (scheduler) { case Scheduler::DISCRETE: return "discrete"; case Scheduler::KARRAS: return "karras"; case Scheduler::EXPONENTIAL: return "exponential"; case Scheduler::AYS: return "ays"; case Scheduler::GITS: return "gits"; case Scheduler::SMOOTHSTEP: return "smoothstep"; case Scheduler::SGM_UNIFORM: return "sgm_uniform"; case Scheduler::SIMPLE: return "simple"; default: return "default"; } } uint64_t Server::estimateGenerationTime(const GenerationRequest& request) { // Basic estimation based on parameters uint64_t baseTime = 1000; // 1 second base time // Factor in steps baseTime *= request.steps; // Factor in resolution double resolutionFactor = (request.width * request.height) / (512.0 * 512.0); baseTime = static_cast(baseTime * resolutionFactor); // Factor in batch count baseTime *= request.batchCount; // Adjust for sampling method (some are faster than others) switch (request.samplingMethod) { case SamplingMethod::LCM: baseTime /= 4; // LCM is much faster break; case SamplingMethod::EULER: case SamplingMethod::EULER_A: baseTime *= 0.8; // Euler methods are faster break; case SamplingMethod::DPM2: case SamplingMethod::DPMPP2S_A: baseTime *= 1.2; // DPM methods are slower break; default: break; } return baseTime; } size_t Server::estimateMemoryUsage(const GenerationRequest& request) { // Basic memory estimation in bytes size_t baseMemory = 1024 * 1024 * 1024; // 1GB base // Factor in resolution double resolutionFactor = (request.width * request.height) / (512.0 * 512.0); baseMemory = static_cast(baseMemory * resolutionFactor); // Factor in batch count baseMemory *= request.batchCount; // Additional memory for certain features if (request.diffusionFlashAttn) { baseMemory += 512 * 1024 * 1024; // Extra 512MB for flash attention } if (!request.controlNetPath.empty()) { baseMemory += 1024 * 1024 * 1024; // Extra 1GB for ControlNet } return baseMemory; } // Specialized generation endpoints void Server::handleText2Img(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { if (!m_generationQueue) { sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId); return; } nlohmann::json requestJson = nlohmann::json::parse(req.body); // Validate required fields for text2img if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) { sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId); return; } // Validate all parameters auto [isValid, errorMessage] = validateGenerationParameters(requestJson); if (!isValid) { sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId); return; } // Check if any model is loaded if (!m_modelManager) { sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId); return; } // Get currently loaded checkpoint model auto allModels = m_modelManager->getAllModels(); std::string loadedModelName; for (const auto& [modelName, modelInfo] : allModels) { if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) { loadedModelName = modelName; break; } } if (loadedModelName.empty()) { sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId); return; } // Create generation request specifically for text2img GenerationRequest genRequest; genRequest.id = requestId; genRequest.modelName = loadedModelName; // Use the currently loaded model genRequest.prompt = requestJson["prompt"]; genRequest.negativePrompt = requestJson.value("negative_prompt", ""); genRequest.width = requestJson.value("width", 512); genRequest.height = requestJson.value("height", 512); genRequest.batchCount = requestJson.value("batch_count", 1); genRequest.steps = requestJson.value("steps", 20); genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f); genRequest.seed = requestJson.value("seed", "random"); // Parse optional parameters if (requestJson.contains("sampling_method")) { genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]); } if (requestJson.contains("scheduler")) { genRequest.scheduler = parseScheduler(requestJson["scheduler"]); } // Set text2img specific defaults genRequest.strength = 1.0f; // Full strength for text2img // Optional VAE model if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) { std::string vaeModelId = requestJson["vae_model"]; if (!vaeModelId.empty()) { auto vaeInfo = m_modelManager->getModelInfo(vaeModelId); if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) { genRequest.vaePath = vaeInfo.path; } else { sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId); return; } } } // Optional TAESD model if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) { std::string taesdModelId = requestJson["taesd_model"]; if (!taesdModelId.empty()) { auto taesdInfo = m_modelManager->getModelInfo(taesdModelId); if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) { genRequest.taesdPath = taesdInfo.path; } else { sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId); return; } } } // Enqueue request auto future = m_generationQueue->enqueueRequest(genRequest); nlohmann::json params = { {"prompt", genRequest.prompt}, {"negative_prompt", genRequest.negativePrompt}, {"model", genRequest.modelName}, {"width", genRequest.width}, {"height", genRequest.height}, {"batch_count", genRequest.batchCount}, {"steps", genRequest.steps}, {"cfg_scale", genRequest.cfgScale}, {"seed", genRequest.seed}, {"sampling_method", samplingMethodToString(genRequest.samplingMethod)}, {"scheduler", schedulerToString(genRequest.scheduler)} }; // Add VAE/TAESD if specified if (!genRequest.vaePath.empty()) { params["vae_model"] = requestJson.value("vae_model", ""); } if (!genRequest.taesdPath.empty()) { params["taesd_model"] = requestJson.value("taesd_model", ""); } nlohmann::json response = { {"request_id", requestId}, {"status", "queued"}, {"message", "Text-to-image generation request queued successfully"}, {"queue_position", m_generationQueue->getQueueSize()}, {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000}, {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)}, {"type", "text2img"}, {"parameters", params} }; sendJsonResponse(res, response, 202); } catch (const nlohmann::json::parse_error& e) { sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Text-to-image request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId); } } void Server::handleImg2Img(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { if (!m_generationQueue) { sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId); return; } nlohmann::json requestJson = nlohmann::json::parse(req.body); // Validate required fields for img2img if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) { sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId); return; } if (!requestJson.contains("init_image") || !requestJson["init_image"].is_string()) { sendErrorResponse(res, "Missing or invalid 'init_image' field", 400, "INVALID_PARAMETERS", requestId); return; } // Validate all parameters auto [isValid, errorMessage] = validateGenerationParameters(requestJson); if (!isValid) { sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId); return; } // Check if any model is loaded if (!m_modelManager) { sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId); return; } // Get currently loaded checkpoint model auto allModels = m_modelManager->getAllModels(); std::string loadedModelName; for (const auto& [modelName, modelInfo] : allModels) { if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) { loadedModelName = modelName; break; } } if (loadedModelName.empty()) { sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId); return; } // Load the init image std::string initImageInput = requestJson["init_image"]; auto [imageData, imgWidth, imgHeight, imgChannels, success, loadError] = loadImageFromInput(initImageInput); if (!success) { sendErrorResponse(res, "Failed to load init image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId); return; } // Create generation request specifically for img2img GenerationRequest genRequest; genRequest.id = requestId; genRequest.requestType = GenerationRequest::RequestType::IMG2IMG; genRequest.modelName = loadedModelName; // Use the currently loaded model genRequest.prompt = requestJson["prompt"]; genRequest.negativePrompt = requestJson.value("negative_prompt", ""); genRequest.width = requestJson.value("width", imgWidth); // Default to input image dimensions genRequest.height = requestJson.value("height", imgHeight); genRequest.batchCount = requestJson.value("batch_count", 1); genRequest.steps = requestJson.value("steps", 20); genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f); genRequest.seed = requestJson.value("seed", "random"); genRequest.strength = requestJson.value("strength", 0.75f); // Set init image data genRequest.initImageData = imageData; genRequest.initImageWidth = imgWidth; genRequest.initImageHeight = imgHeight; genRequest.initImageChannels = imgChannels; // Parse optional parameters if (requestJson.contains("sampling_method")) { genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]); } if (requestJson.contains("scheduler")) { genRequest.scheduler = parseScheduler(requestJson["scheduler"]); } // Optional VAE model if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) { std::string vaeModelId = requestJson["vae_model"]; if (!vaeModelId.empty()) { auto vaeInfo = m_modelManager->getModelInfo(vaeModelId); if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) { genRequest.vaePath = vaeInfo.path; } else { sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId); return; } } } // Optional TAESD model if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) { std::string taesdModelId = requestJson["taesd_model"]; if (!taesdModelId.empty()) { auto taesdInfo = m_modelManager->getModelInfo(taesdModelId); if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) { genRequest.taesdPath = taesdInfo.path; } else { sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId); return; } } } // Enqueue request auto future = m_generationQueue->enqueueRequest(genRequest); nlohmann::json params = { {"prompt", genRequest.prompt}, {"negative_prompt", genRequest.negativePrompt}, {"init_image", requestJson["init_image"]}, {"model", genRequest.modelName}, {"width", genRequest.width}, {"height", genRequest.height}, {"batch_count", genRequest.batchCount}, {"steps", genRequest.steps}, {"cfg_scale", genRequest.cfgScale}, {"seed", genRequest.seed}, {"strength", genRequest.strength}, {"sampling_method", samplingMethodToString(genRequest.samplingMethod)}, {"scheduler", schedulerToString(genRequest.scheduler)} }; // Add VAE/TAESD if specified if (!genRequest.vaePath.empty()) { params["vae_model"] = requestJson.value("vae_model", ""); } if (!genRequest.taesdPath.empty()) { params["taesd_model"] = requestJson.value("taesd_model", ""); } nlohmann::json response = { {"request_id", requestId}, {"status", "queued"}, {"message", "Image-to-image generation request queued successfully"}, {"queue_position", m_generationQueue->getQueueSize()}, {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000}, {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)}, {"type", "img2img"}, {"parameters", params} }; sendJsonResponse(res, response, 202); } catch (const nlohmann::json::parse_error& e) { sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Image-to-image request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId); } } void Server::handleControlNet(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { if (!m_generationQueue) { sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId); return; } nlohmann::json requestJson = nlohmann::json::parse(req.body); // Validate required fields for ControlNet if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) { sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId); return; } if (!requestJson.contains("control_image") || !requestJson["control_image"].is_string()) { sendErrorResponse(res, "Missing or invalid 'control_image' field", 400, "INVALID_PARAMETERS", requestId); return; } // Validate all parameters auto [isValid, errorMessage] = validateGenerationParameters(requestJson); if (!isValid) { sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId); return; } // Check if any model is loaded if (!m_modelManager) { sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId); return; } // Get currently loaded checkpoint model auto allModels = m_modelManager->getAllModels(); std::string loadedModelName; for (const auto& [modelName, modelInfo] : allModels) { if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) { loadedModelName = modelName; break; } } if (loadedModelName.empty()) { sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId); return; } // Create generation request specifically for ControlNet GenerationRequest genRequest; genRequest.id = requestId; genRequest.modelName = loadedModelName; // Use the currently loaded model genRequest.prompt = requestJson["prompt"]; genRequest.negativePrompt = requestJson.value("negative_prompt", ""); genRequest.width = requestJson.value("width", 512); genRequest.height = requestJson.value("height", 512); genRequest.batchCount = requestJson.value("batch_count", 1); genRequest.steps = requestJson.value("steps", 20); genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f); genRequest.seed = requestJson.value("seed", "random"); genRequest.controlStrength = requestJson.value("control_strength", 0.9f); genRequest.controlNetPath = requestJson.value("control_net_model", ""); // Parse optional parameters if (requestJson.contains("sampling_method")) { genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]); } if (requestJson.contains("scheduler")) { genRequest.scheduler = parseScheduler(requestJson["scheduler"]); } // Optional VAE model if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) { std::string vaeModelId = requestJson["vae_model"]; if (!vaeModelId.empty()) { auto vaeInfo = m_modelManager->getModelInfo(vaeModelId); if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) { genRequest.vaePath = vaeInfo.path; } else { sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId); return; } } } // Optional TAESD model if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) { std::string taesdModelId = requestJson["taesd_model"]; if (!taesdModelId.empty()) { auto taesdInfo = m_modelManager->getModelInfo(taesdModelId); if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) { genRequest.taesdPath = taesdInfo.path; } else { sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId); return; } } } // Store control image path (would be handled in actual implementation) genRequest.outputPath = requestJson.value("control_image", ""); // Enqueue request auto future = m_generationQueue->enqueueRequest(genRequest); nlohmann::json params = { {"prompt", genRequest.prompt}, {"negative_prompt", genRequest.negativePrompt}, {"control_image", requestJson["control_image"]}, {"control_net_model", genRequest.controlNetPath}, {"model", genRequest.modelName}, {"width", genRequest.width}, {"height", genRequest.height}, {"batch_count", genRequest.batchCount}, {"steps", genRequest.steps}, {"cfg_scale", genRequest.cfgScale}, {"seed", genRequest.seed}, {"control_strength", genRequest.controlStrength}, {"sampling_method", samplingMethodToString(genRequest.samplingMethod)}, {"scheduler", schedulerToString(genRequest.scheduler)} }; // Add VAE/TAESD if specified if (!genRequest.vaePath.empty()) { params["vae_model"] = requestJson.value("vae_model", ""); } if (!genRequest.taesdPath.empty()) { params["taesd_model"] = requestJson.value("taesd_model", ""); } nlohmann::json response = { {"request_id", requestId}, {"status", "queued"}, {"message", "ControlNet generation request queued successfully"}, {"queue_position", m_generationQueue->getQueueSize()}, {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000}, {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)}, {"type", "controlnet"}, {"parameters", params} }; sendJsonResponse(res, response, 202); } catch (const nlohmann::json::parse_error& e) { sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId); } catch (const std::exception& e) { sendErrorResponse(res, std::string("ControlNet request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId); } } void Server::handleUpscale(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { if (!m_generationQueue) { sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId); return; } nlohmann::json requestJson = nlohmann::json::parse(req.body); // Validate required fields for upscaler if (!requestJson.contains("image") || !requestJson["image"].is_string()) { sendErrorResponse(res, "Missing or invalid 'image' field", 400, "INVALID_PARAMETERS", requestId); return; } if (!requestJson.contains("esrgan_model") || !requestJson["esrgan_model"].is_string()) { sendErrorResponse(res, "Missing or invalid 'esrgan_model' field (model hash or name)", 400, "INVALID_PARAMETERS", requestId); return; } // Check if model manager is available if (!m_modelManager) { sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId); return; } // Get the ESRGAN/upscaler model std::string esrganModelId = requestJson["esrgan_model"]; auto modelInfo = m_modelManager->getModelInfo(esrganModelId); if (modelInfo.name.empty()) { sendErrorResponse(res, "ESRGAN model not found: " + esrganModelId, 404, "MODEL_NOT_FOUND", requestId); return; } if (modelInfo.type != ModelType::ESRGAN && modelInfo.type != ModelType::UPSCALER) { sendErrorResponse(res, "Model is not an ESRGAN/upscaler model", 400, "INVALID_MODEL_TYPE", requestId); return; } // Load the input image std::string imageInput = requestJson["image"]; auto [imageData, imgWidth, imgHeight, imgChannels, success, loadError] = loadImageFromInput(imageInput); if (!success) { sendErrorResponse(res, "Failed to load image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId); return; } // Create upscaler request GenerationRequest genRequest; genRequest.id = requestId; genRequest.requestType = GenerationRequest::RequestType::UPSCALER; genRequest.esrganPath = modelInfo.path; genRequest.upscaleFactor = requestJson.value("upscale_factor", 4); genRequest.nThreads = requestJson.value("threads", -1); genRequest.offloadParamsToCpu = requestJson.value("offload_to_cpu", false); genRequest.diffusionConvDirect = requestJson.value("direct", false); // Set input image data genRequest.initImageData = imageData; genRequest.initImageWidth = imgWidth; genRequest.initImageHeight = imgHeight; genRequest.initImageChannels = imgChannels; // Enqueue request auto future = m_generationQueue->enqueueRequest(genRequest); nlohmann::json response = { {"request_id", requestId}, {"status", "queued"}, {"message", "Upscale request queued successfully"}, {"queue_position", m_generationQueue->getQueueSize()}, {"type", "upscale"}, {"parameters", { {"esrgan_model", esrganModelId}, {"upscale_factor", genRequest.upscaleFactor}, {"input_width", imgWidth}, {"input_height", imgHeight}, {"output_width", imgWidth * genRequest.upscaleFactor}, {"output_height", imgHeight * genRequest.upscaleFactor} }} }; sendJsonResponse(res, response, 202); } catch (const nlohmann::json::parse_error& e) { sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Upscale request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId); } } void Server::handleInpainting(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { if (!m_generationQueue) { sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId); return; } nlohmann::json requestJson = nlohmann::json::parse(req.body); // Validate required fields for inpainting if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) { sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId); return; } if (!requestJson.contains("source_image") || !requestJson["source_image"].is_string()) { sendErrorResponse(res, "Missing or invalid 'source_image' field", 400, "INVALID_PARAMETERS", requestId); return; } if (!requestJson.contains("mask_image") || !requestJson["mask_image"].is_string()) { sendErrorResponse(res, "Missing or invalid 'mask_image' field", 400, "INVALID_PARAMETERS", requestId); return; } // Validate all parameters auto [isValid, errorMessage] = validateGenerationParameters(requestJson); if (!isValid) { sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId); return; } // Check if any model is loaded if (!m_modelManager) { sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId); return; } // Get currently loaded checkpoint model auto allModels = m_modelManager->getAllModels(); std::string loadedModelName; for (const auto& [modelName, modelInfo] : allModels) { if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) { loadedModelName = modelName; break; } } if (loadedModelName.empty()) { sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId); return; } // Load the source image std::string sourceImageInput = requestJson["source_image"]; auto [sourceImageData, sourceImgWidth, sourceImgHeight, sourceImgChannels, sourceSuccess, sourceLoadError] = loadImageFromInput(sourceImageInput); if (!sourceSuccess) { sendErrorResponse(res, "Failed to load source image: " + sourceLoadError, 400, "IMAGE_LOAD_ERROR", requestId); return; } // Load the mask image std::string maskImageInput = requestJson["mask_image"]; auto [maskImageData, maskImgWidth, maskImgHeight, maskImgChannels, maskSuccess, maskLoadError] = loadImageFromInput(maskImageInput); if (!maskSuccess) { sendErrorResponse(res, "Failed to load mask image: " + maskLoadError, 400, "MASK_LOAD_ERROR", requestId); return; } // Validate that source and mask images have compatible dimensions if (sourceImgWidth != maskImgWidth || sourceImgHeight != maskImgHeight) { sendErrorResponse(res, "Source and mask images must have the same dimensions", 400, "DIMENSION_MISMATCH", requestId); return; } // Create generation request specifically for inpainting GenerationRequest genRequest; genRequest.id = requestId; genRequest.requestType = GenerationRequest::RequestType::INPAINTING; genRequest.modelName = loadedModelName; // Use the currently loaded model genRequest.prompt = requestJson["prompt"]; genRequest.negativePrompt = requestJson.value("negative_prompt", ""); genRequest.width = requestJson.value("width", sourceImgWidth); // Default to input image dimensions genRequest.height = requestJson.value("height", sourceImgHeight); genRequest.batchCount = requestJson.value("batch_count", 1); genRequest.steps = requestJson.value("steps", 20); genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f); genRequest.seed = requestJson.value("seed", "random"); genRequest.strength = requestJson.value("strength", 0.75f); // Set source image data genRequest.initImageData = sourceImageData; genRequest.initImageWidth = sourceImgWidth; genRequest.initImageHeight = sourceImgHeight; genRequest.initImageChannels = sourceImgChannels; // Set mask image data genRequest.maskImageData = maskImageData; genRequest.maskImageWidth = maskImgWidth; genRequest.maskImageHeight = maskImgHeight; genRequest.maskImageChannels = maskImgChannels; // Parse optional parameters if (requestJson.contains("sampling_method")) { genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]); } if (requestJson.contains("scheduler")) { genRequest.scheduler = parseScheduler(requestJson["scheduler"]); } // Optional VAE model if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) { std::string vaeModelId = requestJson["vae_model"]; if (!vaeModelId.empty()) { auto vaeInfo = m_modelManager->getModelInfo(vaeModelId); if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) { genRequest.vaePath = vaeInfo.path; } else { sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId); return; } } } // Optional TAESD model if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) { std::string taesdModelId = requestJson["taesd_model"]; if (!taesdModelId.empty()) { auto taesdInfo = m_modelManager->getModelInfo(taesdModelId); if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) { genRequest.taesdPath = taesdInfo.path; } else { sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId); return; } } } // Enqueue request auto future = m_generationQueue->enqueueRequest(genRequest); nlohmann::json params = { {"prompt", genRequest.prompt}, {"negative_prompt", genRequest.negativePrompt}, {"source_image", requestJson["source_image"]}, {"mask_image", requestJson["mask_image"]}, {"model", genRequest.modelName}, {"width", genRequest.width}, {"height", genRequest.height}, {"batch_count", genRequest.batchCount}, {"steps", genRequest.steps}, {"cfg_scale", genRequest.cfgScale}, {"seed", genRequest.seed}, {"strength", genRequest.strength}, {"sampling_method", samplingMethodToString(genRequest.samplingMethod)}, {"scheduler", schedulerToString(genRequest.scheduler)} }; // Add VAE/TAESD if specified if (!genRequest.vaePath.empty()) { params["vae_model"] = requestJson.value("vae_model", ""); } if (!genRequest.taesdPath.empty()) { params["taesd_model"] = requestJson.value("taesd_model", ""); } nlohmann::json response = { {"request_id", requestId}, {"status", "queued"}, {"message", "Inpainting generation request queued successfully"}, {"queue_position", m_generationQueue->getQueueSize()}, {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000}, {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)}, {"type", "inpainting"}, {"parameters", params} }; sendJsonResponse(res, response, 202); } catch (const nlohmann::json::parse_error& e) { sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Inpainting request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId); } } // Utility endpoints void Server::handleSamplers(const httplib::Request& /*req*/, httplib::Response& res) { try { nlohmann::json samplers = { {"samplers", { { {"name", "euler"}, {"description", "Euler sampler - fast and simple"}, {"recommended_steps", 20} }, { {"name", "euler_a"}, {"description", "Euler ancestral sampler - adds randomness"}, {"recommended_steps", 20} }, { {"name", "heun"}, {"description", "Heun sampler - more accurate but slower"}, {"recommended_steps", 20} }, { {"name", "dpm2"}, {"description", "DPM2 sampler - second-order DPM"}, {"recommended_steps", 20} }, { {"name", "dpm++2s_a"}, {"description", "DPM++ 2s ancestral sampler"}, {"recommended_steps", 20} }, { {"name", "dpm++2m"}, {"description", "DPM++ 2m sampler - multistep"}, {"recommended_steps", 20} }, { {"name", "dpm++2mv2"}, {"description", "DPM++ 2m v2 sampler - improved multistep"}, {"recommended_steps", 20} }, { {"name", "ipndm"}, {"description", "IPNDM sampler - improved noise prediction"}, {"recommended_steps", 20} }, { {"name", "ipndm_v"}, {"description", "IPNDM v sampler - variant of IPNDM"}, {"recommended_steps", 20} }, { {"name", "lcm"}, {"description", "LCM sampler - Latent Consistency Model, very fast"}, {"recommended_steps", 4} }, { {"name", "ddim_trailing"}, {"description", "DDIM trailing sampler - deterministic"}, {"recommended_steps", 20} }, { {"name", "tcd"}, {"description", "TCD sampler - Trajectory Consistency Distillation"}, {"recommended_steps", 8} }, { {"name", "default"}, {"description", "Use model's default sampler"}, {"recommended_steps", 20} } }} }; sendJsonResponse(res, samplers); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Failed to get samplers: ") + e.what(), 500); } } void Server::handleSchedulers(const httplib::Request& /*req*/, httplib::Response& res) { try { nlohmann::json schedulers = { {"schedulers", { { {"name", "discrete"}, {"description", "Discrete scheduler - standard noise schedule"} }, { {"name", "karras"}, {"description", "Karras scheduler - improved noise schedule"} }, { {"name", "exponential"}, {"description", "Exponential scheduler - exponential noise decay"} }, { {"name", "ays"}, {"description", "AYS scheduler - Adaptive Your Scheduler"} }, { {"name", "gits"}, {"description", "GITS scheduler - Generalized Iterative Time Steps"} }, { {"name", "smoothstep"}, {"description", "Smoothstep scheduler - smooth transition function"} }, { {"name", "sgm_uniform"}, {"description", "SGM uniform scheduler - uniform noise schedule"} }, { {"name", "simple"}, {"description", "Simple scheduler - basic linear schedule"} }, { {"name", "default"}, {"description", "Use model's default scheduler"} } }} }; sendJsonResponse(res, schedulers); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Failed to get schedulers: ") + e.what(), 500); } } void Server::handleParameters(const httplib::Request& /*req*/, httplib::Response& res) { try { nlohmann::json parameters = { {"parameters", { { {"name", "prompt"}, {"type", "string"}, {"required", true}, {"description", "Text prompt for image generation"}, {"min_length", 1}, {"max_length", 10000}, {"example", "a beautiful landscape with mountains"} }, { {"name", "negative_prompt"}, {"type", "string"}, {"required", false}, {"description", "Negative prompt to guide generation away from"}, {"min_length", 0}, {"max_length", 10000}, {"example", "blurry, low quality, distorted"} }, { {"name", "width"}, {"type", "integer"}, {"required", false}, {"description", "Image width in pixels"}, {"min", 64}, {"max", 2048}, {"multiple_of", 64}, {"default", 512} }, { {"name", "height"}, {"type", "integer"}, {"required", false}, {"description", "Image height in pixels"}, {"min", 64}, {"max", 2048}, {"multiple_of", 64}, {"default", 512} }, { {"name", "steps"}, {"type", "integer"}, {"required", false}, {"description", "Number of diffusion steps"}, {"min", 1}, {"max", 150}, {"default", 20} }, { {"name", "cfg_scale"}, {"type", "number"}, {"required", false}, {"description", "Classifier-Free Guidance scale"}, {"min", 1.0}, {"max", 30.0}, {"default", 7.5} }, { {"name", "seed"}, {"type", "string|integer"}, {"required", false}, {"description", "Seed for generation (use 'random' for random seed)"}, {"example", "42"} }, { {"name", "sampling_method"}, {"type", "string"}, {"required", false}, {"description", "Sampling method to use"}, {"enum", {"euler", "euler_a", "heun", "dpm2", "dpm++2s_a", "dpm++2m", "dpm++2mv2", "ipndm", "ipndm_v", "lcm", "ddim_trailing", "tcd", "default"}}, {"default", "default"} }, { {"name", "scheduler"}, {"type", "string"}, {"required", false}, {"description", "Scheduler to use"}, {"enum", {"discrete", "karras", "exponential", "ays", "gits", "smoothstep", "sgm_uniform", "simple", "default"}}, {"default", "default"} }, { {"name", "batch_count"}, {"type", "integer"}, {"required", false}, {"description", "Number of images to generate"}, {"min", 1}, {"max", 100}, {"default", 1} }, { {"name", "strength"}, {"type", "number"}, {"required", false}, {"description", "Strength for img2img (0.0-1.0)"}, {"min", 0.0}, {"max", 1.0}, {"default", 0.75} }, { {"name", "control_strength"}, {"type", "number"}, {"required", false}, {"description", "ControlNet strength (0.0-1.0)"}, {"min", 0.0}, {"max", 1.0}, {"default", 0.9} } }}, {"openapi", { {"version", "3.0.0"}, {"info", { {"title", "Stable Diffusion REST API"}, {"version", "1.0.0"}, {"description", "Comprehensive REST API for stable-diffusion.cpp functionality"} }}, {"components", { {"schemas", { {"GenerationRequest", { {"type", "object"}, {"required", {"prompt"}}, {"properties", { {"prompt", {{"type", "string"}, {"description", "Text prompt for generation"}}}, {"negative_prompt", {{"type", "string"}, {"description", "Negative prompt"}}}, {"width", {{"type", "integer"}, {"minimum", 64}, {"maximum", 2048}, {"default", 512}}}, {"height", {{"type", "integer"}, {"minimum", 64}, {"maximum", 2048}, {"default", 512}}}, {"steps", {{"type", "integer"}, {"minimum", 1}, {"maximum", 150}, {"default", 20}}}, {"cfg_scale", {{"type", "number"}, {"minimum", 1.0}, {"maximum", 30.0}, {"default", 7.5}}} }} }} }} }} }} }; sendJsonResponse(res, parameters); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Failed to get parameters: ") + e.what(), 500); } } void Server::handleValidate(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { nlohmann::json requestJson = nlohmann::json::parse(req.body); // Validate parameters auto [isValid, errorMessage] = validateGenerationParameters(requestJson); nlohmann::json response = { {"request_id", requestId}, {"valid", isValid}, {"message", isValid ? "Parameters are valid" : errorMessage}, {"errors", isValid ? nlohmann::json::array() : nlohmann::json::array({errorMessage})} }; sendJsonResponse(res, response, isValid ? 200 : 400); } catch (const nlohmann::json::parse_error& e) { sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Validation failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId); } } void Server::handleEstimate(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { nlohmann::json requestJson = nlohmann::json::parse(req.body); // Validate parameters first auto [isValid, errorMessage] = validateGenerationParameters(requestJson); if (!isValid) { sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId); return; } // Create a temporary request to estimate GenerationRequest genRequest; genRequest.prompt = requestJson["prompt"]; genRequest.width = requestJson.value("width", 512); genRequest.height = requestJson.value("height", 512); genRequest.batchCount = requestJson.value("batch_count", 1); genRequest.steps = requestJson.value("steps", 20); genRequest.diffusionFlashAttn = requestJson.value("diffusion_flash_attn", false); genRequest.controlNetPath = requestJson.value("control_net_path", ""); if (requestJson.contains("sampling_method")) { genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]); } // Calculate estimates uint64_t estimatedTime = estimateGenerationTime(genRequest); size_t estimatedMemory = estimateMemoryUsage(genRequest); nlohmann::json response = { {"request_id", requestId}, {"estimated_time_seconds", estimatedTime / 1000}, {"estimated_memory_mb", estimatedMemory / (1024 * 1024)}, {"parameters", { {"resolution", std::to_string(genRequest.width) + "x" + std::to_string(genRequest.height)}, {"steps", genRequest.steps}, {"batch_count", genRequest.batchCount}, {"sampling_method", samplingMethodToString(genRequest.samplingMethod)} }} }; sendJsonResponse(res, response); } catch (const nlohmann::json::parse_error& e) { sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Estimation failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId); } } void Server::handleConfig(const httplib::Request& /*req*/, httplib::Response& res) { std::string requestId = generateRequestId(); try { // Get current configuration nlohmann::json config = { {"request_id", requestId}, {"config", { {"server", { {"host", m_host}, {"port", m_port}, {"max_concurrent_generations", 1} }}, {"generation", { {"default_width", 512}, {"default_height", 512}, {"default_steps", 20}, {"default_cfg_scale", 7.5}, {"max_batch_count", 100}, {"max_steps", 150}, {"max_resolution", 2048} }}, {"rate_limiting", { {"requests_per_minute", 60}, {"enabled", true} }} }} }; sendJsonResponse(res, config); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Config operation failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId); } } void Server::handleSystem(const httplib::Request& /*req*/, httplib::Response& res) { try { nlohmann::json system = { {"system", { {"version", "1.0.0"}, {"build", "stable-diffusion.cpp-rest"}, {"uptime", std::chrono::duration_cast( std::chrono::steady_clock::now().time_since_epoch()).count()}, {"capabilities", { {"text2img", true}, {"img2img", true}, {"controlnet", true}, {"batch_generation", true}, {"parameter_validation", true}, {"estimation", true} }}, {"supported_formats", { {"input", {"png", "jpg", "jpeg", "webp"}}, {"output", {"png", "jpg", "jpeg", "webp"}} }}, {"limits", { {"max_resolution", 2048}, {"max_steps", 150}, {"max_batch_count", 100}, {"max_prompt_length", 10000} }} }}, {"hardware", { {"cpu_threads", std::thread::hardware_concurrency()} }} }; sendJsonResponse(res, system); } catch (const std::exception& e) { sendErrorResponse(res, std::string("System info failed: ") + e.what(), 500); } } void Server::handleSystemRestart(const httplib::Request& /*req*/, httplib::Response& res) { try { nlohmann::json response = { {"message", "Server restart initiated. The server will shut down gracefully and exit. Please use a process manager to automatically restart it."}, {"status", "restarting"} }; sendJsonResponse(res, response); // Schedule server stop after response is sent // Using a separate thread to allow the response to be sent first std::thread([this]() { std::this_thread::sleep_for(std::chrono::seconds(1)); this->stop(); // Exit with code 42 to signal restart intent to process manager std::exit(42); }).detach(); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Restart failed: ") + e.what(), 500); } } // Helper methods for model management nlohmann::json Server::getModelCapabilities(ModelType type) { nlohmann::json capabilities = nlohmann::json::object(); switch (type) { case ModelType::CHECKPOINT: capabilities = { {"text2img", true}, {"img2img", true}, {"inpainting", true}, {"outpainting", true}, {"controlnet", true}, {"lora", true}, {"vae", true}, {"sampling_methods", {"euler", "euler_a", "heun", "dpm2", "dpm++2s_a", "dpm++2m", "dpm++2mv2", "ipndm", "ipndm_v", "lcm", "ddim_trailing", "tcd"}}, {"schedulers", {"discrete", "karras", "exponential", "ays", "gits", "smoothstep", "sgm_uniform", "simple"}}, {"recommended_resolution", "512x512"}, {"max_resolution", "2048x2048"}, {"supports_batch", true} }; break; case ModelType::LORA: capabilities = { {"text2img", true}, {"img2img", true}, {"inpainting", true}, {"controlnet", false}, {"lora", true}, {"vae", false}, {"requires_checkpoint", true}, {"strength_range", {0.0, 2.0}}, {"recommended_strength", 1.0} }; break; case ModelType::CONTROLNET: capabilities = { {"text2img", false}, {"img2img", true}, {"inpainting", true}, {"controlnet", true}, {"requires_checkpoint", true}, {"control_modes", {"canny", "depth", "pose", "scribble", "hed", "mlsd", "normal", "seg"}}, {"strength_range", {0.0, 1.0}}, {"recommended_strength", 0.9} }; break; case ModelType::VAE: capabilities = { {"text2img", false}, {"img2img", false}, {"inpainting", false}, {"vae", true}, {"requires_checkpoint", true}, {"encoding", true}, {"decoding", true}, {"precision", {"fp16", "fp32"}} }; break; case ModelType::EMBEDDING: capabilities = { {"text2img", true}, {"img2img", true}, {"inpainting", true}, {"embedding", true}, {"requires_checkpoint", true}, {"token_count", 1}, {"compatible_with", {"checkpoint", "lora"}} }; break; case ModelType::TAESD: capabilities = { {"text2img", false}, {"img2img", false}, {"inpainting", false}, {"vae", true}, {"requires_checkpoint", true}, {"fast_decoding", true}, {"real_time", true}, {"precision", {"fp16", "fp32"}} }; break; case ModelType::ESRGAN: capabilities = { {"text2img", false}, {"img2img", false}, {"inpainting", false}, {"upscaling", true}, {"scale_factors", {2, 4}}, {"models", {"ESRGAN", "RealESRGAN", "SwinIR"}}, {"supports_alpha", false} }; break; default: capabilities = { {"text2img", false}, {"img2img", false}, {"inpainting", false}, {"capabilities", {}} }; break; } return capabilities; } nlohmann::json Server::getModelTypeStatistics() { if (!m_modelManager) return nlohmann::json::object(); nlohmann::json stats = nlohmann::json::object(); auto allModels = m_modelManager->getAllModels(); // Initialize counters for each type std::map typeCounts; std::map loadedCounts; std::map sizeByType; for (const auto& pair : allModels) { ModelType type = pair.second.type; typeCounts[type]++; if (pair.second.isLoaded) { loadedCounts[type]++; } sizeByType[type] += pair.second.fileSize; } // Build statistics JSON for (const auto& count : typeCounts) { std::string typeName = ModelManager::modelTypeToString(count.first); stats[typeName] = { {"total_count", count.second}, {"loaded_count", loadedCounts[count.first]}, {"total_size_bytes", sizeByType[count.first]}, {"total_size_mb", sizeByType[count.first] / (1024.0 * 1024.0)}, {"average_size_mb", count.second > 0 ? (sizeByType[count.first] / (1024.0 * 1024.0)) / count.second : 0.0} }; } return stats; } // Additional helper methods for model management nlohmann::json Server::getModelCompatibility(const ModelManager::ModelInfo& modelInfo) { nlohmann::json compatibility = { {"is_compatible", true}, {"compatibility_score", 100}, {"issues", nlohmann::json::array()}, {"warnings", nlohmann::json::array()}, {"requirements", { {"min_memory_mb", 1024}, {"recommended_memory_mb", 2048}, {"supported_formats", {"safetensors", "ckpt", "gguf"}}, {"required_dependencies", {}} }} }; // Check for specific compatibility issues based on model type if (modelInfo.type == ModelType::LORA) { compatibility["requirements"]["required_dependencies"] = {"checkpoint"}; } else if (modelInfo.type == ModelType::CONTROLNET) { compatibility["requirements"]["required_dependencies"] = {"checkpoint"}; } else if (modelInfo.type == ModelType::VAE) { compatibility["requirements"]["required_dependencies"] = {"checkpoint"}; } return compatibility; } nlohmann::json Server::getModelRequirements(ModelType type) { nlohmann::json requirements = { {"min_memory_mb", 1024}, {"recommended_memory_mb", 2048}, {"min_disk_space_mb", 1024}, {"supported_formats", {"safetensors", "ckpt", "gguf"}}, {"required_dependencies", nlohmann::json::array()}, {"optional_dependencies", nlohmann::json::array()}, {"system_requirements", { {"cpu_cores", 4}, {"cpu_architecture", "x86_64"}, {"os", "Linux/Windows/macOS"}, {"gpu_memory_mb", 2048}, {"gpu_compute_capability", "3.5+"} }} }; switch (type) { case ModelType::CHECKPOINT: requirements["min_memory_mb"] = 2048; requirements["recommended_memory_mb"] = 4096; requirements["min_disk_space_mb"] = 2048; requirements["supported_formats"] = {"safetensors", "ckpt", "gguf"}; break; case ModelType::LORA: requirements["min_memory_mb"] = 512; requirements["recommended_memory_mb"] = 1024; requirements["min_disk_space_mb"] = 100; requirements["supported_formats"] = {"safetensors", "ckpt"}; requirements["required_dependencies"] = {"checkpoint"}; break; case ModelType::CONTROLNET: requirements["min_memory_mb"] = 1024; requirements["recommended_memory_mb"] = 2048; requirements["min_disk_space_mb"] = 500; requirements["supported_formats"] = {"safetensors", "pth"}; requirements["required_dependencies"] = {"checkpoint"}; break; case ModelType::VAE: requirements["min_memory_mb"] = 512; requirements["recommended_memory_mb"] = 1024; requirements["min_disk_space_mb"] = 200; requirements["supported_formats"] = {"safetensors", "pt", "ckpt", "gguf"}; requirements["required_dependencies"] = {"checkpoint"}; break; case ModelType::EMBEDDING: requirements["min_memory_mb"] = 64; requirements["recommended_memory_mb"] = 256; requirements["min_disk_space_mb"] = 10; requirements["supported_formats"] = {"safetensors", "pt"}; requirements["required_dependencies"] = {"checkpoint"}; break; case ModelType::TAESD: requirements["min_memory_mb"] = 256; requirements["recommended_memory_mb"] = 512; requirements["min_disk_space_mb"] = 100; requirements["supported_formats"] = {"safetensors", "pth", "gguf"}; requirements["required_dependencies"] = {"checkpoint"}; break; case ModelType::ESRGAN: requirements["min_memory_mb"] = 1024; requirements["recommended_memory_mb"] = 2048; requirements["min_disk_space_mb"] = 500; requirements["supported_formats"] = {"pth", "pt"}; requirements["optional_dependencies"] = {"checkpoint"}; break; default: break; } return requirements; } nlohmann::json Server::getRecommendedUsage(ModelType type) { nlohmann::json usage = { {"text2img", false}, {"img2img", false}, {"inpainting", false}, {"controlnet", false}, {"lora", false}, {"vae", false}, {"recommended_resolution", "512x512"}, {"recommended_steps", 20}, {"recommended_cfg_scale", 7.5}, {"recommended_batch_size", 1} }; switch (type) { case ModelType::CHECKPOINT: usage = { {"text2img", true}, {"img2img", true}, {"inpainting", true}, {"controlnet", true}, {"lora", true}, {"vae", true}, {"recommended_resolution", "512x512"}, {"recommended_steps", 20}, {"recommended_cfg_scale", 7.5}, {"recommended_batch_size", 1} }; break; case ModelType::LORA: usage = { {"text2img", true}, {"img2img", true}, {"inpainting", true}, {"controlnet", false}, {"lora", true}, {"vae", false}, {"recommended_strength", 1.0}, {"recommended_usage", "Style transfer, character customization"} }; break; case ModelType::CONTROLNET: usage = { {"text2img", false}, {"img2img", true}, {"inpainting", true}, {"controlnet", true}, {"lora", false}, {"vae", false}, {"recommended_strength", 0.9}, {"recommended_usage", "Precise control over output"} }; break; case ModelType::VAE: usage = { {"text2img", false}, {"img2img", false}, {"inpainting", false}, {"controlnet", false}, {"lora", false}, {"vae", true}, {"recommended_usage", "Improved encoding/decoding quality"} }; break; case ModelType::EMBEDDING: usage = { {"text2img", true}, {"img2img", true}, {"inpainting", true}, {"controlnet", false}, {"lora", false}, {"vae", false}, {"embedding", true}, {"recommended_usage", "Concept control, style words"} }; break; case ModelType::TAESD: usage = { {"text2img", false}, {"img2img", false}, {"inpainting", false}, {"controlnet", false}, {"lora", false}, {"vae", true}, {"recommended_usage", "Real-time decoding"} }; break; case ModelType::ESRGAN: usage = { {"text2img", false}, {"img2img", false}, {"inpainting", false}, {"controlnet", false}, {"lora", false}, {"vae", false}, {"upscaling", true}, {"recommended_usage", "Image upscaling and quality enhancement"} }; break; default: break; } return usage; } std::string Server::getModelTypeFromDirectoryName(const std::string& dirName) { if (dirName == "stable-diffusion" || dirName == "checkpoints") { return "checkpoint"; } else if (dirName == "lora") { return "lora"; } else if (dirName == "controlnet") { return "controlnet"; } else if (dirName == "vae") { return "vae"; } else if (dirName == "taesd") { return "taesd"; } else if (dirName == "esrgan" || dirName == "upscaler") { return "esrgan"; } else if (dirName == "embeddings" || dirName == "textual-inversion") { return "embedding"; } else { return "unknown"; } } std::string Server::getDirectoryDescription(const std::string& dirName) { if (dirName == "stable-diffusion" || dirName == "checkpoints") { return "Main stable diffusion model files"; } else if (dirName == "lora") { return "LoRA adapter models for style transfer"; } else if (dirName == "controlnet") { return "ControlNet models for precise control"; } else if (dirName == "vae") { return "VAE models for improved encoding/decoding"; } else if (dirName == "taesd") { return "TAESD models for real-time decoding"; } else if (dirName == "esrgan" || dirName == "upscaler") { return "ESRGAN models for image upscaling"; } else if (dirName == "embeddings" || dirName == "textual-inversion") { return "Text embeddings for concept control"; } else { return "Unknown model directory"; } } nlohmann::json Server::getDirectoryContents(const std::string& dirPath) { nlohmann::json contents = nlohmann::json::array(); try { if (std::filesystem::exists(dirPath) && std::filesystem::is_directory(dirPath)) { for (const auto& entry : std::filesystem::directory_iterator(dirPath)) { if (entry.is_regular_file()) { nlohmann::json file = { {"name", entry.path().filename().string()}, {"path", entry.path().string()}, {"size", std::filesystem::file_size(entry.path())}, {"size_mb", std::filesystem::file_size(entry.path()) / (1024.0 * 1024.0)}, {"last_modified", std::chrono::duration_cast( std::filesystem::last_write_time(entry.path()).time_since_epoch()).count()} }; contents.push_back(file); } } } } catch (const std::exception& e) { // Return empty array if directory access fails } return contents; } nlohmann::json Server::getLargestModel(const std::map& allModels) { nlohmann::json largest = nlohmann::json::object(); size_t maxSize = 0; std::string largestName; for (const auto& pair : allModels) { if (pair.second.fileSize > maxSize) { maxSize = pair.second.fileSize; largestName = pair.second.name; } } if (!largestName.empty()) { largest = { {"name", largestName}, {"size", maxSize}, {"size_mb", maxSize / (1024.0 * 1024.0)}, {"type", ModelManager::modelTypeToString(allModels.at(largestName).type)} }; } return largest; } nlohmann::json Server::getSmallestModel(const std::map& allModels) { nlohmann::json smallest = nlohmann::json::object(); size_t minSize = SIZE_MAX; std::string smallestName; for (const auto& pair : allModels) { if (pair.second.fileSize < minSize) { minSize = pair.second.fileSize; smallestName = pair.second.name; } } if (!smallestName.empty()) { smallest = { {"name", smallestName}, {"size", minSize}, {"size_mb", minSize / (1024.0 * 1024.0)}, {"type", ModelManager::modelTypeToString(allModels.at(smallestName).type)} }; } return smallest; } nlohmann::json Server::validateModelFile(const std::string& modelPath, const std::string& modelType) { nlohmann::json validation = { {"is_valid", false}, {"errors", nlohmann::json::array()}, {"warnings", nlohmann::json::array()}, {"file_info", nlohmann::json::object()}, {"compatibility", nlohmann::json::object()}, {"recommendations", nlohmann::json::array()} }; try { if (!std::filesystem::exists(modelPath)) { validation["errors"].push_back("File does not exist"); return validation; } if (!std::filesystem::is_regular_file(modelPath)) { validation["errors"].push_back("Path is not a regular file"); return validation; } // Check file extension std::string extension = std::filesystem::path(modelPath).extension().string(); if (extension.empty()) { validation["errors"].push_back("Missing file extension"); return validation; } // Remove dot and convert to lowercase if (extension[0] == '.') { extension = extension.substr(1); } std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower); // Validate extension based on model type ModelType type = ModelManager::stringToModelType(modelType); bool validExtension = false; switch (type) { case ModelType::CHECKPOINT: validExtension = (extension == "safetensors" || extension == "ckpt" || extension == "gguf"); break; case ModelType::LORA: validExtension = (extension == "safetensors" || extension == "ckpt"); break; case ModelType::CONTROLNET: validExtension = (extension == "safetensors" || extension == "pth"); break; case ModelType::VAE: validExtension = (extension == "safetensors" || extension == "pt" || extension == "ckpt" || extension == "gguf"); break; case ModelType::EMBEDDING: validExtension = (extension == "safetensors" || extension == "pt"); break; case ModelType::TAESD: validExtension = (extension == "safetensors" || extension == "pth" || extension == "gguf"); break; case ModelType::ESRGAN: validExtension = (extension == "pth" || extension == "pt"); break; default: break; } if (!validExtension) { validation["errors"].push_back("Invalid file extension for model type: " + extension); } // Check file size size_t fileSize = std::filesystem::file_size(modelPath); if (fileSize == 0) { validation["errors"].push_back("File is empty"); } else if (fileSize > 8ULL * 1024 * 1024 * 1024) { // 8GB validation["warnings"].push_back("Very large file may cause performance issues"); } // Build file info validation["file_info"] = { {"path", modelPath}, {"size", fileSize}, {"size_mb", fileSize / (1024.0 * 1024.0)}, {"extension", extension}, {"last_modified", std::chrono::duration_cast( std::filesystem::last_write_time(modelPath).time_since_epoch()).count()} }; // Check compatibility validation["compatibility"] = { {"extension_valid", validExtension}, {"size_appropriate", fileSize <= 4ULL * 1024 * 1024 * 1024}, // 4GB {"recommended_format", "safetensors"} }; // Add recommendations if (!validExtension) { validation["recommendations"].push_back("Convert to SafeTensors format for better security and performance"); } if (fileSize > 2ULL * 1024 * 1024 * 1024) { // 2GB validation["recommendations"].push_back("Consider using a smaller model for better performance"); } // If no errors found, mark as valid if (validation["errors"].empty()) { validation["is_valid"] = true; } } catch (const std::exception& e) { validation["errors"].push_back("Validation failed: " + std::string(e.what())); } return validation; } nlohmann::json Server::checkModelCompatibility(const ModelManager::ModelInfo& modelInfo, const std::string& systemInfo) { nlohmann::json compatibility = { {"is_compatible", true}, {"compatibility_score", 100}, {"issues", nlohmann::json::array()}, {"warnings", nlohmann::json::array()}, {"requirements", nlohmann::json::object()}, {"recommendations", nlohmann::json::array()}, {"system_info", nlohmann::json::object()} }; // Check system compatibility if (systemInfo == "auto") { compatibility["system_info"] = { {"cpu_cores", std::thread::hardware_concurrency()} }; } // Check model-specific compatibility issues if (modelInfo.type == ModelType::CHECKPOINT) { if (modelInfo.fileSize > 4ULL * 1024 * 1024 * 1024) { // 4GB compatibility["warnings"].push_back("Large checkpoint model may require significant memory"); compatibility["compatibility_score"] = 80; } if (modelInfo.fileSize < 500 * 1024 * 1024) { // 500MB compatibility["warnings"].push_back("Small checkpoint model may have limited capabilities"); compatibility["compatibility_score"] = 85; } } else if (modelInfo.type == ModelType::LORA) { if (modelInfo.fileSize > 500 * 1024 * 1024) { // 500MB compatibility["warnings"].push_back("Large LoRA may impact performance"); compatibility["compatibility_score"] = 75; } } return compatibility; } nlohmann::json Server::calculateSpecificRequirements(const std::string& modelType, const std::string& resolution, const std::string& batchSize) { (void)modelType; // Suppress unused parameter warning nlohmann::json specific = { {"memory_requirements", nlohmann::json::object()}, {"performance_impact", nlohmann::json::object()}, {"quality_expectations", nlohmann::json::object()} }; // Parse resolution int width = 512, height = 512; try { size_t xPos = resolution.find('x'); if (xPos != std::string::npos) { width = std::stoi(resolution.substr(0, xPos)); height = std::stoi(resolution.substr(xPos + 1)); } } catch (...) { // Use defaults if parsing fails } // Parse batch size int batch = 1; try { batch = std::stoi(batchSize); } catch (...) { // Use default if parsing fails } // Calculate memory requirements based on resolution and batch size_t pixels = width * height; size_t baseMemory = 1024 * 1024 * 1024; // 1GB base size_t resolutionMemory = (pixels * 4) / (512 * 512); // Scale based on 512x512 size_t batchMemory = (batch - 1) * baseMemory * 0.5; // Additional memory for batch specific["memory_requirements"] = { {"base_memory_mb", baseMemory / (1024 * 1024)}, {"resolution_memory_mb", resolutionMemory / (1024 * 1024)}, {"batch_memory_mb", batchMemory / (1024 * 1024)}, {"total_memory_mb", (baseMemory + resolutionMemory + batchMemory) / (1024 * 1024)} }; // Calculate performance impact double performanceFactor = 1.0; if (pixels > 512 * 512) { performanceFactor = 1.5; } if (batch > 1) { performanceFactor *= 1.2; } specific["performance_impact"] = { {"resolution_factor", pixels > 512 * 512 ? 1.5 : 1.0}, {"batch_factor", batch > 1 ? 1.2 : 1.0}, {"overall_factor", performanceFactor} }; return specific; } // Enhanced model management endpoint implementations void Server::handleModelInfo(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { if (!m_modelManager) { sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId); return; } // Extract model ID from URL path std::string modelId = req.matches[1].str(); if (modelId.empty()) { sendErrorResponse(res, "Missing model ID", 400, "MISSING_MODEL_ID", requestId); return; } // Get model information auto modelInfo = m_modelManager->getModelInfo(modelId); if (modelInfo.name.empty()) { sendErrorResponse(res, "Model not found", 404, "MODEL_NOT_FOUND", requestId); return; } // Build comprehensive model information nlohmann::json response = { {"model", { {"name", modelInfo.name}, {"path", modelInfo.path}, {"type", ModelManager::modelTypeToString(modelInfo.type)}, {"is_loaded", modelInfo.isLoaded}, {"file_size", modelInfo.fileSize}, {"file_size_mb", modelInfo.fileSize / (1024.0 * 1024.0)}, {"description", modelInfo.description}, {"metadata", modelInfo.metadata}, {"capabilities", getModelCapabilities(modelInfo.type)}, {"compatibility", getModelCompatibility(modelInfo)}, {"requirements", getModelRequirements(modelInfo.type)}, {"recommended_usage", getRecommendedUsage(modelInfo.type)}, {"last_modified", std::chrono::duration_cast( modelInfo.modifiedAt.time_since_epoch()).count()} }}, {"request_id", requestId} }; sendJsonResponse(res, response); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Failed to get model info: ") + e.what(), 500, "MODEL_INFO_ERROR", requestId); } } void Server::handleLoadModelById(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { if (!m_modelManager) { sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId); return; } // Extract model ID from URL path (could be hash or name) std::string modelIdentifier = req.matches[1].str(); if (modelIdentifier.empty()) { sendErrorResponse(res, "Missing model identifier", 400, "MISSING_MODEL_ID", requestId); return; } // Try to find by hash first (if it looks like a hash - 10+ hex chars) std::string modelId = modelIdentifier; if (modelIdentifier.length() >= 10 && std::all_of(modelIdentifier.begin(), modelIdentifier.end(), [](char c) { return std::isxdigit(c); })) { std::string foundName = m_modelManager->findModelByHash(modelIdentifier); if (!foundName.empty()) { modelId = foundName; std::cout << "Resolved hash " << modelIdentifier << " to model: " << modelId << std::endl; } } // Parse optional parameters from request body nlohmann::json requestJson; if (!req.body.empty()) { try { requestJson = nlohmann::json::parse(req.body); } catch (const nlohmann::json::parse_error& e) { sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId); return; } } // Unload previous model if one is loaded std::string previousModel; { std::lock_guard lock(m_currentModelMutex); previousModel = m_currentlyLoadedModel; } if (!previousModel.empty() && previousModel != modelId) { std::cout << "Unloading previous model: " << previousModel << std::endl; m_modelManager->unloadModel(previousModel); } // Load model bool success = m_modelManager->loadModel(modelId); if (success) { // Update currently loaded model { std::lock_guard lock(m_currentModelMutex); m_currentlyLoadedModel = modelId; } auto modelInfo = m_modelManager->getModelInfo(modelId); nlohmann::json response = { {"status", "success"}, {"model", { {"name", modelInfo.name}, {"path", modelInfo.path}, {"type", ModelManager::modelTypeToString(modelInfo.type)}, {"is_loaded", modelInfo.isLoaded} }}, {"request_id", requestId} }; sendJsonResponse(res, response); } else { sendErrorResponse(res, "Failed to load model", 400, "MODEL_LOAD_FAILED", requestId); } } catch (const std::exception& e) { sendErrorResponse(res, std::string("Model load failed: ") + e.what(), 500, "MODEL_LOAD_ERROR", requestId); } } void Server::handleUnloadModelById(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { if (!m_modelManager) { sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId); return; } // Extract model ID from URL path std::string modelId = req.matches[1].str(); if (modelId.empty()) { sendErrorResponse(res, "Missing model ID", 400, "MISSING_MODEL_ID", requestId); return; } // Unload model bool success = m_modelManager->unloadModel(modelId); if (success) { // Clear currently loaded model if it matches { std::lock_guard lock(m_currentModelMutex); if (m_currentlyLoadedModel == modelId) { m_currentlyLoadedModel = ""; } } nlohmann::json response = { {"status", "success"}, {"model", { {"name", modelId}, {"is_loaded", false} }}, {"request_id", requestId} }; sendJsonResponse(res, response); } else { sendErrorResponse(res, "Failed to unload model or model not found", 404, "MODEL_UNLOAD_FAILED", requestId); } } catch (const std::exception& e) { sendErrorResponse(res, std::string("Model unload failed: ") + e.what(), 500, "MODEL_UNLOAD_ERROR", requestId); } } void Server::handleModelTypes(const httplib::Request& /*req*/, httplib::Response& res) { std::string requestId = generateRequestId(); try { nlohmann::json types = { {"model_types", { { {"type", "checkpoint"}, {"description", "Main stable diffusion model files for text-to-image, image-to-image, and inpainting"}, {"extensions", {"safetensors", "ckpt", "gguf"}}, {"capabilities", {"text2img", "img2img", "inpainting", "controlnet", "lora", "vae"}}, {"recommended_for", "General purpose image generation"} }, { {"type", "lora"}, {"description", "LoRA adapter models for style transfer and character customization"}, {"extensions", {"safetensors", "ckpt"}}, {"capabilities", {"style_transfer", "character_customization"}}, {"requires", {"checkpoint"}}, {"recommended_for", "Style modification and character-specific generation"} }, { {"type", "controlnet"}, {"description", "ControlNet models for precise control over output composition"}, {"extensions", {"safetensors", "pth"}}, {"capabilities", {"precise_control", "composition_control"}}, {"requires", {"checkpoint"}}, {"recommended_for", "Precise control over image generation"} }, { {"type", "vae"}, {"description", "VAE models for improved encoding and decoding quality"}, {"extensions", {"safetensors", "pt", "ckpt", "gguf"}}, {"capabilities", {"encoding", "decoding", "quality_improvement"}}, {"requires", {"checkpoint"}}, {"recommended_for", "Improved image quality and encoding"} }, { {"type", "embedding"}, {"description", "Text embeddings for concept control and style words"}, {"extensions", {"safetensors", "pt"}}, {"capabilities", {"concept_control", "style_words"}}, {"requires", {"checkpoint"}}, {"recommended_for", "Concept control and specific styles"} }, { {"type", "taesd"}, {"description", "TAESD models for real-time decoding"}, {"extensions", {"safetensors", "pth", "gguf"}}, {"capabilities", {"real_time_decoding", "fast_preview"}}, {"requires", {"checkpoint"}}, {"recommended_for", "Real-time applications and fast previews"} }, { {"type", "esrgan"}, {"description", "ESRGAN models for image upscaling and enhancement"}, {"extensions", {"pth", "pt"}}, {"capabilities", {"upscaling", "enhancement", "quality_improvement"}}, {"recommended_for", "Image upscaling and quality enhancement"} } }}, {"request_id", requestId} }; sendJsonResponse(res, types); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Failed to get model types: ") + e.what(), 500, "MODEL_TYPES_ERROR", requestId); } } void Server::handleModelDirectories(const httplib::Request& /*req*/, httplib::Response& res) { std::string requestId = generateRequestId(); try { if (!m_modelManager) { sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId); return; } std::string modelsDir = m_modelManager->getModelsDirectory(); nlohmann::json directories = nlohmann::json::array(); // Define expected model directories std::vector modelDirs = { "stable-diffusion", "checkpoints", "lora", "controlnet", "vae", "taesd", "esrgan", "embeddings" }; for (const auto& dirName : modelDirs) { std::string dirPath = modelsDir + "/" + dirName; std::string type = getModelTypeFromDirectoryName(dirName); std::string description = getDirectoryDescription(dirName); nlohmann::json dirInfo = { {"name", dirName}, {"path", dirPath}, {"type", type}, {"description", description}, {"exists", std::filesystem::exists(dirPath) && std::filesystem::is_directory(dirPath)}, {"contents", getDirectoryContents(dirPath)} }; directories.push_back(dirInfo); } nlohmann::json response = { {"models_directory", modelsDir}, {"directories", directories}, {"request_id", requestId} }; sendJsonResponse(res, response); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Failed to get model directories: ") + e.what(), 500, "MODEL_DIRECTORIES_ERROR", requestId); } } void Server::handleRefreshModels(const httplib::Request& /*req*/, httplib::Response& res) { std::string requestId = generateRequestId(); try { if (!m_modelManager) { sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId); return; } // Force refresh of model cache bool success = m_modelManager->scanModelsDirectory(); if (success) { nlohmann::json response = { {"status", "success"}, {"message", "Model cache refreshed successfully"}, {"models_found", m_modelManager->getAvailableModelsCount()}, {"models_loaded", m_modelManager->getLoadedModelsCount()}, {"models_directory", m_modelManager->getModelsDirectory()}, {"request_id", requestId} }; sendJsonResponse(res, response); } else { sendErrorResponse(res, "Failed to refresh model cache", 500, "MODEL_REFRESH_FAILED", requestId); } } catch (const std::exception& e) { sendErrorResponse(res, std::string("Model refresh failed: ") + e.what(), 500, "MODEL_REFRESH_ERROR", requestId); } } void Server::handleHashModels(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { if (!m_generationQueue || !m_modelManager) { sendErrorResponse(res, "Services not available", 500, "SERVICE_UNAVAILABLE", requestId); return; } // Parse request body nlohmann::json requestJson; if (!req.body.empty()) { requestJson = nlohmann::json::parse(req.body); } HashRequest hashReq; hashReq.id = requestId; hashReq.forceRehash = requestJson.value("force_rehash", false); if (requestJson.contains("models") && requestJson["models"].is_array()) { for (const auto& model : requestJson["models"]) { hashReq.modelNames.push_back(model.get()); } } // Enqueue hash request auto future = m_generationQueue->enqueueHashRequest(hashReq); nlohmann::json response = { {"request_id", requestId}, {"status", "queued"}, {"message", "Hash job queued successfully"}, {"models_to_hash", hashReq.modelNames.empty() ? "all_unhashed" : std::to_string(hashReq.modelNames.size())} }; sendJsonResponse(res, response, 202); } catch (const nlohmann::json::parse_error& e) { sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Hash request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId); } } void Server::handleConvertModel(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { if (!m_generationQueue || !m_modelManager) { sendErrorResponse(res, "Services not available", 500, "SERVICE_UNAVAILABLE", requestId); return; } // Parse request body nlohmann::json requestJson; try { requestJson = nlohmann::json::parse(req.body); } catch (const nlohmann::json::parse_error& e) { sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId); return; } // Validate required fields if (!requestJson.contains("model_name")) { sendErrorResponse(res, "Missing required field: model_name", 400, "MISSING_FIELD", requestId); return; } if (!requestJson.contains("quantization_type")) { sendErrorResponse(res, "Missing required field: quantization_type", 400, "MISSING_FIELD", requestId); return; } std::string modelName = requestJson["model_name"].get(); std::string quantizationType = requestJson["quantization_type"].get(); // Validate quantization type const std::vector validTypes = {"f32", "f16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "q2_K", "q3_K", "q4_K"}; if (std::find(validTypes.begin(), validTypes.end(), quantizationType) == validTypes.end()) { sendErrorResponse(res, "Invalid quantization_type. Valid types: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K", 400, "INVALID_QUANTIZATION_TYPE", requestId); return; } // Get model info to find the full path auto modelInfo = m_modelManager->getModelInfo(modelName); if (modelInfo.name.empty()) { sendErrorResponse(res, "Model not found: " + modelName, 404, "MODEL_NOT_FOUND", requestId); return; } // Check if model is already GGUF if (modelInfo.fullPath.find(".gguf") != std::string::npos) { sendErrorResponse(res, "Model is already in GGUF format. Cannot convert GGUF to GGUF.", 400, "ALREADY_GGUF", requestId); return; } // Build output path std::string outputPath = requestJson.value("output_path", ""); if (outputPath.empty()) { // Generate default output path: model_name_quantization.gguf namespace fs = std::filesystem; fs::path inputPath(modelInfo.fullPath); std::string baseName = inputPath.stem().string(); std::string outputDir = inputPath.parent_path().string(); outputPath = outputDir + "/" + baseName + "_" + quantizationType + ".gguf"; } // Create conversion request ConversionRequest convReq; convReq.id = requestId; convReq.modelName = modelName; convReq.modelPath = modelInfo.fullPath; convReq.outputPath = outputPath; convReq.quantizationType = quantizationType; // Enqueue conversion request auto future = m_generationQueue->enqueueConversionRequest(convReq); nlohmann::json response = { {"request_id", requestId}, {"status", "queued"}, {"message", "Model conversion queued successfully"}, {"model_name", modelName}, {"input_path", modelInfo.fullPath}, {"output_path", outputPath}, {"quantization_type", quantizationType} }; sendJsonResponse(res, response, 202); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Conversion request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId); } } void Server::handleModelStats(const httplib::Request& /*req*/, httplib::Response& res) { std::string requestId = generateRequestId(); try { if (!m_modelManager) { sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId); return; } auto allModels = m_modelManager->getAllModels(); nlohmann::json response = { {"statistics", { {"total_models", allModels.size()}, {"loaded_models", m_modelManager->getLoadedModelsCount()}, {"available_models", m_modelManager->getAvailableModelsCount()}, {"model_types", getModelTypeStatistics()}, {"largest_model", getLargestModel(allModels)}, {"smallest_model", getSmallestModel(allModels)} }}, {"request_id", requestId} }; sendJsonResponse(res, response); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Failed to get model stats: ") + e.what(), 500, "MODEL_STATS_ERROR", requestId); } } void Server::handleBatchModels(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { if (!m_modelManager) { sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId); return; } // Parse JSON request body nlohmann::json requestJson = nlohmann::json::parse(req.body); if (!requestJson.contains("operation") || !requestJson["operation"].is_string()) { sendErrorResponse(res, "Missing or invalid 'operation' field", 400, "INVALID_OPERATION", requestId); return; } if (!requestJson.contains("models") || !requestJson["models"].is_array()) { sendErrorResponse(res, "Missing or invalid 'models' field", 400, "INVALID_MODELS", requestId); return; } std::string operation = requestJson["operation"]; nlohmann::json models = requestJson["models"]; nlohmann::json results = nlohmann::json::array(); for (const auto& model : models) { if (!model.is_string()) { results.push_back({ {"model", model}, {"success", false}, {"error", "Invalid model name"} }); continue; } std::string modelName = model; bool success = false; std::string error = ""; if (operation == "load") { success = m_modelManager->loadModel(modelName); if (!success) error = "Failed to load model"; } else if (operation == "unload") { success = m_modelManager->unloadModel(modelName); if (!success) error = "Failed to unload model"; } else { error = "Unsupported operation"; } results.push_back({ {"model", modelName}, {"success", success}, {"error", error.empty() ? nlohmann::json(nullptr) : nlohmann::json(error)} }); } nlohmann::json response = { {"operation", operation}, {"results", results}, {"successful_count", std::count_if(results.begin(), results.end(), [](const nlohmann::json& result) { return result["success"].get(); })}, {"failed_count", std::count_if(results.begin(), results.end(), [](const nlohmann::json& result) { return !result["success"].get(); })}, {"request_id", requestId} }; sendJsonResponse(res, response); } catch (const nlohmann::json::parse_error& e) { sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Batch operation failed: ") + e.what(), 500, "BATCH_OPERATION_ERROR", requestId); } } void Server::handleValidateModel(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { // Parse JSON request body nlohmann::json requestJson = nlohmann::json::parse(req.body); if (!requestJson.contains("model_path") || !requestJson["model_path"].is_string()) { sendErrorResponse(res, "Missing or invalid 'model_path' field", 400, "INVALID_MODEL_PATH", requestId); return; } std::string modelPath = requestJson["model_path"]; std::string modelType = requestJson.value("model_type", "checkpoint"); // Validate model file nlohmann::json validation = validateModelFile(modelPath, modelType); nlohmann::json response = { {"validation", validation}, {"request_id", requestId} }; sendJsonResponse(res, response); } catch (const nlohmann::json::parse_error& e) { sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Model validation failed: ") + e.what(), 500, "MODEL_VALIDATION_ERROR", requestId); } } void Server::handleCheckCompatibility(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { if (!m_modelManager) { sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId); return; } // Parse JSON request body nlohmann::json requestJson = nlohmann::json::parse(req.body); if (!requestJson.contains("model_name") || !requestJson["model_name"].is_string()) { sendErrorResponse(res, "Missing or invalid 'model_name' field", 400, "INVALID_MODEL_NAME", requestId); return; } std::string modelName = requestJson["model_name"]; std::string systemInfo = requestJson.value("system_info", "auto"); // Get model information auto modelInfo = m_modelManager->getModelInfo(modelName); if (modelInfo.name.empty()) { sendErrorResponse(res, "Model not found", 404, "MODEL_NOT_FOUND", requestId); return; } // Check compatibility nlohmann::json compatibility = checkModelCompatibility(modelInfo, systemInfo); nlohmann::json response = { {"model", modelName}, {"compatibility", compatibility}, {"request_id", requestId} }; sendJsonResponse(res, response); } catch (const nlohmann::json::parse_error& e) { sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Compatibility check failed: ") + e.what(), 500, "COMPATIBILITY_CHECK_ERROR", requestId); } } void Server::handleModelRequirements(const httplib::Request& req, httplib::Response& res) { std::string requestId = generateRequestId(); try { // Parse JSON request body nlohmann::json requestJson = nlohmann::json::parse(req.body); std::string modelType = requestJson.value("model_type", "checkpoint"); std::string resolution = requestJson.value("resolution", "512x512"); std::string batchSize = requestJson.value("batch_size", "1"); // Calculate specific requirements nlohmann::json requirements = calculateSpecificRequirements(modelType, resolution, batchSize); // Get general requirements for model type ModelType type = ModelManager::stringToModelType(modelType); nlohmann::json generalRequirements = getModelRequirements(type); nlohmann::json response = { {"model_type", modelType}, {"configuration", { {"resolution", resolution}, {"batch_size", batchSize} }}, {"specific_requirements", requirements}, {"general_requirements", generalRequirements}, {"request_id", requestId} }; sendJsonResponse(res, response); } catch (const nlohmann::json::parse_error& e) { sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId); } catch (const std::exception& e) { sendErrorResponse(res, std::string("Requirements calculation failed: ") + e.what(), 500, "REQUIREMENTS_ERROR", requestId); } } void Server::serverThreadFunction(const std::string& host, int port) { try { std::cout << "Server thread starting, attempting to bind to " << host << ":" << port << std::endl; // Check if port is available before attempting to bind std::cout << "Checking if port " << port << " is available..." << std::endl; // Try to create a test socket to check if port is in use int test_socket = socket(AF_INET, SOCK_STREAM, 0); if (test_socket >= 0) { // Set SO_REUSEADDR to avoid TIME_WAIT issues int opt = 1; if (setsockopt(test_socket, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) { std::cerr << "Warning: Failed to set SO_REUSEADDR on test socket: " << strerror(errno) << std::endl; } // Also set SO_REUSEPORT if available (for better concurrent binding handling) #ifdef SO_REUSEPORT int reuseport = 1; if (setsockopt(test_socket, SOL_SOCKET, SO_REUSEPORT, &reuseport, sizeof(reuseport)) < 0) { std::cerr << "Warning: Failed to set SO_REUSEPORT on test socket: " << strerror(errno) << std::endl; } #endif struct sockaddr_in addr; addr.sin_family = AF_INET; addr.sin_port = htons(port); addr.sin_addr.s_addr = INADDR_ANY; // Try to bind to the port if (bind(test_socket, (struct sockaddr*)&addr, sizeof(addr)) < 0) { close(test_socket); std::cerr << "ERROR: Port " << port << " is already in use! Cannot start server." << std::endl; std::cerr << "This could be due to:" << std::endl; std::cerr << "1. Another instance is already running on this port" << std::endl; std::cerr << "2. A previous instance crashed and the socket is in TIME_WAIT state" << std::endl; std::cerr << "3. The port is being used by another application" << std::endl; std::cerr << std::endl; std::cerr << "Solutions:" << std::endl; std::cerr << "- Wait 30-60 seconds for TIME_WAIT to expire (if server crashed)" << std::endl; std::cerr << "- Kill any existing processes: sudo lsof -ti:" << port << " | xargs kill -9" << std::endl; std::cerr << "- Use a different port with -p " << std::endl; m_isRunning.store(false); m_startupFailed.store(true); return; } close(test_socket); } std::cout << "Port " << port << " is available, proceeding with server startup..." << std::endl; std::cout << "Calling listen()..." << std::endl; // We need to set m_isRunning after successful bind but before blocking // cpp-httplib doesn't provide a callback, so we set it optimistically // and clear it if listen() returns false m_isRunning.store(true); bool listenResult = m_httpServer->listen(host.c_str(), port); std::cout << "listen() returned: " << (listenResult ? "true" : "false") << std::endl; // If we reach here, server has stopped (either normally or due to error) m_isRunning.store(false); if (!listenResult) { std::cerr << "Server listen failed! This usually means port is in use or permission denied." << std::endl; } } catch (const std::exception& e) { std::cerr << "Exception in server thread: " << e.what() << std::endl; m_isRunning.store(false); } }