#include "auth_middleware.h" #include #include #include #include #include #include #include using json = nlohmann::json; AuthMiddleware::AuthMiddleware(const AuthConfig& config, std::shared_ptr userManager) : m_config(config) , m_userManager(userManager) { } AuthMiddleware::~AuthMiddleware() = default; bool AuthMiddleware::initialize() { try { // Validate configuration if (!validateConfig(m_config)) { return false; } // Initialize JWT auth if needed if (m_config.authMethod == AuthMethod::JWT) { m_jwtAuth = std::make_unique(m_config.jwtSecret, m_config.jwtExpirationMinutes, "stable-diffusion-rest"); } // Initialize default paths initializeDefaultPaths(); return true; } catch (const std::exception& e) { return false; } } AuthContext AuthMiddleware::authenticate(const httplib::Request& req, httplib::Response& res) { AuthContext context; context.authenticated = false; try { // Check if authentication is completely disabled if (isAuthenticationDisabled()) { context = createGuestContext(); context.authenticated = true; return context; } // Check if path requires authentication if (!requiresAuthentication(req.path)) { context = createGuestContext(); context.authenticated = m_config.enableGuestAccess; return context; } // Try different authentication methods based on configuration switch (m_config.authMethod) { case AuthMethod::JWT: context = authenticateJwt(req); break; case AuthMethod::API_KEY: context = authenticateApiKey(req); break; case AuthMethod::UNIX: context = authenticateUnix(req); break; case AuthMethod::OPTIONAL: // Try JWT first, then API key, then allow guest context = authenticateJwt(req); if (!context.authenticated) { context = authenticateApiKey(req); } if (!context.authenticated && m_config.enableGuestAccess) { context = createGuestContext(); context.authenticated = true; } break; case AuthMethod::NONE: default: context = createGuestContext(); context.authenticated = true; break; } // Check if user has required permissions for this path if (context.authenticated && !hasPathAccess(req.path, context.permissions)) { context.authenticated = false; context.errorMessage = "Insufficient permissions for this endpoint"; context.errorCode = "INSUFFICIENT_PERMISSIONS"; } // Log authentication attempt logAuthAttempt(req, context, context.authenticated); } catch (const std::exception& e) { context.authenticated = false; context.errorMessage = "Authentication error: " + std::string(e.what()); context.errorCode = "AUTH_ERROR"; } return context; } bool AuthMiddleware::requiresAuthentication(const std::string& path) const { // Check if path is public if (pathMatchesPattern(path, m_config.publicPaths)) { return false; } // All other paths require authentication unless auth is completely disabled return !isAuthenticationDisabled(); } bool AuthMiddleware::requiresAdminAccess(const std::string& path) const { return pathMatchesPattern(path, m_config.adminPaths); } bool AuthMiddleware::requiresUserAccess(const std::string& path) const { return pathMatchesPattern(path, m_config.userPaths); } bool AuthMiddleware::hasPathAccess(const std::string& path, const std::vector& permissions) const { // Check admin paths if (requiresAdminAccess(path)) { return JWTAuth::hasPermission(permissions, UserManager::Permissions::ADMIN); } // Check user paths if (requiresUserAccess(path)) { return JWTAuth::hasAnyPermission(permissions, { UserManager::Permissions::USER_MANAGE, UserManager::Permissions::ADMIN }); } // Default: allow access if authenticated return true; } AuthMiddleware::AuthHandler AuthMiddleware::createMiddleware(AuthHandler handler) { return [this, handler](const httplib::Request& req, httplib::Response& res, const AuthContext& context) { // Authenticate request AuthContext authContext = authenticate(req, res); // Check if authentication failed if (!authContext.authenticated) { sendAuthError(res, authContext.errorMessage, authContext.errorCode); return; } // Call the next handler handler(req, res, authContext); }; } void AuthMiddleware::sendAuthError(httplib::Response& res, const std::string& message, const std::string& errorCode, int statusCode) { json error = { {"error", { {"message", message}, {"code", errorCode}, {"timestamp", std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()).count()} }} }; res.set_header("Content-Type", "application/json"); res.set_header("WWW-Authenticate", "Bearer realm=\"" + m_config.authRealm + "\""); res.status = statusCode; res.body = error.dump(); } void AuthMiddleware::sendAuthzError(httplib::Response& res, const std::string& message, const std::string& errorCode) { json error = { {"error", { {"message", message}, {"code", errorCode}, {"timestamp", std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()).count()} }} }; res.set_header("Content-Type", "application/json"); res.status = 403; res.body = error.dump(); } void AuthMiddleware::addPublicPath(const std::string& path) { m_config.publicPaths.push_back(path); } void AuthMiddleware::addAdminPath(const std::string& path) { m_config.adminPaths.push_back(path); } void AuthMiddleware::addUserPath(const std::string& path) { m_config.userPaths.push_back(path); } void AuthMiddleware::setJwtSecret(const std::string& secret) { m_config.jwtSecret = secret; if (m_jwtAuth) { m_jwtAuth->setIssuer("stable-diffusion-rest"); } } std::string AuthMiddleware::getJwtSecret() const { return m_config.jwtSecret; } void AuthMiddleware::setAuthMethod(UserManager::AuthMethod method) { m_config.authMethod = static_cast(method); } UserManager::AuthMethod AuthMiddleware::getAuthMethod() const { return static_cast(m_config.authMethod); } void AuthMiddleware::setGuestAccessEnabled(bool enable) { m_config.enableGuestAccess = enable; } bool AuthMiddleware::isGuestAccessEnabled() const { return m_config.enableGuestAccess; } AuthConfig AuthMiddleware::getConfig() const { return m_config; } void AuthMiddleware::updateConfig(const AuthConfig& config) { m_config = config; if (m_config.authMethod == AuthMethod::JWT) { m_jwtAuth = std::make_unique(m_config.jwtSecret, m_config.jwtExpirationMinutes, "stable-diffusion-rest"); } } AuthContext AuthMiddleware::authenticateJwt(const httplib::Request& req) { AuthContext context; context.authenticated = false; if (!m_jwtAuth) { context.errorMessage = "JWT authentication not configured"; context.errorCode = "JWT_NOT_CONFIGURED"; return context; } // Extract token from header std::string token = extractToken(req, "Authorization"); if (token.empty()) { context.errorMessage = "Missing authorization token"; context.errorCode = "MISSING_TOKEN"; return context; } // Validate token auto result = m_jwtAuth->validateToken(token); if (!result.success) { context.errorMessage = result.errorMessage; context.errorCode = result.errorCode; return context; } // Token is valid context.authenticated = true; context.userId = result.userId; context.username = result.username; context.role = result.role; context.permissions = result.permissions; context.authMethod = "JWT"; return context; } AuthContext AuthMiddleware::authenticateApiKey(const httplib::Request& req) { AuthContext context; context.authenticated = false; if (!m_userManager) { context.errorMessage = "User manager not available"; context.errorCode = "USER_MANAGER_UNAVAILABLE"; return context; } // Extract API key from header std::string apiKey = extractToken(req, "X-API-Key"); if (apiKey.empty()) { context.errorMessage = "Missing API key"; context.errorCode = "MISSING_API_KEY"; return context; } // Validate API key auto result = m_userManager->authenticateApiKey(apiKey); if (!result.success) { context.errorMessage = result.errorMessage; context.errorCode = result.errorCode; return context; } // API key is valid context.authenticated = true; context.userId = result.userId; context.username = result.username; context.role = result.role; context.permissions = result.permissions; context.authMethod = "API_KEY"; return context; } AuthContext AuthMiddleware::authenticateUnix(const httplib::Request& req) { AuthContext context; context.authenticated = false; if (!m_userManager || !m_userManager->isUnixAuthEnabled()) { context.errorMessage = "Unix authentication not available"; context.errorCode = "UNIX_AUTH_UNAVAILABLE"; return context; } // For Unix auth, we need to get username from request // This could be from a header or client certificate std::string username = req.get_header_value("X-Unix-User"); if (username.empty()) { context.errorMessage = "Missing Unix username"; context.errorCode = "MISSING_UNIX_USER"; return context; } // Authenticate Unix user auto result = m_userManager->authenticateUnix(username); if (!result.success) { context.errorMessage = result.errorMessage; context.errorCode = result.errorCode; return context; } // Unix authentication successful context.authenticated = true; context.userId = result.userId; context.username = result.username; context.role = result.role; context.permissions = result.permissions; context.authMethod = "UNIX"; return context; } std::string AuthMiddleware::extractToken(const httplib::Request& req, const std::string& headerName) const { std::string authHeader = req.get_header_value(headerName); if (headerName == "Authorization") { return JWTAuth::extractTokenFromHeader(authHeader); } return authHeader; } AuthContext AuthMiddleware::createGuestContext() const { AuthContext context; context.authenticated = false; context.userId = "guest"; context.username = "guest"; context.role = "guest"; context.permissions = UserManager::getDefaultPermissions(UserManager::UserRole::GUEST); context.authMethod = "none"; return context; } bool AuthMiddleware::pathMatchesPattern(const std::string& path, const std::vector& patterns) { for (const auto& pattern : patterns) { // Simple exact match for now if (path == pattern) { return true; } // Check for prefix match (pattern ends with *) if (pattern.length() > 1 && pattern.back() == '*') { std::string prefix = pattern.substr(0, pattern.length() - 1); if (path.length() >= prefix.length() && path.substr(0, prefix.length()) == prefix) { return true; } } } return false; } std::vector AuthMiddleware::getRequiredPermissions(const std::string& path) const { if (requiresAdminAccess(path)) { return {UserManager::Permissions::ADMIN}; } if (requiresUserAccess(path)) { return {UserManager::Permissions::READ}; } return {}; } void AuthMiddleware::logAuthAttempt(const httplib::Request& req, const AuthContext& context, bool success) const { // In a real implementation, this would log to a file or logging system std::string clientIp = getClientIp(req); std::string userAgent = getUserAgent(req); if (success) { // Log successful authentication } else { // Log failed authentication attempt } } std::string AuthMiddleware::getClientIp(const httplib::Request& req) { // Check various headers for client IP std::string ip = req.get_header_value("X-Forwarded-For"); if (ip.empty()) { ip = req.get_header_value("X-Real-IP"); } if (ip.empty()) { ip = req.get_header_value("X-Client-IP"); } if (ip.empty()) { ip = req.remote_addr; } return ip; } std::string AuthMiddleware::getUserAgent(const httplib::Request& req) { return req.get_header_value("User-Agent"); } bool AuthMiddleware::validateConfig(const AuthConfig& config) { // Validate JWT configuration if (config.authMethod == AuthMethod::JWT) { if (config.jwtSecret.empty()) { // Will be auto-generated } if (config.jwtExpirationMinutes <= 0 || config.jwtExpirationMinutes > 1440) { return false; // Max 24 hours } } // Validate realm if (config.authRealm.empty()) { return false; } return true; } void AuthMiddleware::initializeDefaultPaths() { // Add default public paths if (m_config.publicPaths.empty()) { m_config.publicPaths = { "/api/health", "/api/status", "/api/samplers", "/api/schedulers", "/api/parameters", "/api/models", "/api/models/types", "/api/models/directories" }; } // Add default admin paths if (m_config.adminPaths.empty()) { m_config.adminPaths = { "/api/users", "/api/auth/users", "/api/system/restart" }; } // Add default user paths if (m_config.userPaths.empty()) { m_config.userPaths = { "/api/generate", "/api/queue", "/api/models/load", "/api/models/unload", "/api/auth/profile", "/api/auth/api-keys" }; } } bool AuthMiddleware::isAuthenticationDisabled() const { return m_config.authMethod == AuthMethod::NONE; } // Factory functions namespace AuthMiddlewareFactory { std::unique_ptr createDefault(std::shared_ptr userManager, const std::string& dataDir) { AuthConfig config; config.authMethod = AuthMethod::NONE; config.enableGuestAccess = true; config.jwtSecret = ""; config.jwtExpirationMinutes = 60; config.authRealm = "stable-diffusion-rest"; return std::make_unique(config, userManager); } std::unique_ptr createJwtOnly(std::shared_ptr userManager, const std::string& jwtSecret, int jwtExpirationMinutes) { AuthConfig config; config.authMethod = AuthMethod::JWT; config.enableGuestAccess = false; config.jwtSecret = jwtSecret; config.jwtExpirationMinutes = jwtExpirationMinutes; config.authRealm = "stable-diffusion-rest"; config.enableUnixAuth = false; return std::make_unique(config, userManager); } std::unique_ptr createApiKeyOnly(std::shared_ptr userManager) { AuthConfig config; config.authMethod = AuthMethod::API_KEY; config.enableGuestAccess = false; config.authRealm = "stable-diffusion-rest"; config.enableUnixAuth = false; return std::make_unique(config, userManager); } std::unique_ptr createMultiMethod(std::shared_ptr userManager, const AuthConfig& config) { return std::make_unique(config, userManager); } std::unique_ptr createDevelopment() { AuthConfig config; config.authMethod = AuthMethod::NONE; config.enableGuestAccess = true; config.authRealm = "stable-diffusion-rest"; return std::make_unique(config, nullptr); } } // namespace AuthMiddlewareFactory