| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562 |
- #include "auth_middleware.h"
- #include <httplib.h>
- #include <nlohmann/json.hpp>
- #include <fstream>
- #include <sstream>
- #include <iomanip>
- #include <algorithm>
- #include <regex>
- using json = nlohmann::json;
- AuthMiddleware::AuthMiddleware(const AuthConfig& config,
- std::shared_ptr<UserManager> userManager)
- : m_config(config)
- , m_userManager(userManager)
- {
- }
- AuthMiddleware::~AuthMiddleware() = default;
- bool AuthMiddleware::initialize() {
- try {
- // Validate configuration
- if (!validateConfig(m_config)) {
- return false;
- }
- // Initialize JWT auth if needed
- if (m_config.authMethod == AuthMethod::JWT) {
- m_jwtAuth = std::make_unique<JWTAuth>(m_config.jwtSecret,
- m_config.jwtExpirationMinutes,
- "stable-diffusion-rest");
- }
- // Initialize default paths
- initializeDefaultPaths();
- return true;
- } catch (const std::exception& e) {
- return false;
- }
- }
- AuthContext AuthMiddleware::authenticate(const httplib::Request& req, httplib::Response& res) {
- AuthContext context;
- context.authenticated = false;
- try {
- // Check if authentication is completely disabled
- if (isAuthenticationDisabled()) {
- context = createGuestContext();
- context.authenticated = true;
- return context;
- }
- // Check if path requires authentication
- if (!requiresAuthentication(req.path)) {
- context = createGuestContext();
- context.authenticated = m_config.enableGuestAccess;
- return context;
- }
- // Try different authentication methods based on configuration
- switch (m_config.authMethod) {
- case AuthMethod::JWT:
- context = authenticateJwt(req);
- break;
- case AuthMethod::API_KEY:
- context = authenticateApiKey(req);
- break;
- case AuthMethod::UNIX:
- context = authenticateUnix(req);
- break;
- case AuthMethod::OPTIONAL:
- // Try JWT first, then API key, then allow guest
- context = authenticateJwt(req);
- if (!context.authenticated) {
- context = authenticateApiKey(req);
- }
- if (!context.authenticated && m_config.enableGuestAccess) {
- context = createGuestContext();
- context.authenticated = true;
- }
- break;
- case AuthMethod::NONE:
- default:
- context = createGuestContext();
- context.authenticated = true;
- break;
- }
- // Check if user has required permissions for this path
- if (context.authenticated && !hasPathAccess(req.path, context.permissions)) {
- context.authenticated = false;
- context.errorMessage = "Insufficient permissions for this endpoint";
- context.errorCode = "INSUFFICIENT_PERMISSIONS";
- }
- // Log authentication attempt
- logAuthAttempt(req, context, context.authenticated);
- } catch (const std::exception& e) {
- context.authenticated = false;
- context.errorMessage = "Authentication error: " + std::string(e.what());
- context.errorCode = "AUTH_ERROR";
- }
- return context;
- }
- bool AuthMiddleware::requiresAuthentication(const std::string& path) const {
- // Check if path is public
- if (pathMatchesPattern(path, m_config.publicPaths)) {
- return false;
- }
- // All other paths require authentication unless auth is completely disabled
- return !isAuthenticationDisabled();
- }
- bool AuthMiddleware::requiresAdminAccess(const std::string& path) const {
- return pathMatchesPattern(path, m_config.adminPaths);
- }
- bool AuthMiddleware::requiresUserAccess(const std::string& path) const {
- return pathMatchesPattern(path, m_config.userPaths);
- }
- bool AuthMiddleware::hasPathAccess(const std::string& path,
- const std::vector<std::string>& permissions) const {
- // Check admin paths
- if (requiresAdminAccess(path)) {
- return JWTAuth::hasPermission(permissions, UserManager::Permissions::ADMIN);
- }
- // Check user paths
- if (requiresUserAccess(path)) {
- return JWTAuth::hasAnyPermission(permissions, {
- UserManager::Permissions::USER_MANAGE,
- UserManager::Permissions::ADMIN
- });
- }
- // Default: allow access if authenticated
- return true;
- }
- AuthMiddleware::AuthHandler AuthMiddleware::createMiddleware(AuthHandler handler) {
- return [this, handler](const httplib::Request& req, httplib::Response& res, const AuthContext& context) {
- // Authenticate request
- AuthContext authContext = authenticate(req, res);
- // Check if authentication failed
- if (!authContext.authenticated) {
- sendAuthError(res, authContext.errorMessage, authContext.errorCode);
- return;
- }
- // Call the next handler
- handler(req, res, authContext);
- };
- }
- void AuthMiddleware::sendAuthError(httplib::Response& res,
- const std::string& message,
- const std::string& errorCode,
- int statusCode) {
- json error = {
- {"error", {
- {"message", message},
- {"code", errorCode},
- {"timestamp", std::chrono::duration_cast<std::chrono::seconds>(
- std::chrono::system_clock::now().time_since_epoch()).count()}
- }}
- };
- res.set_header("Content-Type", "application/json");
- res.set_header("WWW-Authenticate", "Bearer realm=\"" + m_config.authRealm + "\"");
- res.status = statusCode;
- res.body = error.dump();
- }
- void AuthMiddleware::sendAuthzError(httplib::Response& res,
- const std::string& message,
- const std::string& errorCode) {
- json error = {
- {"error", {
- {"message", message},
- {"code", errorCode},
- {"timestamp", std::chrono::duration_cast<std::chrono::seconds>(
- std::chrono::system_clock::now().time_since_epoch()).count()}
- }}
- };
- res.set_header("Content-Type", "application/json");
- res.status = 403;
- res.body = error.dump();
- }
- void AuthMiddleware::addPublicPath(const std::string& path) {
- m_config.publicPaths.push_back(path);
- }
- void AuthMiddleware::addAdminPath(const std::string& path) {
- m_config.adminPaths.push_back(path);
- }
- void AuthMiddleware::addUserPath(const std::string& path) {
- m_config.userPaths.push_back(path);
- }
- void AuthMiddleware::setJwtSecret(const std::string& secret) {
- m_config.jwtSecret = secret;
- if (m_jwtAuth) {
- m_jwtAuth->setIssuer("stable-diffusion-rest");
- }
- }
- std::string AuthMiddleware::getJwtSecret() const {
- return m_config.jwtSecret;
- }
- void AuthMiddleware::setAuthMethod(UserManager::AuthMethod method) {
- m_config.authMethod = static_cast<AuthMethod>(method);
- }
- UserManager::AuthMethod AuthMiddleware::getAuthMethod() const {
- return static_cast<UserManager::AuthMethod>(m_config.authMethod);
- }
- void AuthMiddleware::setGuestAccessEnabled(bool enable) {
- m_config.enableGuestAccess = enable;
- }
- bool AuthMiddleware::isGuestAccessEnabled() const {
- return m_config.enableGuestAccess;
- }
- AuthConfig AuthMiddleware::getConfig() const {
- return m_config;
- }
- void AuthMiddleware::updateConfig(const AuthConfig& config) {
- m_config = config;
- if (m_config.authMethod == AuthMethod::JWT) {
- m_jwtAuth = std::make_unique<JWTAuth>(m_config.jwtSecret,
- m_config.jwtExpirationMinutes,
- "stable-diffusion-rest");
- }
- }
- AuthContext AuthMiddleware::authenticateJwt(const httplib::Request& req) {
- AuthContext context;
- context.authenticated = false;
- if (!m_jwtAuth) {
- context.errorMessage = "JWT authentication not configured";
- context.errorCode = "JWT_NOT_CONFIGURED";
- return context;
- }
- // Extract token from header
- std::string token = extractToken(req, "Authorization");
- if (token.empty()) {
- context.errorMessage = "Missing authorization token";
- context.errorCode = "MISSING_TOKEN";
- return context;
- }
- // Validate token
- auto result = m_jwtAuth->validateToken(token);
- if (!result.success) {
- context.errorMessage = result.errorMessage;
- context.errorCode = result.errorCode;
- return context;
- }
- // Token is valid
- context.authenticated = true;
- context.userId = result.userId;
- context.username = result.username;
- context.role = result.role;
- context.permissions = result.permissions;
- context.authMethod = "JWT";
- return context;
- }
- AuthContext AuthMiddleware::authenticateApiKey(const httplib::Request& req) {
- AuthContext context;
- context.authenticated = false;
- if (!m_userManager) {
- context.errorMessage = "User manager not available";
- context.errorCode = "USER_MANAGER_UNAVAILABLE";
- return context;
- }
- // Extract API key from header
- std::string apiKey = extractToken(req, "X-API-Key");
- if (apiKey.empty()) {
- context.errorMessage = "Missing API key";
- context.errorCode = "MISSING_API_KEY";
- return context;
- }
- // Validate API key
- auto result = m_userManager->authenticateApiKey(apiKey);
- if (!result.success) {
- context.errorMessage = result.errorMessage;
- context.errorCode = result.errorCode;
- return context;
- }
- // API key is valid
- context.authenticated = true;
- context.userId = result.userId;
- context.username = result.username;
- context.role = result.role;
- context.permissions = result.permissions;
- context.authMethod = "API_KEY";
- return context;
- }
- AuthContext AuthMiddleware::authenticateUnix(const httplib::Request& req) {
- AuthContext context;
- context.authenticated = false;
- if (!m_userManager || !m_userManager->isUnixAuthEnabled()) {
- context.errorMessage = "Unix authentication not available";
- context.errorCode = "UNIX_AUTH_UNAVAILABLE";
- return context;
- }
- // For Unix auth, we need to get username from request
- // This could be from a header or client certificate
- std::string username = req.get_header_value("X-Unix-User");
- if (username.empty()) {
- context.errorMessage = "Missing Unix username";
- context.errorCode = "MISSING_UNIX_USER";
- return context;
- }
- // Authenticate Unix user
- auto result = m_userManager->authenticateUnix(username);
- if (!result.success) {
- context.errorMessage = result.errorMessage;
- context.errorCode = result.errorCode;
- return context;
- }
- // Unix authentication successful
- context.authenticated = true;
- context.userId = result.userId;
- context.username = result.username;
- context.role = result.role;
- context.permissions = result.permissions;
- context.authMethod = "UNIX";
- return context;
- }
- std::string AuthMiddleware::extractToken(const httplib::Request& req, const std::string& headerName) const {
- std::string authHeader = req.get_header_value(headerName);
- if (headerName == "Authorization") {
- return JWTAuth::extractTokenFromHeader(authHeader);
- }
- return authHeader;
- }
- AuthContext AuthMiddleware::createGuestContext() const {
- AuthContext context;
- context.authenticated = false;
- context.userId = "guest";
- context.username = "guest";
- context.role = "guest";
- context.permissions = UserManager::getDefaultPermissions(UserManager::UserRole::GUEST);
- context.authMethod = "none";
- return context;
- }
- bool AuthMiddleware::pathMatchesPattern(const std::string& path,
- const std::vector<std::string>& patterns) {
- for (const auto& pattern : patterns) {
- // Simple exact match for now
- if (path == pattern) {
- return true;
- }
- // Check for prefix match (pattern ends with *)
- if (pattern.length() > 1 && pattern.back() == '*') {
- std::string prefix = pattern.substr(0, pattern.length() - 1);
- if (path.length() >= prefix.length() && path.substr(0, prefix.length()) == prefix) {
- return true;
- }
- }
- }
- return false;
- }
- std::vector<std::string> AuthMiddleware::getRequiredPermissions(const std::string& path) const {
- if (requiresAdminAccess(path)) {
- return {UserManager::Permissions::ADMIN};
- }
- if (requiresUserAccess(path)) {
- return {UserManager::Permissions::READ};
- }
- return {};
- }
- void AuthMiddleware::logAuthAttempt(const httplib::Request& req,
- const AuthContext& context,
- bool success) const {
- // In a real implementation, this would log to a file or logging system
- std::string clientIp = getClientIp(req);
- std::string userAgent = getUserAgent(req);
- if (success) {
- // Log successful authentication
- } else {
- // Log failed authentication attempt
- }
- }
- std::string AuthMiddleware::getClientIp(const httplib::Request& req) {
- // Check various headers for client IP
- std::string ip = req.get_header_value("X-Forwarded-For");
- if (ip.empty()) {
- ip = req.get_header_value("X-Real-IP");
- }
- if (ip.empty()) {
- ip = req.get_header_value("X-Client-IP");
- }
- if (ip.empty()) {
- ip = req.remote_addr;
- }
- return ip;
- }
- std::string AuthMiddleware::getUserAgent(const httplib::Request& req) {
- return req.get_header_value("User-Agent");
- }
- bool AuthMiddleware::validateConfig(const AuthConfig& config) {
- // Validate JWT configuration
- if (config.authMethod == AuthMethod::JWT) {
- if (config.jwtSecret.empty()) {
- // Will be auto-generated
- }
- if (config.jwtExpirationMinutes <= 0 || config.jwtExpirationMinutes > 1440) {
- return false; // Max 24 hours
- }
- }
- // Validate realm
- if (config.authRealm.empty()) {
- return false;
- }
- return true;
- }
- void AuthMiddleware::initializeDefaultPaths() {
- // Add default public paths
- if (m_config.publicPaths.empty()) {
- m_config.publicPaths = {
- "/api/health",
- "/api/status",
- "/api/samplers",
- "/api/schedulers",
- "/api/parameters",
- "/api/models",
- "/api/models/types",
- "/api/models/directories"
- };
- }
- // Add default admin paths
- if (m_config.adminPaths.empty()) {
- m_config.adminPaths = {
- "/api/users",
- "/api/auth/users",
- "/api/system/restart"
- };
- }
- // Add default user paths
- if (m_config.userPaths.empty()) {
- m_config.userPaths = {
- "/api/generate",
- "/api/queue",
- "/api/models/load",
- "/api/models/unload",
- "/api/auth/profile",
- "/api/auth/api-keys"
- };
- }
- }
- bool AuthMiddleware::isAuthenticationDisabled() const {
- return m_config.authMethod == AuthMethod::NONE;
- }
- // Factory functions
- namespace AuthMiddlewareFactory {
- std::unique_ptr<AuthMiddleware> createDefault(std::shared_ptr<UserManager> userManager,
- const std::string& dataDir) {
- AuthConfig config;
- config.authMethod = AuthMethod::NONE;
- config.enableGuestAccess = true;
- config.jwtSecret = "";
- config.jwtExpirationMinutes = 60;
- config.authRealm = "stable-diffusion-rest";
- return std::make_unique<AuthMiddleware>(config, userManager);
- }
- std::unique_ptr<AuthMiddleware> createJwtOnly(std::shared_ptr<UserManager> userManager,
- const std::string& jwtSecret,
- int jwtExpirationMinutes) {
- AuthConfig config;
- config.authMethod = AuthMethod::JWT;
- config.enableGuestAccess = false;
- config.jwtSecret = jwtSecret;
- config.jwtExpirationMinutes = jwtExpirationMinutes;
- config.authRealm = "stable-diffusion-rest";
- config.enableUnixAuth = false;
- return std::make_unique<AuthMiddleware>(config, userManager);
- }
- std::unique_ptr<AuthMiddleware> createApiKeyOnly(std::shared_ptr<UserManager> userManager) {
- AuthConfig config;
- config.authMethod = AuthMethod::API_KEY;
- config.enableGuestAccess = false;
- config.authRealm = "stable-diffusion-rest";
- config.enableUnixAuth = false;
- return std::make_unique<AuthMiddleware>(config, userManager);
- }
- std::unique_ptr<AuthMiddleware> createMultiMethod(std::shared_ptr<UserManager> userManager,
- const AuthConfig& config) {
- return std::make_unique<AuthMiddleware>(config, userManager);
- }
- std::unique_ptr<AuthMiddleware> createDevelopment() {
- AuthConfig config;
- config.authMethod = AuthMethod::NONE;
- config.enableGuestAccess = true;
- config.authRealm = "stable-diffusion-rest";
- return std::make_unique<AuthMiddleware>(config, nullptr);
- }
- } // namespace AuthMiddlewareFactory
|