auth_middleware.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562
  1. #include "auth_middleware.h"
  2. #include <httplib.h>
  3. #include <nlohmann/json.hpp>
  4. #include <fstream>
  5. #include <sstream>
  6. #include <iomanip>
  7. #include <algorithm>
  8. #include <regex>
  9. using json = nlohmann::json;
  10. AuthMiddleware::AuthMiddleware(const AuthConfig& config,
  11. std::shared_ptr<UserManager> userManager)
  12. : m_config(config)
  13. , m_userManager(userManager)
  14. {
  15. }
  16. AuthMiddleware::~AuthMiddleware() = default;
  17. bool AuthMiddleware::initialize() {
  18. try {
  19. // Validate configuration
  20. if (!validateConfig(m_config)) {
  21. return false;
  22. }
  23. // Initialize JWT auth if needed
  24. if (m_config.authMethod == AuthMethod::JWT) {
  25. m_jwtAuth = std::make_unique<JWTAuth>(m_config.jwtSecret,
  26. m_config.jwtExpirationMinutes,
  27. "stable-diffusion-rest");
  28. }
  29. // Initialize default paths
  30. initializeDefaultPaths();
  31. return true;
  32. } catch (const std::exception& e) {
  33. return false;
  34. }
  35. }
  36. AuthContext AuthMiddleware::authenticate(const httplib::Request& req, httplib::Response& res) {
  37. AuthContext context;
  38. context.authenticated = false;
  39. try {
  40. // Check if authentication is completely disabled
  41. if (isAuthenticationDisabled()) {
  42. context = createGuestContext();
  43. context.authenticated = true;
  44. return context;
  45. }
  46. // Check if path requires authentication
  47. if (!requiresAuthentication(req.path)) {
  48. context = createGuestContext();
  49. context.authenticated = m_config.enableGuestAccess;
  50. return context;
  51. }
  52. // Try different authentication methods based on configuration
  53. switch (m_config.authMethod) {
  54. case AuthMethod::JWT:
  55. context = authenticateJwt(req);
  56. break;
  57. case AuthMethod::API_KEY:
  58. context = authenticateApiKey(req);
  59. break;
  60. case AuthMethod::UNIX:
  61. context = authenticateUnix(req);
  62. break;
  63. case AuthMethod::OPTIONAL:
  64. // Try JWT first, then API key, then allow guest
  65. context = authenticateJwt(req);
  66. if (!context.authenticated) {
  67. context = authenticateApiKey(req);
  68. }
  69. if (!context.authenticated && m_config.enableGuestAccess) {
  70. context = createGuestContext();
  71. context.authenticated = true;
  72. }
  73. break;
  74. case AuthMethod::NONE:
  75. default:
  76. context = createGuestContext();
  77. context.authenticated = true;
  78. break;
  79. }
  80. // Check if user has required permissions for this path
  81. if (context.authenticated && !hasPathAccess(req.path, context.permissions)) {
  82. context.authenticated = false;
  83. context.errorMessage = "Insufficient permissions for this endpoint";
  84. context.errorCode = "INSUFFICIENT_PERMISSIONS";
  85. }
  86. // Log authentication attempt
  87. logAuthAttempt(req, context, context.authenticated);
  88. } catch (const std::exception& e) {
  89. context.authenticated = false;
  90. context.errorMessage = "Authentication error: " + std::string(e.what());
  91. context.errorCode = "AUTH_ERROR";
  92. }
  93. return context;
  94. }
  95. bool AuthMiddleware::requiresAuthentication(const std::string& path) const {
  96. // Check if path is public
  97. if (pathMatchesPattern(path, m_config.publicPaths)) {
  98. return false;
  99. }
  100. // All other paths require authentication unless auth is completely disabled
  101. return !isAuthenticationDisabled();
  102. }
  103. bool AuthMiddleware::requiresAdminAccess(const std::string& path) const {
  104. return pathMatchesPattern(path, m_config.adminPaths);
  105. }
  106. bool AuthMiddleware::requiresUserAccess(const std::string& path) const {
  107. return pathMatchesPattern(path, m_config.userPaths);
  108. }
  109. bool AuthMiddleware::hasPathAccess(const std::string& path,
  110. const std::vector<std::string>& permissions) const {
  111. // Check admin paths
  112. if (requiresAdminAccess(path)) {
  113. return JWTAuth::hasPermission(permissions, UserManager::Permissions::ADMIN);
  114. }
  115. // Check user paths
  116. if (requiresUserAccess(path)) {
  117. return JWTAuth::hasAnyPermission(permissions, {
  118. UserManager::Permissions::USER_MANAGE,
  119. UserManager::Permissions::ADMIN
  120. });
  121. }
  122. // Default: allow access if authenticated
  123. return true;
  124. }
  125. AuthMiddleware::AuthHandler AuthMiddleware::createMiddleware(AuthHandler handler) {
  126. return [this, handler](const httplib::Request& req, httplib::Response& res, const AuthContext& context) {
  127. // Authenticate request
  128. AuthContext authContext = authenticate(req, res);
  129. // Check if authentication failed
  130. if (!authContext.authenticated) {
  131. sendAuthError(res, authContext.errorMessage, authContext.errorCode);
  132. return;
  133. }
  134. // Call the next handler
  135. handler(req, res, authContext);
  136. };
  137. }
  138. void AuthMiddleware::sendAuthError(httplib::Response& res,
  139. const std::string& message,
  140. const std::string& errorCode,
  141. int statusCode) {
  142. json error = {
  143. {"error", {
  144. {"message", message},
  145. {"code", errorCode},
  146. {"timestamp", std::chrono::duration_cast<std::chrono::seconds>(
  147. std::chrono::system_clock::now().time_since_epoch()).count()}
  148. }}
  149. };
  150. res.set_header("Content-Type", "application/json");
  151. res.set_header("WWW-Authenticate", "Bearer realm=\"" + m_config.authRealm + "\"");
  152. res.status = statusCode;
  153. res.body = error.dump();
  154. }
  155. void AuthMiddleware::sendAuthzError(httplib::Response& res,
  156. const std::string& message,
  157. const std::string& errorCode) {
  158. json error = {
  159. {"error", {
  160. {"message", message},
  161. {"code", errorCode},
  162. {"timestamp", std::chrono::duration_cast<std::chrono::seconds>(
  163. std::chrono::system_clock::now().time_since_epoch()).count()}
  164. }}
  165. };
  166. res.set_header("Content-Type", "application/json");
  167. res.status = 403;
  168. res.body = error.dump();
  169. }
  170. void AuthMiddleware::addPublicPath(const std::string& path) {
  171. m_config.publicPaths.push_back(path);
  172. }
  173. void AuthMiddleware::addAdminPath(const std::string& path) {
  174. m_config.adminPaths.push_back(path);
  175. }
  176. void AuthMiddleware::addUserPath(const std::string& path) {
  177. m_config.userPaths.push_back(path);
  178. }
  179. void AuthMiddleware::setJwtSecret(const std::string& secret) {
  180. m_config.jwtSecret = secret;
  181. if (m_jwtAuth) {
  182. m_jwtAuth->setIssuer("stable-diffusion-rest");
  183. }
  184. }
  185. std::string AuthMiddleware::getJwtSecret() const {
  186. return m_config.jwtSecret;
  187. }
  188. void AuthMiddleware::setAuthMethod(UserManager::AuthMethod method) {
  189. m_config.authMethod = static_cast<AuthMethod>(method);
  190. }
  191. UserManager::AuthMethod AuthMiddleware::getAuthMethod() const {
  192. return static_cast<UserManager::AuthMethod>(m_config.authMethod);
  193. }
  194. void AuthMiddleware::setGuestAccessEnabled(bool enable) {
  195. m_config.enableGuestAccess = enable;
  196. }
  197. bool AuthMiddleware::isGuestAccessEnabled() const {
  198. return m_config.enableGuestAccess;
  199. }
  200. AuthConfig AuthMiddleware::getConfig() const {
  201. return m_config;
  202. }
  203. void AuthMiddleware::updateConfig(const AuthConfig& config) {
  204. m_config = config;
  205. if (m_config.authMethod == AuthMethod::JWT) {
  206. m_jwtAuth = std::make_unique<JWTAuth>(m_config.jwtSecret,
  207. m_config.jwtExpirationMinutes,
  208. "stable-diffusion-rest");
  209. }
  210. }
  211. AuthContext AuthMiddleware::authenticateJwt(const httplib::Request& req) {
  212. AuthContext context;
  213. context.authenticated = false;
  214. if (!m_jwtAuth) {
  215. context.errorMessage = "JWT authentication not configured";
  216. context.errorCode = "JWT_NOT_CONFIGURED";
  217. return context;
  218. }
  219. // Extract token from header
  220. std::string token = extractToken(req, "Authorization");
  221. if (token.empty()) {
  222. context.errorMessage = "Missing authorization token";
  223. context.errorCode = "MISSING_TOKEN";
  224. return context;
  225. }
  226. // Validate token
  227. auto result = m_jwtAuth->validateToken(token);
  228. if (!result.success) {
  229. context.errorMessage = result.errorMessage;
  230. context.errorCode = result.errorCode;
  231. return context;
  232. }
  233. // Token is valid
  234. context.authenticated = true;
  235. context.userId = result.userId;
  236. context.username = result.username;
  237. context.role = result.role;
  238. context.permissions = result.permissions;
  239. context.authMethod = "JWT";
  240. return context;
  241. }
  242. AuthContext AuthMiddleware::authenticateApiKey(const httplib::Request& req) {
  243. AuthContext context;
  244. context.authenticated = false;
  245. if (!m_userManager) {
  246. context.errorMessage = "User manager not available";
  247. context.errorCode = "USER_MANAGER_UNAVAILABLE";
  248. return context;
  249. }
  250. // Extract API key from header
  251. std::string apiKey = extractToken(req, "X-API-Key");
  252. if (apiKey.empty()) {
  253. context.errorMessage = "Missing API key";
  254. context.errorCode = "MISSING_API_KEY";
  255. return context;
  256. }
  257. // Validate API key
  258. auto result = m_userManager->authenticateApiKey(apiKey);
  259. if (!result.success) {
  260. context.errorMessage = result.errorMessage;
  261. context.errorCode = result.errorCode;
  262. return context;
  263. }
  264. // API key is valid
  265. context.authenticated = true;
  266. context.userId = result.userId;
  267. context.username = result.username;
  268. context.role = result.role;
  269. context.permissions = result.permissions;
  270. context.authMethod = "API_KEY";
  271. return context;
  272. }
  273. AuthContext AuthMiddleware::authenticateUnix(const httplib::Request& req) {
  274. AuthContext context;
  275. context.authenticated = false;
  276. if (!m_userManager || !m_userManager->isUnixAuthEnabled()) {
  277. context.errorMessage = "Unix authentication not available";
  278. context.errorCode = "UNIX_AUTH_UNAVAILABLE";
  279. return context;
  280. }
  281. // For Unix auth, we need to get username from request
  282. // This could be from a header or client certificate
  283. std::string username = req.get_header_value("X-Unix-User");
  284. if (username.empty()) {
  285. context.errorMessage = "Missing Unix username";
  286. context.errorCode = "MISSING_UNIX_USER";
  287. return context;
  288. }
  289. // Authenticate Unix user
  290. auto result = m_userManager->authenticateUnix(username);
  291. if (!result.success) {
  292. context.errorMessage = result.errorMessage;
  293. context.errorCode = result.errorCode;
  294. return context;
  295. }
  296. // Unix authentication successful
  297. context.authenticated = true;
  298. context.userId = result.userId;
  299. context.username = result.username;
  300. context.role = result.role;
  301. context.permissions = result.permissions;
  302. context.authMethod = "UNIX";
  303. return context;
  304. }
  305. std::string AuthMiddleware::extractToken(const httplib::Request& req, const std::string& headerName) const {
  306. std::string authHeader = req.get_header_value(headerName);
  307. if (headerName == "Authorization") {
  308. return JWTAuth::extractTokenFromHeader(authHeader);
  309. }
  310. return authHeader;
  311. }
  312. AuthContext AuthMiddleware::createGuestContext() const {
  313. AuthContext context;
  314. context.authenticated = false;
  315. context.userId = "guest";
  316. context.username = "guest";
  317. context.role = "guest";
  318. context.permissions = UserManager::getDefaultPermissions(UserManager::UserRole::GUEST);
  319. context.authMethod = "none";
  320. return context;
  321. }
  322. bool AuthMiddleware::pathMatchesPattern(const std::string& path,
  323. const std::vector<std::string>& patterns) {
  324. for (const auto& pattern : patterns) {
  325. // Simple exact match for now
  326. if (path == pattern) {
  327. return true;
  328. }
  329. // Check for prefix match (pattern ends with *)
  330. if (pattern.length() > 1 && pattern.back() == '*') {
  331. std::string prefix = pattern.substr(0, pattern.length() - 1);
  332. if (path.length() >= prefix.length() && path.substr(0, prefix.length()) == prefix) {
  333. return true;
  334. }
  335. }
  336. }
  337. return false;
  338. }
  339. std::vector<std::string> AuthMiddleware::getRequiredPermissions(const std::string& path) const {
  340. if (requiresAdminAccess(path)) {
  341. return {UserManager::Permissions::ADMIN};
  342. }
  343. if (requiresUserAccess(path)) {
  344. return {UserManager::Permissions::READ};
  345. }
  346. return {};
  347. }
  348. void AuthMiddleware::logAuthAttempt(const httplib::Request& req,
  349. const AuthContext& context,
  350. bool success) const {
  351. // In a real implementation, this would log to a file or logging system
  352. std::string clientIp = getClientIp(req);
  353. std::string userAgent = getUserAgent(req);
  354. if (success) {
  355. // Log successful authentication
  356. } else {
  357. // Log failed authentication attempt
  358. }
  359. }
  360. std::string AuthMiddleware::getClientIp(const httplib::Request& req) {
  361. // Check various headers for client IP
  362. std::string ip = req.get_header_value("X-Forwarded-For");
  363. if (ip.empty()) {
  364. ip = req.get_header_value("X-Real-IP");
  365. }
  366. if (ip.empty()) {
  367. ip = req.get_header_value("X-Client-IP");
  368. }
  369. if (ip.empty()) {
  370. ip = req.remote_addr;
  371. }
  372. return ip;
  373. }
  374. std::string AuthMiddleware::getUserAgent(const httplib::Request& req) {
  375. return req.get_header_value("User-Agent");
  376. }
  377. bool AuthMiddleware::validateConfig(const AuthConfig& config) {
  378. // Validate JWT configuration
  379. if (config.authMethod == AuthMethod::JWT) {
  380. if (config.jwtSecret.empty()) {
  381. // Will be auto-generated
  382. }
  383. if (config.jwtExpirationMinutes <= 0 || config.jwtExpirationMinutes > 1440) {
  384. return false; // Max 24 hours
  385. }
  386. }
  387. // Validate realm
  388. if (config.authRealm.empty()) {
  389. return false;
  390. }
  391. return true;
  392. }
  393. void AuthMiddleware::initializeDefaultPaths() {
  394. // Add default public paths
  395. if (m_config.publicPaths.empty()) {
  396. m_config.publicPaths = {
  397. "/api/health",
  398. "/api/status",
  399. "/api/samplers",
  400. "/api/schedulers",
  401. "/api/parameters",
  402. "/api/models",
  403. "/api/models/types",
  404. "/api/models/directories"
  405. };
  406. }
  407. // Add default admin paths
  408. if (m_config.adminPaths.empty()) {
  409. m_config.adminPaths = {
  410. "/api/users",
  411. "/api/auth/users",
  412. "/api/system/restart"
  413. };
  414. }
  415. // Add default user paths
  416. if (m_config.userPaths.empty()) {
  417. m_config.userPaths = {
  418. "/api/generate",
  419. "/api/queue",
  420. "/api/models/load",
  421. "/api/models/unload",
  422. "/api/auth/profile",
  423. "/api/auth/api-keys"
  424. };
  425. }
  426. }
  427. bool AuthMiddleware::isAuthenticationDisabled() const {
  428. return m_config.authMethod == AuthMethod::NONE;
  429. }
  430. // Factory functions
  431. namespace AuthMiddlewareFactory {
  432. std::unique_ptr<AuthMiddleware> createDefault(std::shared_ptr<UserManager> userManager,
  433. const std::string& dataDir) {
  434. AuthConfig config;
  435. config.authMethod = AuthMethod::NONE;
  436. config.enableGuestAccess = true;
  437. config.jwtSecret = "";
  438. config.jwtExpirationMinutes = 60;
  439. config.authRealm = "stable-diffusion-rest";
  440. return std::make_unique<AuthMiddleware>(config, userManager);
  441. }
  442. std::unique_ptr<AuthMiddleware> createJwtOnly(std::shared_ptr<UserManager> userManager,
  443. const std::string& jwtSecret,
  444. int jwtExpirationMinutes) {
  445. AuthConfig config;
  446. config.authMethod = AuthMethod::JWT;
  447. config.enableGuestAccess = false;
  448. config.jwtSecret = jwtSecret;
  449. config.jwtExpirationMinutes = jwtExpirationMinutes;
  450. config.authRealm = "stable-diffusion-rest";
  451. config.enableUnixAuth = false;
  452. return std::make_unique<AuthMiddleware>(config, userManager);
  453. }
  454. std::unique_ptr<AuthMiddleware> createApiKeyOnly(std::shared_ptr<UserManager> userManager) {
  455. AuthConfig config;
  456. config.authMethod = AuthMethod::API_KEY;
  457. config.enableGuestAccess = false;
  458. config.authRealm = "stable-diffusion-rest";
  459. config.enableUnixAuth = false;
  460. return std::make_unique<AuthMiddleware>(config, userManager);
  461. }
  462. std::unique_ptr<AuthMiddleware> createMultiMethod(std::shared_ptr<UserManager> userManager,
  463. const AuthConfig& config) {
  464. return std::make_unique<AuthMiddleware>(config, userManager);
  465. }
  466. std::unique_ptr<AuthMiddleware> createDevelopment() {
  467. AuthConfig config;
  468. config.authMethod = AuthMethod::NONE;
  469. config.enableGuestAccess = true;
  470. config.authRealm = "stable-diffusion-rest";
  471. return std::make_unique<AuthMiddleware>(config, nullptr);
  472. }
  473. } // namespace AuthMiddlewareFactory