#include "auth_middleware.h" #include "user_manager.h" #include #include #include #include class AuthMiddlewareSecurityTest : public ::testing::Test { protected: void SetUp() override { userManager = std::make_shared("./test-auth", UserManager::AuthMethod::JWT); ASSERT_TRUE(userManager->initialize()); } void TearDown() override { // Clean up test data std::filesystem::remove_all("./test-auth"); } std::shared_ptr userManager; }; TEST_F(AuthMiddlewareSecurityTest, DefaultPublicPathsWhenAuthDisabled) { // Test that when auth is disabled, no paths require authentication AuthConfig config; config.authMethod = AuthMethod::NONE; auto authMiddleware = std::make_unique(config, userManager); ASSERT_TRUE(authMiddleware->initialize()); // All paths should be public when auth is disabled EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health")); EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status")); EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/models")); EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/generate")); EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/queue/status")); } TEST_F(AuthMiddlewareSecurityTest, DefaultPublicPathsWhenAuthEnabled) { // Test that when auth is enabled, only essential endpoints are public AuthConfig config; config.authMethod = AuthMethod::JWT; auto authMiddleware = std::make_unique(config, userManager); ASSERT_TRUE(authMiddleware->initialize()); // Only health and status should be public by default EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health")); EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status")); // All other endpoints should require authentication EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models/types")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models/directories")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/samplers")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/schedulers")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/parameters")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/queue/status")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/queue/job/123")); } TEST_F(AuthMiddlewareSecurityTest, CustomPublicPathsConfiguration) { // Test custom public paths configuration AuthConfig config; config.authMethod = AuthMethod::JWT; config.customPublicPaths = "/api/health,/api/status,/api/models,/api/samplers"; auto authMiddleware = std::make_unique(config, userManager); ASSERT_TRUE(authMiddleware->initialize()); // Configured public paths should be accessible EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health")); EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status")); EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/models")); EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/samplers")); // Non-configured paths should require authentication EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models/types")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/schedulers")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate")); } TEST_F(AuthMiddlewareSecurityTest, CustomPublicPathsWithSpaces) { // Test custom public paths with spaces AuthConfig config; config.authMethod = AuthMethod::JWT; config.customPublicPaths = " /api/health , /api/status , /api/models "; auto authMiddleware = std::make_unique(config, userManager); ASSERT_TRUE(authMiddleware->initialize()); // Spaces should be trimmed properly EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health")); EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status")); EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/models")); // Non-configured paths should require authentication EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/samplers")); } TEST_F(AuthMiddlewareSecurityTest, CustomPublicPathsAutoPrefix) { // Test that paths without leading slash get auto-prefixed AuthConfig config; config.authMethod = AuthMethod::JWT; config.customPublicPaths = "api/health,api/status"; auto authMiddleware = std::make_unique(config, userManager); ASSERT_TRUE(authMiddleware->initialize()); // Paths should be accessible with or without leading slash EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health")); EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status")); } TEST_F(AuthMiddlewareSecurityTest, ModelDiscoveryEndpointsProtected) { // Test that all model discovery endpoints require authentication when enabled AuthConfig config; config.authMethod = AuthMethod::JWT; auto authMiddleware = std::make_unique(config, userManager); ASSERT_TRUE(authMiddleware->initialize()); // All model discovery endpoints should require authentication EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models/types")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models/directories")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models/stats")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models/123")); } TEST_F(AuthMiddlewareSecurityTest, GenerationEndpointsProtected) { // Test that all generation endpoints require authentication when enabled AuthConfig config; config.authMethod = AuthMethod::JWT; auto authMiddleware = std::make_unique(config, userManager); ASSERT_TRUE(authMiddleware->initialize()); // All generation endpoints should require authentication EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate/text2img")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate/img2img")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate/controlnet")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate/upscale")); } TEST_F(AuthMiddlewareSecurityTest, QueueEndpointsProtected) { // Test that queue endpoints require authentication when enabled AuthConfig config; config.authMethod = AuthMethod::JWT; auto authMiddleware = std::make_unique(config, userManager); ASSERT_TRUE(authMiddleware->initialize()); // Queue endpoints should require authentication EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/queue/status")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/queue/job/123")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/queue/cancel")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/queue/clear")); } TEST_F(AuthMiddlewareSecurityTest, AuthenticationMethodsConsistency) { // Test that authentication enforcement is consistent across different auth methods std::vector authMethods = { AuthMethod::JWT, AuthMethod::API_KEY, AuthMethod::UNIX, AuthMethod::PAM }; for (auto authMethod : authMethods) { AuthConfig config; config.authMethod = authMethod; auto authMiddleware = std::make_unique(config, userManager); ASSERT_TRUE(authMiddleware->initialize()); // Health and status should always be public EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health")); EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status")); // Model discovery should require authentication EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/samplers")); } } TEST_F(AuthMiddlewareSecurityTest, OptionalAuthWithGuestAccess) { // Test optional authentication with guest access enabled AuthConfig config; config.authMethod = AuthMethod::OPTIONAL; config.enableGuestAccess = true; auto authMiddleware = std::make_unique(config, userManager); ASSERT_TRUE(authMiddleware->initialize()); // With guest access enabled, public paths should be accessible EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health")); EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status")); // But protected paths should still require authentication EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate")); } TEST_F(AuthMiddlewareSecurityTest, OptionalAuthWithoutGuestAccess) { // Test optional authentication without guest access AuthConfig config; config.authMethod = AuthMethod::OPTIONAL; config.enableGuestAccess = false; auto authMiddleware = std::make_unique(config, userManager); ASSERT_TRUE(authMiddleware->initialize()); // Without guest access, only public paths should be accessible EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/health")); EXPECT_FALSE(authMiddleware->requiresAuthentication("/api/status")); // All other paths should require authentication EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/models")); EXPECT_TRUE(authMiddleware->requiresAuthentication("/api/generate")); } // Integration test with actual HTTP requests class AuthMiddlewareHttpTest : public ::testing::Test { protected: void SetUp() override { userManager = std::make_shared("./test-auth-http", UserManager::AuthMethod::JWT); ASSERT_TRUE(userManager->initialize()); AuthConfig config; config.authMethod = AuthMethod::JWT; config.jwtSecret = "test-secret"; authMiddleware = std::make_unique(config, userManager); ASSERT_TRUE(authMiddleware->initialize()); server = std::make_unique(); // Set up test endpoints server->Get("/api/health", [](const httplib::Request&, httplib::Response& res) { res.set_content("{\"status\":\"healthy\"}", "application/json"); }); server->Get("/api/models", [this](const httplib::Request& req, httplib::Response& res) { auto authContext = authMiddleware->authenticate(req, res); if (!authContext.authenticated) { authMiddleware->sendAuthError(res, "Authentication required", "AUTH_REQUIRED"); return; } res.set_content("{\"models\":[]}", "application/json"); }); // Start server in background serverThread = std::thread([this]() { server->listen("localhost", 0); // Use port 0 to get random port }); // Wait for server to start std::this_thread::sleep_for(std::chrono::milliseconds(100)); } void TearDown() override { if (server) { server->stop(); } if (serverThread.joinable()) { serverThread.join(); } std::filesystem::remove_all("./test-auth-http"); } std::shared_ptr userManager; std::unique_ptr authMiddleware; std::unique_ptr server; std::thread serverThread; }; TEST_F(AuthMiddlewareHttpTest, PublicEndpointAccessible) { // Test that public endpoints are accessible without authentication httplib::Client client("localhost", 8080); auto res = client.Get("/api/health"); EXPECT_EQ(res->status, 200); EXPECT_NE(res->body.find("healthy"), std::string::npos); } TEST_F(AuthMiddlewareHttpTest, ProtectedEndpointRequiresAuth) { // Test that protected endpoints return 401 without authentication httplib::Client client("localhost", 8080); auto res = client.Get("/api/models"); EXPECT_EQ(res->status, 401); EXPECT_NE(res->body.find("Authentication required"), std::string::npos); } int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); }