#include "auth_middleware.h" #include #include #include #include #include #include #include #include "logger.h" AuthMiddleware::AuthMiddleware(AuthConfig config, std::shared_ptr userManager) : m_config(std::move(config)), m_userManager(std::move(userManager)) { } AuthMiddleware::~AuthMiddleware() = default; bool AuthMiddleware::initialize() { try { // Validate configuration if (!validateConfig(m_config)) { return false; } // Initialize JWT auth if needed if (m_config.authMethod == AuthMethod::JWT) { m_jwtAuth = std::make_unique(m_config.jwtSecret, m_config.jwtExpirationMinutes, "stable-diffusion-rest"); } // Initialize default paths initializeDefaultPaths(); return true; } catch (const std::exception& e) { return false; } } AuthContext AuthMiddleware::authenticate(const httplib::Request& req, httplib::Response& /*res*/) { AuthContext context; context.authenticated = false; try { // Check if authentication is completely disabled if (isAuthenticationDisabled()) { context = createGuestContext(); context.authenticated = true; return context; } // Check if path requires authentication if (!requiresAuthentication(req.path)) { context = createGuestContext(); // Only allow guest access if authentication is completely disabled or guest access is explicitly enabled context.authenticated = (isAuthenticationDisabled() || m_config.enableGuestAccess); return context; } // Try different authentication methods based on configuration switch (m_config.authMethod) { case AuthMethod::JWT: context = authenticateJwt(req); break; case AuthMethod::API_KEY: context = authenticateApiKey(req); break; case AuthMethod::UNIX: context = authenticateUnix(req); break; case AuthMethod::PAM: context = authenticatePam(req); break; case AuthMethod::OPTIONAL: // Try JWT first, then API key, then allow guest context = authenticateJwt(req); if (!context.authenticated) { context = authenticateApiKey(req); } if (!context.authenticated && m_config.enableGuestAccess) { context = createGuestContext(); context.authenticated = true; } break; case AuthMethod::NONE: default: context = createGuestContext(); context.authenticated = true; break; } // Check if user has required permissions for this path if (context.authenticated && !hasPathAccess(req.path, context.permissions)) { context.authenticated = false; context.errorMessage = "Insufficient permissions for this endpoint"; context.errorCode = "INSUFFICIENT_PERMISSIONS"; } // Log authentication attempt logAuthAttempt(req, context, context.authenticated); } catch (const std::exception& e) { context.authenticated = false; context.errorMessage = "Authentication error: " + std::string(e.what()); context.errorCode = "AUTH_ERROR"; } return context; } bool AuthMiddleware::requiresAuthentication(const std::string& path) const { // First check if authentication is completely disabled if (isAuthenticationDisabled()) { return false; } // Authentication is enabled, check if path is explicitly public // Only paths in publicPaths are accessible without authentication if (pathMatchesPattern(path, m_config.publicPaths)) { return false; } // All other paths require authentication when auth is enabled return true; } bool AuthMiddleware::requiresAdminAccess(const std::string& path) const { return pathMatchesPattern(path, m_config.adminPaths); } bool AuthMiddleware::requiresUserAccess(const std::string& path) const { return pathMatchesPattern(path, m_config.userPaths); } bool AuthMiddleware::hasPathAccess(const std::string& path, const std::vector& permissions) const { // Check admin paths if (requiresAdminAccess(path)) { return JWTAuth::hasPermission(permissions, UserManager::Permissions::ADMIN); } // Check user paths if (requiresUserAccess(path)) { return JWTAuth::hasAnyPermission(permissions, {UserManager::Permissions::USER_MANAGE, UserManager::Permissions::ADMIN}); } // Default: allow access if authenticated return true; } AuthMiddleware::AuthHandler AuthMiddleware::createMiddleware(AuthHandler handler) { return [this, handler](const httplib::Request& req, httplib::Response& res, const AuthContext& /*context*/) { // Authenticate request AuthContext authContext = authenticate(req, res); // Check if authentication failed if (!authContext.authenticated) { sendAuthError(res, authContext.errorMessage, authContext.errorCode); return; } // Call the next handler handler(req, res, authContext); }; } void AuthMiddleware::sendAuthError(httplib::Response& res, const std::string& message, const std::string& errorCode, int statusCode) { nlohmann::json error = { {"error", {{"message", message}, {"code", errorCode}, {"timestamp", std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()}}}}; res.set_header("Content-Type", "application/nlohmann::json"); res.set_header("WWW-Authenticate", "Bearer realm=\"" + m_config.authRealm + "\""); res.status = statusCode; res.body = error.dump(); } void AuthMiddleware::sendAuthzError(httplib::Response& res, const std::string& message, const std::string& errorCode) { nlohmann::json error = { {"error", {{"message", message}, {"code", errorCode}, {"timestamp", std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count()}}}}; res.set_header("Content-Type", "application/nlohmann::json"); res.status = 403; res.body = error.dump(); } void AuthMiddleware::addPublicPath(const std::string& path) { m_config.publicPaths.push_back(path); } void AuthMiddleware::addAdminPath(const std::string& path) { m_config.adminPaths.push_back(path); } void AuthMiddleware::addUserPath(const std::string& path) { m_config.userPaths.push_back(path); } AuthContext AuthMiddleware::authenticateApiKey(const httplib::Request& req) { AuthContext context; context.authenticated = false; if (!m_userManager) { context.errorMessage = "User manager not available"; context.errorCode = "USER_MANAGER_UNAVAILABLE"; return context; } // Extract API key from header std::string apiKey = extractToken(req, "X-API-Key"); if (apiKey.empty()) { context.errorMessage = "Missing API key"; context.errorCode = "MISSING_API_KEY"; return context; } // Validate API key auto result = m_userManager->authenticateApiKey(apiKey); if (!result.success) { context.errorMessage = result.errorMessage; context.errorCode = result.errorCode; return context; } // API key is valid context.authenticated = true; context.userId = result.userId; context.username = result.username; context.role = result.role; context.permissions = result.permissions; context.authMethod = "API_KEY"; return context; } AuthContext AuthMiddleware::authenticateUnix(const httplib::Request& req) { AuthContext context; context.authenticated = false; if (!m_userManager) { context.errorMessage = "User manager not available"; context.errorCode = "USER_MANAGER_UNAVAILABLE"; return context; } // Check if Unix authentication is the configured method if (m_config.authMethod != AuthMethod::UNIX) { context.errorMessage = "Unix authentication not available"; context.errorCode = "UNIX_AUTH_UNAVAILABLE"; return context; } // For Unix auth, we need to get username and password from request std::string username; std::string password; // Try to extract from JSON body (for login API) std::string contentType = req.get_header_value("Content-Type"); if (contentType.find("application/nlohmann::json") != std::string::npos) { try { nlohmann::json body = nlohmann::json::parse(req.body); username = body.value("username", ""); password = body.value("password", ""); } catch (const nlohmann::json::exception& e) { // Invalid JSON, continue with other methods } } // If no credentials in body, check headers if (username.empty()) { username = req.get_header_value("X-Unix-User"); // Also check Authorization header for Bearer token (for UI requests after login) if (username.empty()) { std::string authHeader = req.get_header_value("Authorization"); if (!authHeader.empty() && authHeader.find("Bearer ") == 0) { std::string token = authHeader.substr(7); // Remove "Bearer " // Check if this is a Unix token if (token.find("unix_token_") == 0) { // Extract username from token size_t lastUnderscore = token.find_last_of('_'); if (lastUnderscore != std::string::npos) { username = token.substr(lastUnderscore + 1); } } } } } if (username.empty()) { // Check if this is a request for the login page or API endpoints // For UI requests, we'll let the UI handler show the login page // For API requests, we need to return an error std::string path = req.path; if (path.find("/ui/") == 0 || path == "/ui") { // This is a UI request, let it proceed to show the login page context = createGuestContext(); context.authenticated = false; // Ensure it's false to trigger login page return context; } else { // This is an API request, return error context.errorMessage = "Missing Unix username"; context.errorCode = "MISSING_UNIX_USER"; return context; } } // Authenticate Unix user (with or without password depending on PAM availability) auto result = m_userManager->authenticateUnix(username, password); if (!result.success) { context.errorMessage = result.errorMessage; context.errorCode = result.errorCode; return context; } // Unix authentication successful context.authenticated = true; context.userId = result.userId; context.username = result.username; context.role = result.role; context.permissions = result.permissions; context.authMethod = "UNIX"; return context; } AuthContext AuthMiddleware::authenticatePam(const httplib::Request& req) { AuthContext context; context.authenticated = false; if (!m_userManager) { context.errorMessage = "User manager not available"; context.errorCode = "USER_MANAGER_UNAVAILABLE"; return context; } // Check if PAM authentication is the configured method if (m_config.authMethod != AuthMethod::PAM) { context.errorMessage = "PAM authentication not available"; context.errorCode = "PAM_AUTH_UNAVAILABLE"; return context; } // For PAM auth, we need to get username and password from request // This could be from a JSON body for login requests std::string username; std::string password; // Try to extract from JSON body (for login API) std::string contentType = req.get_header_value("Content-Type"); if (contentType.find("application/nlohmann::json") != std::string::npos) { try { nlohmann::json body = nlohmann::json::parse(req.body); username = body.value("username", ""); password = body.value("password", ""); } catch (const nlohmann::json::exception& e) { // Invalid JSON } } // If no credentials in body, check Authorization header for basic auth if (username.empty() || password.empty()) { std::string authHeader = req.get_header_value("Authorization"); if (!authHeader.empty() && authHeader.find("Basic ") == 0) { // Decode basic auth std::string basicAuth = authHeader.substr(6); // Remove "Basic " // Note: In a real implementation, you'd decode base64 here // For now, we'll expect the credentials to be in the JSON body } } if (username.empty() || password.empty()) { // Check if this is a request for the login page or API endpoints // For UI requests, we'll let the UI handler show the login page // For API requests, we need to return an error std::string path = req.path; if (path.find("/ui/") == 0 || path == "/ui") { // This is a UI request, let it proceed to show the login page context = createGuestContext(); context.authenticated = false; // Ensure it's false to trigger login page return context; } else { // This is an API request, return error context.errorMessage = "Missing PAM credentials"; context.errorCode = "MISSING_PAM_CREDENTIALS"; return context; } } // Authenticate PAM user auto result = m_userManager->authenticatePam(username, password); if (!result.success) { context.errorMessage = result.errorMessage; context.errorCode = result.errorCode; return context; } // PAM authentication successful context.authenticated = true; context.userId = result.userId; context.username = result.username; context.role = result.role; context.permissions = result.permissions; context.authMethod = "PAM"; return context; } std::string AuthMiddleware::extractToken(const httplib::Request& req, const std::string& headerName) const { std::string authHeader = req.get_header_value(headerName); if (headerName == "Authorization") { return JWTAuth::extractTokenFromHeader(authHeader); } return authHeader; } AuthContext AuthMiddleware::createGuestContext() const { AuthContext context; context.authenticated = false; context.userId = "guest"; context.username = "guest"; context.role = "guest"; context.permissions = UserManager::getDefaultPermissions(UserManager::UserRole::GUEST); context.authMethod = "none"; return context; } void AuthMiddleware::logAuthAttempt(const httplib::Request& req, const AuthContext& context, bool success) const { auto now = std::chrono::system_clock::now(); auto time_t_now = std::chrono::system_clock::to_time_t(now); std::stringstream timestamp_ss; timestamp_ss << std::put_time(std::gmtime(&time_t_now), "%Y-%m-%dT%H:%M:%S"); auto ms = std::chrono::duration_cast(now.time_since_epoch()) % 1000; timestamp_ss << "." << std::setfill('0') << std::setw(3) << ms.count() << "Z"; std::string clientIp = getClientIp(req); std::string userAgent = getUserAgent(req); std::string username = context.authenticated ? context.username : "unknown"; std::string status = success ? "success" : "failure"; std::string message = "Authentication " + status + " - " + "timestamp=" + timestamp_ss.str() + ", " + "ip=" + (clientIp.empty() ? "unknown" : clientIp) + ", " + "username=" + (username.empty() ? "unknown" : username) + ", " + "path=" + req.path + ", " + "user-agent=" + (userAgent.empty() ? "unknown" : userAgent); if (success) { Logger::getInstance().info(message); } else { Logger::getInstance().warning(message); } } std::string AuthMiddleware::getClientIp(const httplib::Request& req) { // Check various headers for client IP std::string ip = req.get_header_value("X-Forwarded-For"); if (ip.empty()) { ip = req.get_header_value("X-Real-IP"); } if (ip.empty()) { ip = req.get_header_value("X-Client-IP"); } if (ip.empty()) { ip = req.remote_addr; } return ip; } std::string AuthMiddleware::getUserAgent(const httplib::Request& req) { return req.get_header_value("User-Agent"); } bool AuthMiddleware::validateConfig(AuthConfig config) { // Validate JWT configuration if (config.authMethod == AuthMethod::JWT) { if (config.jwtSecret.empty()) { // Will be auto-generated } if (config.jwtExpirationMinutes <= 0 || config.jwtExpirationMinutes > 1440) { return false; // Max 24 hours } } // Validate realm if (config.authRealm.empty()) { return false; } return true; } void AuthMiddleware::initializeDefaultPaths() { // Parse custom public paths if provided if (!m_config.customPublicPaths.empty()) { // Split comma-separated paths std::stringstream ss(m_config.customPublicPaths); std::string path; while (std::getline(ss, path, ',')) { // Trim whitespace path.erase(0, path.find_first_not_of(" \t")); path.erase(path.find_last_not_of(" \t") + 1); if (!path.empty()) { // Ensure path starts with / if (path[0] != '/') { path = "/" + path; } m_config.publicPaths.push_back(path); } } } // Add default public paths - only truly public endpoints when auth is enabled if (m_config.publicPaths.empty()) { m_config.publicPaths = { "/api/health", "/api/status" // Note: Model discovery endpoints removed from public paths // These now require authentication when auth is enabled }; } // Add default admin paths if (m_config.adminPaths.empty()) { m_config.adminPaths = { "/api/users", "/api/auth/users", "/api/system/restart"}; } // Add default user paths if (m_config.userPaths.empty()) { m_config.userPaths = { "/api/generate", "/api/queue", "/api/models/load", "/api/models/unload", "/api/auth/profile", "/api/auth/api-keys", // Model discovery endpoints moved to user paths "/api/models", "/api/models/types", "/api/models/directories", "/api/samplers", "/api/schedulers", "/api/parameters"}; } } bool AuthMiddleware::isAuthenticationDisabled() const { return m_config.authMethod == AuthMethod::NONE; } AuthContext AuthMiddleware::authenticateJwt(const httplib::Request& req) { AuthContext context; context.authenticated = false; if (!m_jwtAuth) { context.errorMessage = "JWT authentication not available"; context.errorCode = "JWT_AUTH_UNAVAILABLE"; return context; } // Extract JWT token from Authorization header std::string token = extractToken(req, "Authorization"); if (token.empty()) { context.errorMessage = "Missing JWT token"; context.errorCode = "MISSING_JWT_TOKEN"; return context; } // Validate JWT token auto result = m_jwtAuth->validateToken(token); if (!result.success) { context.errorMessage = result.errorMessage; context.errorCode = "INVALID_JWT_TOKEN"; return context; } // JWT is valid context.authenticated = true; context.userId = result.userId; context.username = result.username; context.role = result.role; context.permissions = result.permissions; context.authMethod = "JWT"; return context; } std::vector AuthMiddleware::getRequiredPermissions(const std::string& path) const { std::vector permissions; if (requiresAdminAccess(path)) { permissions.push_back("admin"); } else if (requiresUserAccess(path)) { permissions.push_back("user"); } return permissions; } bool AuthMiddleware::pathMatchesPattern(const std::string& path, const std::vector& patterns) { for (const auto& pattern : patterns) { if (pattern == path) { return true; } // Check if pattern is a prefix if (pattern.back() == '/' && path.find(pattern) == 0) { return true; } // Simple wildcard matching if (pattern.find('*') != std::string::npos) { std::regex regexPattern(pattern, std::regex_constants::icase); if (std::regex_match(path, regexPattern)) { return true; } } } return false; } AuthConfig AuthMiddleware::getConfig() const { return m_config; } void AuthMiddleware::updateConfig(AuthConfig config) { m_config = std::move(config); } // Factory functions namespace AuthMiddlewareFactory { std::unique_ptr createDefault(std::shared_ptr userManager, const std::string& /*dataDir*/) { AuthConfig config; config.authMethod = AuthMethod::NONE; config.enableGuestAccess = true; config.jwtSecret = ""; config.jwtExpirationMinutes = 60; config.authRealm = "stable-diffusion-rest"; return std::make_unique(config, userManager); } std::unique_ptr createMultiMethod(std::shared_ptr userManager, AuthConfig config) { return std::make_unique(std::move(config), userManager); } std::unique_ptr createJwtOnly(std::shared_ptr userManager, const std::string& jwtSecret, int jwtExpirationMinutes) { AuthConfig config; config.authMethod = AuthMethod::JWT; config.enableGuestAccess = false; config.jwtSecret = jwtSecret; config.jwtExpirationMinutes = jwtExpirationMinutes; config.authRealm = "stable-diffusion-rest"; return std::make_unique(config, userManager); } std::unique_ptr createApiKeyOnly(std::shared_ptr userManager) { AuthConfig config; config.authMethod = AuthMethod::API_KEY; config.enableGuestAccess = false; config.authRealm = "stable-diffusion-rest"; return std::make_unique(config, userManager); } } // namespace AuthMiddlewareFactory