| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685 |
- #include "auth_middleware.h"
- #include <httplib.h>
- #include <chrono>
- #include <iomanip>
- #include <nlohmann/json.hpp>
- #include <regex>
- #include <sstream>
- #include <utility>
- #include "logger.h"
- AuthMiddleware::AuthMiddleware(AuthConfig config, std::shared_ptr<UserManager> userManager)
- : m_config(std::move(config)), m_userManager(std::move(userManager)) {
- }
- AuthMiddleware::~AuthMiddleware() = default;
- bool AuthMiddleware::initialize() {
- try {
- // Validate configuration
- if (!validateConfig(m_config)) {
- return false;
- }
- // Initialize JWT auth if needed
- if (m_config.authMethod == AuthMethod::JWT) {
- m_jwtAuth = std::make_unique<JWTAuth>(m_config.jwtSecret,
- m_config.jwtExpirationMinutes,
- m_config.authRealm,
- m_config.jwtAudience);
- }
- // Initialize default paths
- initializeDefaultPaths();
- return true;
- } catch (const std::exception& e) {
- return false;
- }
- }
- AuthContext AuthMiddleware::authenticate(const httplib::Request& req, httplib::Response& /*res*/) {
- AuthContext context;
- context.authenticated = false;
- try {
- // Check if authentication is completely disabled
- if (isAuthenticationDisabled()) {
- context = createGuestContext();
- context.authenticated = true;
- return context;
- }
- // Check if path requires authentication
- if (!requiresAuthentication(req.path)) {
- context = createGuestContext();
- // Only allow guest access if authentication is completely disabled or guest access is explicitly enabled
- context.authenticated = (isAuthenticationDisabled() || m_config.enableGuestAccess);
- return context;
- }
- // Try different authentication methods based on configuration
- switch (m_config.authMethod) {
- case AuthMethod::JWT:
- context = authenticateJwt(req);
- break;
- case AuthMethod::API_KEY:
- context = authenticateApiKey(req);
- break;
- case AuthMethod::UNIX:
- context = authenticateUnix(req);
- break;
- case AuthMethod::PAM:
- context = authenticatePam(req);
- break;
- case AuthMethod::OPTIONAL:
- // Try JWT first, then API key, then allow guest
- context = authenticateJwt(req);
- if (!context.authenticated) {
- context = authenticateApiKey(req);
- }
- if (!context.authenticated && m_config.enableGuestAccess) {
- context = createGuestContext();
- context.authenticated = true;
- }
- break;
- case AuthMethod::NONE:
- default:
- context = createGuestContext();
- context.authenticated = true;
- break;
- }
- // Check if user has required permissions for this path
- if (context.authenticated && !hasPathAccess(req.path, context.permissions)) {
- context.authenticated = false;
- context.errorMessage = "Insufficient permissions for this endpoint";
- context.errorCode = "INSUFFICIENT_PERMISSIONS";
- }
- // Log authentication attempt
- logAuthAttempt(req, context, context.authenticated);
- } catch (const std::exception& e) {
- context.authenticated = false;
- context.errorMessage = "Authentication error: " + std::string(e.what());
- context.errorCode = "AUTH_ERROR";
- }
- return context;
- }
- bool AuthMiddleware::requiresAuthentication(const std::string& path) const {
- // First check if authentication is completely disabled
- if (isAuthenticationDisabled()) {
- return false;
- }
- // Authentication is enabled, check if path is explicitly public
- // Only paths in publicPaths are accessible without authentication
- if (pathMatchesPattern(path, m_config.publicPaths)) {
- return false;
- }
- // All other paths require authentication when auth is enabled
- return true;
- }
- bool AuthMiddleware::requiresAdminAccess(const std::string& path) const {
- return pathMatchesPattern(path, m_config.adminPaths);
- }
- bool AuthMiddleware::requiresUserAccess(const std::string& path) const {
- return pathMatchesPattern(path, m_config.userPaths);
- }
- bool AuthMiddleware::hasPathAccess(const std::string& path,
- const std::vector<std::string>& permissions) const {
- // Check admin paths
- if (requiresAdminAccess(path)) {
- return JWTAuth::hasPermission(permissions, UserManager::Permissions::ADMIN);
- }
- // Check user paths
- if (requiresUserAccess(path)) {
- return JWTAuth::hasAnyPermission(permissions, {UserManager::Permissions::USER_MANAGE,
- UserManager::Permissions::ADMIN});
- }
- // Default: allow access if authenticated
- return true;
- }
- AuthMiddleware::AuthHandler AuthMiddleware::createMiddleware(AuthHandler handler) {
- return [this, handler](const httplib::Request& req, httplib::Response& res, const AuthContext& /*context*/) {
- // Authenticate request
- AuthContext authContext = authenticate(req, res);
- // Check if authentication failed
- if (!authContext.authenticated) {
- sendAuthError(res, authContext.errorMessage, authContext.errorCode);
- return;
- }
- // Call the next handler
- handler(req, res, authContext);
- };
- }
- void AuthMiddleware::sendAuthError(httplib::Response& res,
- const std::string& message,
- const std::string& errorCode,
- int statusCode) {
- nlohmann::json error = {
- {"error", {{"message", message}, {"code", errorCode}, {"timestamp", std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()).count()}}}};
- res.set_header("Content-Type", "application/nlohmann::json");
- res.set_header("WWW-Authenticate", "Bearer realm=\"" + m_config.authRealm + "\"");
- res.status = statusCode;
- res.body = error.dump();
- }
- void AuthMiddleware::sendAuthzError(httplib::Response& res,
- const std::string& message,
- const std::string& errorCode) {
- nlohmann::json error = {
- {"error", {{"message", message}, {"code", errorCode}, {"timestamp", std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()).count()}}}};
- res.set_header("Content-Type", "application/nlohmann::json");
- res.status = 403;
- res.body = error.dump();
- }
- void AuthMiddleware::addPublicPath(const std::string& path) {
- m_config.publicPaths.push_back(path);
- }
- void AuthMiddleware::addAdminPath(const std::string& path) {
- m_config.adminPaths.push_back(path);
- }
- void AuthMiddleware::addUserPath(const std::string& path) {
- m_config.userPaths.push_back(path);
- }
- AuthContext AuthMiddleware::authenticateApiKey(const httplib::Request& req) {
- AuthContext context;
- context.authenticated = false;
- if (!m_userManager) {
- context.errorMessage = "User manager not available";
- context.errorCode = "USER_MANAGER_UNAVAILABLE";
- return context;
- }
- // Extract API key from header
- std::string apiKey = extractToken(req, "X-API-Key");
- if (apiKey.empty()) {
- context.errorMessage = "Missing API key";
- context.errorCode = "MISSING_API_KEY";
- return context;
- }
- // Validate API key
- auto result = m_userManager->authenticateApiKey(apiKey);
- if (!result.success) {
- context.errorMessage = result.errorMessage;
- context.errorCode = result.errorCode;
- return context;
- }
- // API key is valid
- context.authenticated = true;
- context.userId = result.userId;
- context.username = result.username;
- context.role = result.role;
- context.permissions = result.permissions;
- context.authMethod = "API_KEY";
- return context;
- }
- AuthContext AuthMiddleware::authenticateUnix(const httplib::Request& req) {
- AuthContext context;
- context.authenticated = false;
- if (!m_userManager) {
- context.errorMessage = "User manager not available";
- context.errorCode = "USER_MANAGER_UNAVAILABLE";
- return context;
- }
- // Check if Unix authentication is the configured method
- if (m_config.authMethod != AuthMethod::UNIX) {
- context.errorMessage = "Unix authentication not available";
- context.errorCode = "UNIX_AUTH_UNAVAILABLE";
- return context;
- }
- // For Unix auth, we need to get username and password from request
- std::string username;
- std::string password;
- // Try to extract from JSON body (for login API)
- std::string contentType = req.get_header_value("Content-Type");
- if (contentType.find("application/nlohmann::json") != std::string::npos) {
- try {
- nlohmann::json body = nlohmann::json::parse(req.body);
- username = body.value("username", "");
- password = body.value("password", "");
- } catch (const nlohmann::json::exception& e) {
- // Invalid JSON, continue with other methods
- }
- }
- // If no credentials in body, check headers
- if (username.empty()) {
- username = req.get_header_value("X-Unix-User");
- // Also check Authorization header for Bearer token (for UI requests after login)
- if (username.empty()) {
- std::string authHeader = req.get_header_value("Authorization");
- if (!authHeader.empty() && authHeader.find("Bearer ") == 0) {
- std::string token = authHeader.substr(7); // Remove "Bearer "
- // Check if this is a Unix token
- if (token.find("unix_token_") == 0) {
- // Extract username from token
- size_t lastUnderscore = token.find_last_of('_');
- if (lastUnderscore != std::string::npos) {
- username = token.substr(lastUnderscore + 1);
- }
- }
- }
- }
- }
- if (username.empty()) {
- // Check if this is a request for the login page or API endpoints
- // For UI requests, we'll let the UI handler show the login page
- // For API requests, we need to return an error
- std::string path = req.path;
- if (path.find("/ui/") == 0 || path == "/ui") {
- // This is a UI request, let it proceed to show the login page
- context = createGuestContext();
- context.authenticated = false; // Ensure it's false to trigger login page
- return context;
- } else {
- // This is an API request, return error
- context.errorMessage = "Missing Unix username";
- context.errorCode = "MISSING_UNIX_USER";
- return context;
- }
- }
- // Authenticate Unix user (with or without password depending on PAM availability)
- auto result = m_userManager->authenticateUnix(username, password);
- if (!result.success) {
- context.errorMessage = result.errorMessage;
- context.errorCode = result.errorCode;
- return context;
- }
- // Unix authentication successful
- context.authenticated = true;
- context.userId = result.userId;
- context.username = result.username;
- context.role = result.role;
- context.permissions = result.permissions;
- context.authMethod = "UNIX";
- return context;
- }
- AuthContext AuthMiddleware::authenticatePam(const httplib::Request& req) {
- AuthContext context;
- context.authenticated = false;
- if (!m_userManager) {
- context.errorMessage = "User manager not available";
- context.errorCode = "USER_MANAGER_UNAVAILABLE";
- return context;
- }
- // Check if PAM authentication is the configured method
- if (m_config.authMethod != AuthMethod::PAM) {
- context.errorMessage = "PAM authentication not available";
- context.errorCode = "PAM_AUTH_UNAVAILABLE";
- return context;
- }
- // For PAM auth, we need to get username and password from request
- // This could be from a JSON body for login requests
- std::string username;
- std::string password;
- // Try to extract from JSON body (for login API)
- std::string contentType = req.get_header_value("Content-Type");
- if (contentType.find("application/nlohmann::json") != std::string::npos) {
- try {
- nlohmann::json body = nlohmann::json::parse(req.body);
- username = body.value("username", "");
- password = body.value("password", "");
- } catch (const nlohmann::json::exception& e) {
- // Invalid JSON
- }
- }
- // If no credentials in body, check Authorization header for basic auth
- if (username.empty() || password.empty()) {
- std::string authHeader = req.get_header_value("Authorization");
- if (!authHeader.empty() && authHeader.find("Basic ") == 0) {
- // Decode basic auth
- std::string basicAuth = authHeader.substr(6); // Remove "Basic "
- // Note: In a real implementation, you'd decode base64 here
- // For now, we'll expect the credentials to be in the JSON body
- }
- }
- if (username.empty() || password.empty()) {
- // Check if this is a request for the login page or API endpoints
- // For UI requests, we'll let the UI handler show the login page
- // For API requests, we need to return an error
- std::string path = req.path;
- if (path.find("/ui/") == 0 || path == "/ui") {
- // This is a UI request, let it proceed to show the login page
- context = createGuestContext();
- context.authenticated = false; // Ensure it's false to trigger login page
- return context;
- } else {
- // This is an API request, return error
- context.errorMessage = "Missing PAM credentials";
- context.errorCode = "MISSING_PAM_CREDENTIALS";
- return context;
- }
- }
- // Authenticate PAM user
- auto result = m_userManager->authenticatePam(username, password);
- if (!result.success) {
- context.errorMessage = result.errorMessage;
- context.errorCode = result.errorCode;
- return context;
- }
- // PAM authentication successful
- context.authenticated = true;
- context.userId = result.userId;
- context.username = result.username;
- context.role = result.role;
- context.permissions = result.permissions;
- context.authMethod = "PAM";
- return context;
- }
- std::string AuthMiddleware::extractToken(const httplib::Request& req, const std::string& headerName) const {
- std::string authHeader = req.get_header_value(headerName);
- if (headerName == "Authorization") {
- return JWTAuth::extractTokenFromHeader(authHeader);
- }
- return authHeader;
- }
- AuthContext AuthMiddleware::createGuestContext() const {
- AuthContext context;
- context.authenticated = false;
- context.userId = "guest";
- context.username = "guest";
- context.role = "guest";
- context.permissions = UserManager::getDefaultPermissions(UserManager::UserRole::GUEST);
- context.authMethod = "none";
- return context;
- }
- void AuthMiddleware::logAuthAttempt(const httplib::Request& req,
- const AuthContext& context,
- bool success) const {
- auto now = std::chrono::system_clock::now();
- auto time_t_now = std::chrono::system_clock::to_time_t(now);
- std::stringstream timestamp_ss;
- timestamp_ss << std::put_time(std::gmtime(&time_t_now), "%Y-%m-%dT%H:%M:%S");
- auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(now.time_since_epoch()) % 1000;
- timestamp_ss << "." << std::setfill('0') << std::setw(3) << ms.count() << "Z";
- std::string clientIp = getClientIp(req);
- std::string userAgent = getUserAgent(req);
- std::string username = context.authenticated ? context.username : "unknown";
- std::string status = success ? "success" : "failure";
- std::string message = "Authentication " + status + " - " +
- "timestamp=" + timestamp_ss.str() + ", " +
- "ip=" + (clientIp.empty() ? "unknown" : clientIp) + ", " +
- "username=" + (username.empty() ? "unknown" : username) + ", " +
- "path=" + req.path + ", " +
- "user-agent=" + (userAgent.empty() ? "unknown" : userAgent);
- if (success) {
- Logger::getInstance().info(message);
- } else {
- Logger::getInstance().warning(message);
- }
- }
- std::string AuthMiddleware::getClientIp(const httplib::Request& req) {
- // Check various headers for client IP
- std::string ip = req.get_header_value("X-Forwarded-For");
- if (ip.empty()) {
- ip = req.get_header_value("X-Real-IP");
- }
- if (ip.empty()) {
- ip = req.get_header_value("X-Client-IP");
- }
- if (ip.empty()) {
- ip = req.remote_addr;
- }
- return ip;
- }
- std::string AuthMiddleware::getUserAgent(const httplib::Request& req) {
- return req.get_header_value("User-Agent");
- }
- bool AuthMiddleware::validateConfig(AuthConfig config) {
- // Validate JWT configuration
- if (config.authMethod == AuthMethod::JWT) {
- if (config.jwtSecret.empty()) {
- // Will be auto-generated
- }
- if (config.jwtExpirationMinutes <= 0 || config.jwtExpirationMinutes > 1440) {
- return false; // Max 24 hours
- }
- }
- // Validate realm
- if (config.authRealm.empty()) {
- return false;
- }
- return true;
- }
- void AuthMiddleware::initializeDefaultPaths() {
- // Parse custom public paths if provided
- if (!m_config.customPublicPaths.empty()) {
- // Split comma-separated paths
- std::stringstream ss(m_config.customPublicPaths);
- std::string path;
- while (std::getline(ss, path, ',')) {
- // Trim whitespace
- path.erase(0, path.find_first_not_of(" \t"));
- path.erase(path.find_last_not_of(" \t") + 1);
- if (!path.empty()) {
- // Ensure path starts with /
- if (path[0] != '/') {
- path = "/" + path;
- }
- m_config.publicPaths.push_back(path);
- }
- }
- }
- // Add default public paths - only truly public endpoints when auth is enabled
- if (m_config.publicPaths.empty()) {
- m_config.publicPaths = {
- "/api/health",
- "/api/status"
- // Note: Model discovery endpoints removed from public paths
- // These now require authentication when auth is enabled
- };
- }
- // Add default admin paths
- if (m_config.adminPaths.empty()) {
- m_config.adminPaths = {
- "/api/users",
- "/api/auth/users",
- "/api/system/restart"};
- }
- // Add default user paths
- if (m_config.userPaths.empty()) {
- m_config.userPaths = {
- "/api/generate",
- "/api/queue",
- "/api/models/load",
- "/api/models/unload",
- "/api/auth/profile",
- "/api/auth/api-keys",
- // Model discovery endpoints moved to user paths
- "/api/models",
- "/api/models/types",
- "/api/models/directories",
- "/api/samplers",
- "/api/schedulers",
- "/api/parameters"};
- }
- }
- bool AuthMiddleware::isAuthenticationDisabled() const {
- return m_config.authMethod == AuthMethod::NONE;
- }
- AuthContext AuthMiddleware::authenticateJwt(const httplib::Request& req) {
- AuthContext context;
- context.authenticated = false;
- if (!m_jwtAuth) {
- context.errorMessage = "JWT authentication not available";
- context.errorCode = "JWT_AUTH_UNAVAILABLE";
- return context;
- }
- // Extract JWT token from Authorization header
- std::string token = extractToken(req, "Authorization");
- if (token.empty()) {
- context.errorMessage = "Missing JWT token";
- context.errorCode = "MISSING_JWT_TOKEN";
- return context;
- }
- // Validate JWT token
- auto result = m_jwtAuth->validateToken(token);
- if (!result.success) {
- context.errorMessage = result.errorMessage;
- context.errorCode = "INVALID_JWT_TOKEN";
- return context;
- }
- // JWT is valid
- context.authenticated = true;
- context.userId = result.userId;
- context.username = result.username;
- context.role = result.role;
- context.permissions = result.permissions;
- context.authMethod = "JWT";
- return context;
- }
- std::vector<std::string> AuthMiddleware::getRequiredPermissions(const std::string& path) const {
- std::vector<std::string> permissions;
- if (requiresAdminAccess(path)) {
- permissions.push_back("admin");
- } else if (requiresUserAccess(path)) {
- permissions.push_back("user");
- }
- return permissions;
- }
- bool AuthMiddleware::pathMatchesPattern(const std::string& path, const std::vector<std::string>& patterns) {
- for (const auto& pattern : patterns) {
- if (pattern == path) {
- return true;
- }
- // Check if pattern is a prefix
- if (pattern.back() == '/' && path.find(pattern) == 0) {
- return true;
- }
- // Simple wildcard matching
- if (pattern.find('*') != std::string::npos) {
- std::regex regexPattern(pattern, std::regex_constants::icase);
- if (std::regex_match(path, regexPattern)) {
- return true;
- }
- }
- }
- return false;
- }
- AuthConfig AuthMiddleware::getConfig() const {
- return m_config;
- }
- void AuthMiddleware::updateConfig(AuthConfig config) {
- m_config = std::move(config);
- }
- // Factory functions
- namespace AuthMiddlewareFactory {
- std::unique_ptr<AuthMiddleware> createDefault(std::shared_ptr<UserManager> userManager,
- const std::string& /*dataDir*/) {
- AuthConfig config;
- config.authMethod = AuthMethod::NONE;
- config.enableGuestAccess = true;
- config.jwtSecret = "";
- config.jwtExpirationMinutes = 60;
- config.authRealm = "stable-diffusion-rest";
- return std::make_unique<AuthMiddleware>(config, userManager);
- }
- std::unique_ptr<AuthMiddleware> createMultiMethod(std::shared_ptr<UserManager> userManager,
- AuthConfig config) {
- return std::make_unique<AuthMiddleware>(std::move(config), userManager);
- }
- std::unique_ptr<AuthMiddleware> createJwtOnly(std::shared_ptr<UserManager> userManager,
- const std::string& jwtSecret,
- int jwtExpirationMinutes) {
- AuthConfig config;
- config.authMethod = AuthMethod::JWT;
- config.enableGuestAccess = false;
- config.jwtSecret = jwtSecret;
- config.jwtExpirationMinutes = jwtExpirationMinutes;
- config.authRealm = "stable-diffusion-rest";
- return std::make_unique<AuthMiddleware>(config, userManager);
- }
- std::unique_ptr<AuthMiddleware> createApiKeyOnly(std::shared_ptr<UserManager> userManager) {
- AuthConfig config;
- config.authMethod = AuthMethod::API_KEY;
- config.enableGuestAccess = false;
- config.authRealm = "stable-diffusion-rest";
- return std::make_unique<AuthMiddleware>(config, userManager);
- }
- } // namespace AuthMiddlewareFactory
|