| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301 |
- #include "auth_middleware.h"
- #include "user_manager.h"
- #include <gtest/gtest.h>
- #include <gmock/gmock.h>
- #include <memory>
- #include <httplib.h>
- class AuthMiddlewareSecurityTest : public ::testing::Test {
- protected:
- void SetUp() override {
- userManager = std::make_shared<UserManager>("./test-auth", UserManager::AuthMethod::JWT);
- ASSERT_TRUE(userManager->initialize());
- }
- void TearDown() override {
- // Clean up test data
- std::filesystem::remove_all("./test-auth");
- }
- std::shared_ptr<UserManager> userManager;
- };
- TEST_F(AuthMiddlewareSecurityTest, DefaultPublicPathsWhenAuthDisabled) {
- // Test that when auth is disabled, no paths require authentication
- AuthConfig config;
- config.authMethod = AuthMethod::NONE;
- auto authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
- ASSERT_TRUE(authMiddleware->initialize());
- // All paths should be public when auth is disabled
- EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health"));
- EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status"));
- EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/models"));
- EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/generate"));
- EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/queue/status"));
- }
- TEST_F(AuthMiddlewareSecurityTest, DefaultPublicPathsWhenAuthEnabled) {
- // Test that when auth is enabled, only essential endpoints are public
- AuthConfig config;
- config.authMethod = AuthMethod::JWT;
- auto authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
- ASSERT_TRUE(authMiddleware->initialize());
- // Only health and status should be public by default
- EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health"));
- EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status"));
- // All other endpoints should require authentication
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models/types"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models/directories"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/samplers"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/schedulers"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/parameters"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/queue/status"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/queue/job/123"));
- }
- TEST_F(AuthMiddlewareSecurityTest, CustomPublicPathsConfiguration) {
- // Test custom public paths configuration
- AuthConfig config;
- config.authMethod = AuthMethod::JWT;
- config.customPublicPaths = "/api/health,/api/status,/api/models,/api/samplers";
- auto authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
- ASSERT_TRUE(authMiddleware->initialize());
- // Configured public paths should be accessible
- EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health"));
- EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status"));
- EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/models"));
- EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/samplers"));
- // Non-configured paths should require authentication
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models/types"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/schedulers"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate"));
- }
- TEST_F(AuthMiddlewareSecurityTest, CustomPublicPathsWithSpaces) {
- // Test custom public paths with spaces
- AuthConfig config;
- config.authMethod = AuthMethod::JWT;
- config.customPublicPaths = " /api/health , /api/status , /api/models ";
- auto authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
- ASSERT_TRUE(authMiddleware->initialize());
- // Spaces should be trimmed properly
- EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health"));
- EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status"));
- EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/models"));
- // Non-configured paths should require authentication
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/samplers"));
- }
- TEST_F(AuthMiddlewareSecurityTest, CustomPublicPathsAutoPrefix) {
- // Test that paths without leading slash get auto-prefixed
- AuthConfig config;
- config.authMethod = AuthMethod::JWT;
- config.customPublicPaths = "api/health,api/status";
- auto authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
- ASSERT_TRUE(authMiddleware->initialize());
- // Paths should be accessible with or without leading slash
- EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health"));
- EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status"));
- }
- TEST_F(AuthMiddlewareSecurityTest, ModelDiscoveryEndpointsProtected) {
- // Test that all model discovery endpoints require authentication when enabled
- AuthConfig config;
- config.authMethod = AuthMethod::JWT;
- auto authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
- ASSERT_TRUE(authMiddleware->initialize());
- // All model discovery endpoints should require authentication
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models/types"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models/directories"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models/stats"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models/123"));
- }
- TEST_F(AuthMiddlewareSecurityTest, GenerationEndpointsProtected) {
- // Test that all generation endpoints require authentication when enabled
- AuthConfig config;
- config.authMethod = AuthMethod::JWT;
- auto authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
- ASSERT_TRUE(authMiddleware->initialize());
- // All generation endpoints should require authentication
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate/text2img"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate/img2img"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate/controlnet"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate/upscale"));
- }
- TEST_F(AuthMiddlewareSecurityTest, QueueEndpointsProtected) {
- // Test that queue endpoints require authentication when enabled
- AuthConfig config;
- config.authMethod = AuthMethod::JWT;
- auto authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
- ASSERT_TRUE(authMiddleware->initialize());
- // Queue endpoints should require authentication
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/queue/status"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/queue/job/123"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/queue/cancel"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/queue/clear"));
- }
- TEST_F(AuthMiddlewareSecurityTest, AuthenticationMethodsConsistency) {
- // Test that authentication enforcement is consistent across different auth methods
- std::vector<AuthMethod> authMethods = {
- AuthMethod::JWT,
- AuthMethod::API_KEY,
- AuthMethod::UNIX,
- AuthMethod::PAM
- };
- for (auto authMethod : authMethods) {
- AuthConfig config;
- config.authMethod = authMethod;
- auto authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
- ASSERT_TRUE(authMiddleware->initialize());
- // Health and status should always be public
- EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health"));
- EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status"));
- // Model discovery should require authentication
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/samplers"));
- }
- }
- TEST_F(AuthMiddlewareSecurityTest, OptionalAuthWithGuestAccess) {
- // Test optional authentication with guest access enabled
- AuthConfig config;
- config.authMethod = AuthMethod::OPTIONAL;
- config.enableGuestAccess = true;
- auto authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
- ASSERT_TRUE(authMiddleware->initialize());
- // With guest access enabled, public paths should be accessible
- EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health"));
- EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status"));
- // But protected paths should still require authentication
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate"));
- }
- TEST_F(AuthMiddlewareSecurityTest, OptionalAuthWithoutGuestAccess) {
- // Test optional authentication without guest access
- AuthConfig config;
- config.authMethod = AuthMethod::OPTIONAL;
- config.enableGuestAccess = false;
- auto authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
- ASSERT_TRUE(authMiddleware->initialize());
- // Without guest access, only public paths should be accessible
- EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health"));
- EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status"));
- // All other paths should require authentication
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models"));
- EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate"));
- }
- // Integration test with actual HTTP requests
- class AuthMiddlewareHttpTest : public ::testing::Test {
- protected:
- void SetUp() override {
- userManager = std::make_shared<UserManager>("./test-auth-http", UserManager::AuthMethod::JWT);
- ASSERT_TRUE(userManager->initialize());
- AuthConfig config;
- config.authMethod = AuthMethod::JWT;
- config.jwtSecret = "test-secret";
- authMiddleware = std::make_unique<AuthMiddleware>(config, userManager);
- ASSERT_TRUE(authMiddleware->initialize());
- server = std::make_unique<httplib::Server>();
- // Set up test endpoints
- server->Get("/api/health", [](const httplib::Request&, httplib::Response& res) {
- res.set_content("{\"status\":\"healthy\"}", "application/json");
- });
- server->Get("/api/models", [this](const httplib::Request& req, httplib::Response& res) {
- auto authContext = authMiddleware->authenticate(req, res);
- if (!authContext.authenticated) {
- authMiddleware->sendAuthError(res, "Authentication required", "AUTH_REQUIRED");
- return;
- }
- res.set_content("{\"models\":[]}", "application/json");
- });
- // Start server in background
- serverThread = std::thread([this]() {
- server->listen("localhost", 0); // Use port 0 to get random port
- });
- // Wait for server to start
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
- }
- void TearDown() override {
- if (server) {
- server->stop();
- }
- if (serverThread.joinable()) {
- serverThread.join();
- }
- std::filesystem::remove_all("./test-auth-http");
- }
- std::shared_ptr<UserManager> userManager;
- std::unique_ptr<AuthMiddleware> authMiddleware;
- std::unique_ptr<httplib::Server> server;
- std::thread serverThread;
- };
- TEST_F(AuthMiddlewareHttpTest, PublicEndpointAccessible) {
- // Test that public endpoints are accessible without authentication
- httplib::Client client("localhost", 8080);
- auto res = client.Get("/api/health");
- EXPECT_EQ(res->status, 200);
- EXPECT_NE(res->body.find("healthy"), std::string::npos);
- }
- TEST_F(AuthMiddlewareHttpTest, ProtectedEndpointRequiresAuth) {
- // Test that protected endpoints return 401 without authentication
- httplib::Client client("localhost", 8080);
- auto res = client.Get("/api/models");
- EXPECT_EQ(res->status, 401);
- EXPECT_NE(res->body.find("Authentication required"), std::string::npos);
- }
- int main(int argc, char** argv) {
- ::testing::InitGoogleTest(&argc, argv);
- return RUN_ALL_TESTS();
- }
|