test_auth_security.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. #include "auth_middleware.h"
  2. #include "user_manager.h"
  3. #include <gtest/gtest.h>
  4. #include <gmock/gmock.h>
  5. #include <memory>
  6. #include <httplib.h>
  7. class AuthMiddlewareSecurityTest : public ::testing::Test {
  8. protected:
  9. void SetUp() override {
  10. userManager = std::make_shared<UserManager>("./test-auth", UserManager::AuthMethod::JWT);
  11. ASSERT_TRUE(userManager->initialize());
  12. }
  13. void TearDown() override {
  14. // Clean up test data
  15. std::filesystem::remove_all("./test-auth");
  16. }
  17. std::shared_ptr<UserManager> userManager;
  18. };
  19. TEST_F(AuthMiddlewareSecurityTest, DefaultPublicPathsWhenAuthDisabled) {
  20. // Test that when auth is disabled, no paths require authentication
  21. AuthConfig config;
  22. config.authMethod = AuthMethod::NONE;
  23. auto authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
  24. ASSERT_TRUE(authMiddleware->initialize());
  25. // All paths should be public when auth is disabled
  26. EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health"));
  27. EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status"));
  28. EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/models"));
  29. EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/generate"));
  30. EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/queue/status"));
  31. }
  32. TEST_F(AuthMiddlewareSecurityTest, DefaultPublicPathsWhenAuthEnabled) {
  33. // Test that when auth is enabled, only essential endpoints are public
  34. AuthConfig config;
  35. config.authMethod = AuthMethod::JWT;
  36. auto authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
  37. ASSERT_TRUE(authMiddleware->initialize());
  38. // Only health and status should be public by default
  39. EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health"));
  40. EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status"));
  41. // All other endpoints should require authentication
  42. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models"));
  43. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models/types"));
  44. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models/directories"));
  45. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/samplers"));
  46. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/schedulers"));
  47. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/parameters"));
  48. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate"));
  49. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/queue/status"));
  50. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/queue/job/123"));
  51. }
  52. TEST_F(AuthMiddlewareSecurityTest, CustomPublicPathsConfiguration) {
  53. // Test custom public paths configuration
  54. AuthConfig config;
  55. config.authMethod = AuthMethod::JWT;
  56. config.customPublicPaths = "/api/health,/api/status,/api/models,/api/samplers";
  57. auto authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
  58. ASSERT_TRUE(authMiddleware->initialize());
  59. // Configured public paths should be accessible
  60. EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health"));
  61. EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status"));
  62. EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/models"));
  63. EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/samplers"));
  64. // Non-configured paths should require authentication
  65. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models/types"));
  66. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/schedulers"));
  67. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate"));
  68. }
  69. TEST_F(AuthMiddlewareSecurityTest, CustomPublicPathsWithSpaces) {
  70. // Test custom public paths with spaces
  71. AuthConfig config;
  72. config.authMethod = AuthMethod::JWT;
  73. config.customPublicPaths = " /api/health , /api/status , /api/models ";
  74. auto authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
  75. ASSERT_TRUE(authMiddleware->initialize());
  76. // Spaces should be trimmed properly
  77. EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health"));
  78. EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status"));
  79. EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/models"));
  80. // Non-configured paths should require authentication
  81. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/samplers"));
  82. }
  83. TEST_F(AuthMiddlewareSecurityTest, CustomPublicPathsAutoPrefix) {
  84. // Test that paths without leading slash get auto-prefixed
  85. AuthConfig config;
  86. config.authMethod = AuthMethod::JWT;
  87. config.customPublicPaths = "api/health,api/status";
  88. auto authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
  89. ASSERT_TRUE(authMiddleware->initialize());
  90. // Paths should be accessible with or without leading slash
  91. EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health"));
  92. EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status"));
  93. }
  94. TEST_F(AuthMiddlewareSecurityTest, ModelDiscoveryEndpointsProtected) {
  95. // Test that all model discovery endpoints require authentication when enabled
  96. AuthConfig config;
  97. config.authMethod = AuthMethod::JWT;
  98. auto authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
  99. ASSERT_TRUE(authMiddleware->initialize());
  100. // All model discovery endpoints should require authentication
  101. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models"));
  102. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models/types"));
  103. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models/directories"));
  104. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models/stats"));
  105. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models/123"));
  106. }
  107. TEST_F(AuthMiddlewareSecurityTest, GenerationEndpointsProtected) {
  108. // Test that all generation endpoints require authentication when enabled
  109. AuthConfig config;
  110. config.authMethod = AuthMethod::JWT;
  111. auto authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
  112. ASSERT_TRUE(authMiddleware->initialize());
  113. // All generation endpoints should require authentication
  114. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate"));
  115. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate/text2img"));
  116. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate/img2img"));
  117. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate/controlnet"));
  118. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate/upscale"));
  119. }
  120. TEST_F(AuthMiddlewareSecurityTest, QueueEndpointsProtected) {
  121. // Test that queue endpoints require authentication when enabled
  122. AuthConfig config;
  123. config.authMethod = AuthMethod::JWT;
  124. auto authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
  125. ASSERT_TRUE(authMiddleware->initialize());
  126. // Queue endpoints should require authentication
  127. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/queue/status"));
  128. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/queue/job/123"));
  129. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/queue/cancel"));
  130. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/queue/clear"));
  131. }
  132. TEST_F(AuthMiddlewareSecurityTest, AuthenticationMethodsConsistency) {
  133. // Test that authentication enforcement is consistent across different auth methods
  134. std::vector<AuthMethod> authMethods = {
  135. AuthMethod::JWT,
  136. AuthMethod::API_KEY,
  137. AuthMethod::UNIX,
  138. AuthMethod::PAM
  139. };
  140. for (auto authMethod : authMethods) {
  141. AuthConfig config;
  142. config.authMethod = authMethod;
  143. auto authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
  144. ASSERT_TRUE(authMiddleware->initialize());
  145. // Health and status should always be public
  146. EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health"));
  147. EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status"));
  148. // Model discovery should require authentication
  149. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models"));
  150. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/samplers"));
  151. }
  152. }
  153. TEST_F(AuthMiddlewareSecurityTest, OptionalAuthWithGuestAccess) {
  154. // Test optional authentication with guest access enabled
  155. AuthConfig config;
  156. config.authMethod = AuthMethod::OPTIONAL;
  157. config.enableGuestAccess = true;
  158. auto authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
  159. ASSERT_TRUE(authMiddleware->initialize());
  160. // With guest access enabled, public paths should be accessible
  161. EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health"));
  162. EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status"));
  163. // But protected paths should still require authentication
  164. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models"));
  165. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate"));
  166. }
  167. TEST_F(AuthMiddlewareSecurityTest, OptionalAuthWithoutGuestAccess) {
  168. // Test optional authentication without guest access
  169. AuthConfig config;
  170. config.authMethod = AuthMethod::OPTIONAL;
  171. config.enableGuestAccess = false;
  172. auto authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
  173. ASSERT_TRUE(authMiddleware->initialize());
  174. // Without guest access, only public paths should be accessible
  175. EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health"));
  176. EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status"));
  177. // All other paths should require authentication
  178. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models"));
  179. EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate"));
  180. }
  181. // Integration test with actual HTTP requests
  182. class AuthMiddlewareHttpTest : public ::testing::Test {
  183. protected:
  184. void SetUp() override {
  185. userManager = std::make_shared<UserManager>("./test-auth-http", UserManager::AuthMethod::JWT);
  186. ASSERT_TRUE(userManager->initialize());
  187. AuthConfig config;
  188. config.authMethod = AuthMethod::JWT;
  189. config.jwtSecret = "test-secret";
  190. authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
  191. ASSERT_TRUE(authMiddleware->initialize());
  192. server = std::make_unique<httplib::Server>();
  193. // Set up test endpoints
  194. server->Get("/api/health", [](const httplib::Request&, httplib::Response& res) {
  195. res.set_content("{\"status\":\"healthy\"}", "application/json");
  196. });
  197. server->Get("/api/models", [this](const httplib::Request& req, httplib::Response& res) {
  198. auto authContext = authMiddleware->authenticate(req, res);
  199. if (!authContext.authenticated) {
  200. authMiddleware->sendAuthError(res, "Authentication required", "AUTH_REQUIRED");
  201. return;
  202. }
  203. res.set_content("{\"models\":[]}", "application/json");
  204. });
  205. // Start server in background
  206. serverThread = std::thread([this]() {
  207. server->listen("localhost", 0); // Use port 0 to get random port
  208. });
  209. // Wait for server to start
  210. std::this_thread::sleep_for(std::chrono::milliseconds(100));
  211. }
  212. void TearDown() override {
  213. if (server) {
  214. server->stop();
  215. }
  216. if (serverThread.joinable()) {
  217. serverThread.join();
  218. }
  219. std::filesystem::remove_all("./test-auth-http");
  220. }
  221. std::shared_ptr<UserManager> userManager;
  222. std::unique_ptr<AuthMiddleware> authMiddleware;
  223. std::unique_ptr<httplib::Server> server;
  224. std::thread serverThread;
  225. };
  226. TEST_F(AuthMiddlewareHttpTest, PublicEndpointAccessible) {
  227. // Test that public endpoints are accessible without authentication
  228. httplib::Client client("localhost", 8080);
  229. auto res = client.Get("/api/health");
  230. EXPECT_EQ(res->status, 200);
  231. EXPECT_NE(res->body.find("healthy"), std::string::npos);
  232. }
  233. TEST_F(AuthMiddlewareHttpTest, ProtectedEndpointRequiresAuth) {
  234. // Test that protected endpoints return 401 without authentication
  235. httplib::Client client("localhost", 8080);
  236. auto res = client.Get("/api/models");
  237. EXPECT_EQ(res->status, 401);
  238. EXPECT_NE(res->body.find("Authentication required"), std::string::npos);
  239. }
  240. int main(int argc, char** argv) {
  241. ::testing::InitGoogleTest(&argc, argv);
  242. return RUN_ALL_TESTS();
  243. }