|
@@ -3,6 +3,8 @@
|
|
|
#include "model_manager.h"
|
|
#include "model_manager.h"
|
|
|
#include "generation_queue.h"
|
|
#include "generation_queue.h"
|
|
|
#include "utils.h"
|
|
#include "utils.h"
|
|
|
|
|
+#include "auth_middleware.h"
|
|
|
|
|
+#include "user_manager.h"
|
|
|
#include <httplib.h>
|
|
#include <httplib.h>
|
|
|
#include <nlohmann/json.hpp>
|
|
#include <nlohmann/json.hpp>
|
|
|
#include <iostream>
|
|
#include <iostream>
|
|
@@ -32,6 +34,8 @@ Server::Server(ModelManager* modelManager, GenerationQueue* generationQueue, con
|
|
|
, m_port(8080)
|
|
, m_port(8080)
|
|
|
, m_outputDir(outputDir)
|
|
, m_outputDir(outputDir)
|
|
|
, m_uiDir(uiDir)
|
|
, m_uiDir(uiDir)
|
|
|
|
|
+ , m_userManager(nullptr)
|
|
|
|
|
+ , m_authMiddleware(nullptr)
|
|
|
{
|
|
{
|
|
|
m_httpServer = std::make_unique<httplib::Server>();
|
|
m_httpServer = std::make_unique<httplib::Server>();
|
|
|
}
|
|
}
|
|
@@ -147,150 +151,172 @@ void Server::waitForStop() {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
void Server::registerEndpoints() {
|
|
void Server::registerEndpoints() {
|
|
|
- // Health check endpoint
|
|
|
|
|
|
|
+ // Register authentication endpoints first (before applying middleware)
|
|
|
|
|
+ registerAuthEndpoints();
|
|
|
|
|
+
|
|
|
|
|
+ // Health check endpoint (public)
|
|
|
m_httpServer->Get("/api/health", [this](const httplib::Request& req, httplib::Response& res) {
|
|
m_httpServer->Get("/api/health", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleHealthCheck(req, res);
|
|
handleHealthCheck(req, res);
|
|
|
});
|
|
});
|
|
|
|
|
|
|
|
- // API status endpoint
|
|
|
|
|
|
|
+ // API status endpoint (public)
|
|
|
m_httpServer->Get("/api/status", [this](const httplib::Request& req, httplib::Response& res) {
|
|
m_httpServer->Get("/api/status", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleApiStatus(req, res);
|
|
handleApiStatus(req, res);
|
|
|
});
|
|
});
|
|
|
|
|
|
|
|
- // Specialized generation endpoints
|
|
|
|
|
- m_httpServer->Post("/api/generate/text2img", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ // Apply authentication middleware to protected endpoints
|
|
|
|
|
+ auto withAuth = [this](std::function<void(const httplib::Request&, httplib::Response&)> handler) {
|
|
|
|
|
+ return [this, handler](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
+ if (m_authMiddleware) {
|
|
|
|
|
+ AuthContext authContext = m_authMiddleware->authenticate(req, res);
|
|
|
|
|
+ if (!authContext.authenticated) {
|
|
|
|
|
+ m_authMiddleware->sendAuthError(res, authContext.errorMessage, authContext.errorCode);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ handler(req, res);
|
|
|
|
|
+ };
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ // Specialized generation endpoints (protected)
|
|
|
|
|
+ m_httpServer->Post("/api/generate/text2img", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleText2Img(req, res);
|
|
handleText2Img(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- m_httpServer->Post("/api/generate/img2img", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ m_httpServer->Post("/api/generate/img2img", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleImg2Img(req, res);
|
|
handleImg2Img(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- m_httpServer->Post("/api/generate/controlnet", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ m_httpServer->Post("/api/generate/controlnet", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleControlNet(req, res);
|
|
handleControlNet(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- m_httpServer->Post("/api/generate/upscale", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ m_httpServer->Post("/api/generate/upscale", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleUpscale(req, res);
|
|
handleUpscale(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
+
|
|
|
|
|
+ m_httpServer->Post("/api/generate/inpainting", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
+ handleInpainting(req, res);
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- // Utility endpoints
|
|
|
|
|
- m_httpServer->Get("/api/samplers", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ // Utility endpoints (now protected - require authentication)
|
|
|
|
|
+ m_httpServer->Get("/api/samplers", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleSamplers(req, res);
|
|
handleSamplers(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- m_httpServer->Get("/api/schedulers", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ m_httpServer->Get("/api/schedulers", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleSchedulers(req, res);
|
|
handleSchedulers(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- m_httpServer->Get("/api/parameters", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ m_httpServer->Get("/api/parameters", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleParameters(req, res);
|
|
handleParameters(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
m_httpServer->Post("/api/validate", [this](const httplib::Request& req, httplib::Response& res) {
|
|
m_httpServer->Post("/api/validate", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleValidate(req, res);
|
|
handleValidate(req, res);
|
|
|
});
|
|
});
|
|
|
|
|
|
|
|
- m_httpServer->Post("/api/estimate", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ m_httpServer->Post("/api/estimate", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleEstimate(req, res);
|
|
handleEstimate(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- m_httpServer->Get("/api/config", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ m_httpServer->Get("/api/config", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleConfig(req, res);
|
|
handleConfig(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- m_httpServer->Get("/api/system", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ m_httpServer->Get("/api/system", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleSystem(req, res);
|
|
handleSystem(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- m_httpServer->Post("/api/system/restart", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ m_httpServer->Post("/api/system/restart", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleSystemRestart(req, res);
|
|
handleSystemRestart(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- // Models list endpoint
|
|
|
|
|
- m_httpServer->Get("/api/models", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ // Models list endpoint (now protected - require authentication)
|
|
|
|
|
+ m_httpServer->Get("/api/models", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleModelsList(req, res);
|
|
handleModelsList(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
// Model-specific endpoints
|
|
// Model-specific endpoints
|
|
|
m_httpServer->Get("/api/models/(.*)", [this](const httplib::Request& req, httplib::Response& res) {
|
|
m_httpServer->Get("/api/models/(.*)", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleModelInfo(req, res);
|
|
handleModelInfo(req, res);
|
|
|
});
|
|
});
|
|
|
|
|
|
|
|
- m_httpServer->Post("/api/models/(.*)/load", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ m_httpServer->Post("/api/models/(.*)/load", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleLoadModelById(req, res);
|
|
handleLoadModelById(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- m_httpServer->Post("/api/models/(.*)/unload", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ m_httpServer->Post("/api/models/(.*)/unload", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleUnloadModelById(req, res);
|
|
handleUnloadModelById(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- // Model management endpoints
|
|
|
|
|
- m_httpServer->Get("/api/models/types", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ // Model management endpoints (now protected - require authentication)
|
|
|
|
|
+ m_httpServer->Get("/api/models/types", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleModelTypes(req, res);
|
|
handleModelTypes(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- m_httpServer->Get("/api/models/directories", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ m_httpServer->Get("/api/models/directories", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleModelDirectories(req, res);
|
|
handleModelDirectories(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- m_httpServer->Post("/api/models/refresh", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ m_httpServer->Post("/api/models/refresh", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleRefreshModels(req, res);
|
|
handleRefreshModels(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- m_httpServer->Post("/api/models/hash", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ m_httpServer->Post("/api/models/hash", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleHashModels(req, res);
|
|
handleHashModels(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- m_httpServer->Post("/api/models/convert", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ m_httpServer->Post("/api/models/convert", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleConvertModel(req, res);
|
|
handleConvertModel(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- m_httpServer->Get("/api/models/stats", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ m_httpServer->Get("/api/models/stats", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleModelStats(req, res);
|
|
handleModelStats(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- m_httpServer->Post("/api/models/batch", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ m_httpServer->Post("/api/models/batch", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleBatchModels(req, res);
|
|
handleBatchModels(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- // Model validation endpoints
|
|
|
|
|
- m_httpServer->Post("/api/models/validate", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ // Model validation endpoints (already protected with withAuth)
|
|
|
|
|
+ m_httpServer->Post("/api/models/validate", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleValidateModel(req, res);
|
|
handleValidateModel(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- m_httpServer->Post("/api/models/compatible", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ m_httpServer->Post("/api/models/compatible", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleCheckCompatibility(req, res);
|
|
handleCheckCompatibility(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- m_httpServer->Post("/api/models/requirements", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ m_httpServer->Post("/api/models/requirements", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleModelRequirements(req, res);
|
|
handleModelRequirements(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- // Queue status endpoint
|
|
|
|
|
- m_httpServer->Get("/api/queue/status", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ // Queue status endpoint (now protected - require authentication)
|
|
|
|
|
+ m_httpServer->Get("/api/queue/status", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleQueueStatus(req, res);
|
|
handleQueueStatus(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
// Download job output file endpoint (must be before job status endpoint to match more specific pattern first)
|
|
// Download job output file endpoint (must be before job status endpoint to match more specific pattern first)
|
|
|
|
|
+ // Note: This endpoint is public to allow frontend to display generated images without authentication
|
|
|
m_httpServer->Get("/api/queue/job/(.*)/output/(.*)", [this](const httplib::Request& req, httplib::Response& res) {
|
|
m_httpServer->Get("/api/queue/job/(.*)/output/(.*)", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleDownloadOutput(req, res);
|
|
handleDownloadOutput(req, res);
|
|
|
});
|
|
});
|
|
|
|
|
|
|
|
- // Job status endpoint
|
|
|
|
|
- m_httpServer->Get("/api/queue/job/(.*)", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ // Job status endpoint (now protected - require authentication)
|
|
|
|
|
+ m_httpServer->Get("/api/queue/job/(.*)", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleJobStatus(req, res);
|
|
handleJobStatus(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- // Cancel job endpoint
|
|
|
|
|
- m_httpServer->Post("/api/queue/cancel", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ // Cancel job endpoint (protected)
|
|
|
|
|
+ m_httpServer->Post("/api/queue/cancel", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleCancelJob(req, res);
|
|
handleCancelJob(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
- // Clear queue endpoint
|
|
|
|
|
- m_httpServer->Post("/api/queue/clear", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
|
|
+ // Clear queue endpoint (protected)
|
|
|
|
|
+ m_httpServer->Post("/api/queue/clear", withAuth([this](const httplib::Request& req, httplib::Response& res) {
|
|
|
handleClearQueue(req, res);
|
|
handleClearQueue(req, res);
|
|
|
- });
|
|
|
|
|
|
|
+ }));
|
|
|
|
|
|
|
|
// Serve static web UI files if uiDir is configured
|
|
// Serve static web UI files if uiDir is configured
|
|
|
if (!m_uiDir.empty() && std::filesystem::exists(m_uiDir)) {
|
|
if (!m_uiDir.empty() && std::filesystem::exists(m_uiDir)) {
|
|
@@ -325,8 +351,31 @@ void Server::registerEndpoints() {
|
|
|
<< " apiBasePath: '/api',\n"
|
|
<< " apiBasePath: '/api',\n"
|
|
|
<< " host: '" << m_host << "',\n"
|
|
<< " host: '" << m_host << "',\n"
|
|
|
<< " port: " << m_port << ",\n"
|
|
<< " port: " << m_port << ",\n"
|
|
|
- << " uiVersion: '" << uiVersion << "'\n"
|
|
|
|
|
- << "};\n";
|
|
|
|
|
|
|
+ << " uiVersion: '" << uiVersion << "',\n";
|
|
|
|
|
+
|
|
|
|
|
+ // Add authentication method information
|
|
|
|
|
+ if (m_authMiddleware) {
|
|
|
|
|
+ auto authConfig = m_authMiddleware->getConfig();
|
|
|
|
|
+ std::string authMethod = "none";
|
|
|
|
|
+ switch (authConfig.authMethod) {
|
|
|
|
|
+ case AuthMethod::UNIX:
|
|
|
|
|
+ authMethod = "unix";
|
|
|
|
|
+ break;
|
|
|
|
|
+ case AuthMethod::JWT:
|
|
|
|
|
+ authMethod = "jwt";
|
|
|
|
|
+ break;
|
|
|
|
|
+ default:
|
|
|
|
|
+ authMethod = "none";
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+ configJs << " authMethod: '" << authMethod << "',\n"
|
|
|
|
|
+ << " authEnabled: " << (authConfig.authMethod != AuthMethod::NONE ? "true" : "false") << "\n";
|
|
|
|
|
+ } else {
|
|
|
|
|
+ configJs << " authMethod: 'none',\n"
|
|
|
|
|
+ << " authEnabled: false\n";
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ configJs << "};\n";
|
|
|
|
|
|
|
|
// No cache for config.js - always fetch fresh
|
|
// No cache for config.js - always fetch fresh
|
|
|
res.set_header("Cache-Control", "no-cache, no-store, must-revalidate");
|
|
res.set_header("Cache-Control", "no-cache, no-store, must-revalidate");
|
|
@@ -373,10 +422,163 @@ void Server::registerEndpoints() {
|
|
|
}
|
|
}
|
|
|
});
|
|
});
|
|
|
|
|
|
|
|
- // Mount the static file directory at /ui
|
|
|
|
|
- if (!m_httpServer->set_mount_point("/ui", m_uiDir)) {
|
|
|
|
|
- std::cerr << "Failed to mount UI directory: " << m_uiDir << std::endl;
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ // Create a handler for UI routes with authentication check
|
|
|
|
|
+ auto uiHandler = [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
+ // Check if authentication is enabled
|
|
|
|
|
+ if (m_authMiddleware) {
|
|
|
|
|
+ auto authConfig = m_authMiddleware->getConfig();
|
|
|
|
|
+ if (authConfig.authMethod != AuthMethod::NONE) {
|
|
|
|
|
+ // Authentication is enabled, check if user is authenticated
|
|
|
|
|
+ AuthContext authContext = m_authMiddleware->authenticate(req, res);
|
|
|
|
|
+
|
|
|
|
|
+ // For Unix auth, we need to check if the user is authenticated
|
|
|
|
|
+ // The authenticateUnix function will return a guest context for UI requests
|
|
|
|
|
+ // when no Authorization header is present, but we still need to show the login page
|
|
|
|
|
+ if (!authContext.authenticated) {
|
|
|
|
|
+ // Check if this is a request for a static asset (JS, CSS, images)
|
|
|
|
|
+ // These should be served even without authentication to allow the login page to work
|
|
|
|
|
+ bool isStaticAsset = false;
|
|
|
|
|
+ std::string path = req.path;
|
|
|
|
|
+ if (path.find(".js") != std::string::npos ||
|
|
|
|
|
+ path.find(".css") != std::string::npos ||
|
|
|
|
|
+ path.find(".png") != std::string::npos ||
|
|
|
|
|
+ path.find(".jpg") != std::string::npos ||
|
|
|
|
|
+ path.find(".jpeg") != std::string::npos ||
|
|
|
|
|
+ path.find(".svg") != std::string::npos ||
|
|
|
|
|
+ path.find(".ico") != std::string::npos ||
|
|
|
|
|
+ path.find("/_next/") != std::string::npos) {
|
|
|
|
|
+ isStaticAsset = true;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // For static assets, allow them to be served without authentication
|
|
|
|
|
+ if (isStaticAsset) {
|
|
|
|
|
+ // Continue to serve the file
|
|
|
|
|
+ } else {
|
|
|
|
|
+ // For HTML requests, redirect to login page
|
|
|
|
|
+ if (req.path.find(".html") != std::string::npos ||
|
|
|
|
|
+ req.path == "/ui/" || req.path == "/ui") {
|
|
|
|
|
+ // Serve the login page instead of the requested page
|
|
|
|
|
+ std::string loginPagePath = m_uiDir + "/login.html";
|
|
|
|
|
+ if (std::filesystem::exists(loginPagePath)) {
|
|
|
|
|
+ std::ifstream loginFile(loginPagePath);
|
|
|
|
|
+ if (loginFile.is_open()) {
|
|
|
|
|
+ std::string content((std::istreambuf_iterator<char>(loginFile)),
|
|
|
|
|
+ std::istreambuf_iterator<char>());
|
|
|
|
|
+ res.set_content(content, "text/html");
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ // If login.html doesn't exist, serve a simple login page
|
|
|
|
|
+ std::string simpleLoginPage = R"(
|
|
|
|
|
+<!DOCTYPE html>
|
|
|
|
|
+<html>
|
|
|
|
|
+<head>
|
|
|
|
|
+ <title>Login Required</title>
|
|
|
|
|
+ <style>
|
|
|
|
|
+ body { font-family: Arial, sans-serif; max-width: 500px; margin: 100px auto; padding: 20px; }
|
|
|
|
|
+ .form-group { margin-bottom: 15px; }
|
|
|
|
|
+ label { display: block; margin-bottom: 5px; }
|
|
|
|
|
+ input { width: 100%; padding: 8px; box-sizing: border-box; }
|
|
|
|
|
+ button { background-color: #007bff; color: white; padding: 10px 15px; border: none; cursor: pointer; }
|
|
|
|
|
+ .error { color: red; margin-top: 10px; }
|
|
|
|
|
+ </style>
|
|
|
|
|
+</head>
|
|
|
|
|
+<body>
|
|
|
|
|
+ <h1>Login Required</h1>
|
|
|
|
|
+ <p>Please enter your username to continue.</p>
|
|
|
|
|
+ <form id="loginForm">
|
|
|
|
|
+ <div class="form-group">
|
|
|
|
|
+ <label for="username">Username:</label>
|
|
|
|
|
+ <input type="text" id="username" name="username" required>
|
|
|
|
|
+ </div>
|
|
|
|
|
+ <button type="submit">Login</button>
|
|
|
|
|
+ </form>
|
|
|
|
|
+ <div id="error" class="error"></div>
|
|
|
|
|
+ <script>
|
|
|
|
|
+ document.getElementById('loginForm').addEventListener('submit', async (e) => {
|
|
|
|
|
+ e.preventDefault();
|
|
|
|
|
+ const username = document.getElementById('username').value;
|
|
|
|
|
+ const errorDiv = document.getElementById('error');
|
|
|
|
|
+
|
|
|
|
|
+ try {
|
|
|
|
|
+ const response = await fetch('/api/auth/login', {
|
|
|
|
|
+ method: 'POST',
|
|
|
|
|
+ headers: { 'Content-Type': 'application/json' },
|
|
|
|
|
+ body: JSON.stringify({ username })
|
|
|
|
|
+ });
|
|
|
|
|
+
|
|
|
|
|
+ if (response.ok) {
|
|
|
|
|
+ const data = await response.json();
|
|
|
|
|
+ localStorage.setItem('auth_token', data.token);
|
|
|
|
|
+ localStorage.setItem('unix_user', username);
|
|
|
|
|
+ window.location.reload();
|
|
|
|
|
+ } else {
|
|
|
|
|
+ const error = await response.json();
|
|
|
|
|
+ errorDiv.textContent = error.message || 'Login failed';
|
|
|
|
|
+ }
|
|
|
|
|
+ } catch (err) {
|
|
|
|
|
+ errorDiv.textContent = 'Login failed: ' + err.message;
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ </script>
|
|
|
|
|
+</body>
|
|
|
|
|
+</html>
|
|
|
|
|
+)";
|
|
|
|
|
+ res.set_content(simpleLoginPage, "text/html");
|
|
|
|
|
+ return;
|
|
|
|
|
+ } else {
|
|
|
|
|
+ // For non-HTML files, return unauthorized
|
|
|
|
|
+ m_authMiddleware->sendAuthError(res, "Authentication required", "AUTH_REQUIRED");
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // If we get here, either auth is disabled or user is authenticated
|
|
|
|
|
+ // Serve the requested file
|
|
|
|
|
+ std::string filePath = req.path.substr(3); // Remove "/ui" prefix
|
|
|
|
|
+ if (filePath.empty() || filePath == "/") {
|
|
|
|
|
+ filePath = "/index.html";
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ std::string fullPath = m_uiDir + filePath;
|
|
|
|
|
+ if (std::filesystem::exists(fullPath) && std::filesystem::is_regular_file(fullPath)) {
|
|
|
|
|
+ std::ifstream file(fullPath, std::ios::binary);
|
|
|
|
|
+ if (file.is_open()) {
|
|
|
|
|
+ std::string content((std::istreambuf_iterator<char>(file)),
|
|
|
|
|
+ std::istreambuf_iterator<char>());
|
|
|
|
|
+
|
|
|
|
|
+ // Determine content type based on file extension
|
|
|
|
|
+ std::string contentType = "text/plain";
|
|
|
|
|
+ if (filePath.find(".html") != std::string::npos) {
|
|
|
|
|
+ contentType = "text/html";
|
|
|
|
|
+ } else if (filePath.find(".js") != std::string::npos) {
|
|
|
|
|
+ contentType = "application/javascript";
|
|
|
|
|
+ } else if (filePath.find(".css") != std::string::npos) {
|
|
|
|
|
+ contentType = "text/css";
|
|
|
|
|
+ } else if (filePath.find(".png") != std::string::npos) {
|
|
|
|
|
+ contentType = "image/png";
|
|
|
|
|
+ } else if (filePath.find(".jpg") != std::string::npos || filePath.find(".jpeg") != std::string::npos) {
|
|
|
|
|
+ contentType = "image/jpeg";
|
|
|
|
|
+ } else if (filePath.find(".svg") != std::string::npos) {
|
|
|
|
|
+ contentType = "image/svg+xml";
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ res.set_content(content, contentType);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ res.status = 404;
|
|
|
|
|
+ res.set_content("File not found", "text/plain");
|
|
|
|
|
+ }
|
|
|
|
|
+ } else {
|
|
|
|
|
+ res.status = 404;
|
|
|
|
|
+ res.set_content("File not found", "text/plain");
|
|
|
|
|
+ }
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ // Set up UI routes with authentication
|
|
|
|
|
+ m_httpServer->Get("/ui/.*", uiHandler);
|
|
|
|
|
|
|
|
// Redirect /ui to /ui/ to ensure proper routing
|
|
// Redirect /ui to /ui/ to ensure proper routing
|
|
|
m_httpServer->Get("/ui", [](const httplib::Request& req, httplib::Response& res) {
|
|
m_httpServer->Get("/ui", [](const httplib::Request& req, httplib::Response& res) {
|
|
@@ -385,6 +587,273 @@ void Server::registerEndpoints() {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+void Server::setAuthComponents(std::shared_ptr<UserManager> userManager, std::shared_ptr<AuthMiddleware> authMiddleware) {
|
|
|
|
|
+ m_userManager = userManager;
|
|
|
|
|
+ m_authMiddleware = authMiddleware;
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+void Server::registerAuthEndpoints() {
|
|
|
|
|
+ // Login endpoint
|
|
|
|
|
+ m_httpServer->Post("/api/auth/login", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
+ handleLogin(req, res);
|
|
|
|
|
+ });
|
|
|
|
|
+
|
|
|
|
|
+ // Logout endpoint
|
|
|
|
|
+ m_httpServer->Post("/api/auth/logout", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
+ handleLogout(req, res);
|
|
|
|
|
+ });
|
|
|
|
|
+
|
|
|
|
|
+ // Token validation endpoint
|
|
|
|
|
+ m_httpServer->Get("/api/auth/validate", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
+ handleValidateToken(req, res);
|
|
|
|
|
+ });
|
|
|
|
|
+
|
|
|
|
|
+ // Refresh token endpoint
|
|
|
|
|
+ m_httpServer->Post("/api/auth/refresh", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
+ handleRefreshToken(req, res);
|
|
|
|
|
+ });
|
|
|
|
|
+
|
|
|
|
|
+ // Get current user endpoint
|
|
|
|
|
+ m_httpServer->Get("/api/auth/me", [this](const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
+ handleGetCurrentUser(req, res);
|
|
|
|
|
+ });
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+void Server::handleLogin(const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
+ std::string requestId = generateRequestId();
|
|
|
|
|
+
|
|
|
|
|
+ try {
|
|
|
|
|
+ if (!m_userManager || !m_authMiddleware) {
|
|
|
|
|
+ sendErrorResponse(res, "Authentication system not available", 500, "AUTH_UNAVAILABLE", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Parse request body
|
|
|
|
|
+ json requestJson;
|
|
|
|
|
+ try {
|
|
|
|
|
+ requestJson = json::parse(req.body);
|
|
|
|
|
+ } catch (const json::parse_error& e) {
|
|
|
|
|
+ sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Check if using Unix authentication
|
|
|
|
|
+ if (m_authMiddleware->getConfig().authMethod == AuthMethod::UNIX) {
|
|
|
|
|
+ // For Unix auth, get username and password from request body
|
|
|
|
|
+ std::string username = requestJson.value("username", "");
|
|
|
|
|
+ std::string password = requestJson.value("password", "");
|
|
|
|
|
+
|
|
|
|
|
+ if (username.empty()) {
|
|
|
|
|
+ sendErrorResponse(res, "Missing username", 400, "MISSING_USERNAME", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Check if PAM is enabled - if so, password is required
|
|
|
|
|
+ if (m_userManager->isPamAuthEnabled() && password.empty()) {
|
|
|
|
|
+ sendErrorResponse(res, "Password is required for Unix authentication", 400, "MISSING_PASSWORD", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Authenticate Unix user (with or without password depending on PAM)
|
|
|
|
|
+ auto result = m_userManager->authenticateUnix(username, password);
|
|
|
|
|
+
|
|
|
|
|
+ if (!result.success) {
|
|
|
|
|
+ sendErrorResponse(res, result.errorMessage, 401, "UNIX_AUTH_FAILED", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Generate simple token for Unix auth
|
|
|
|
|
+ std::string token = "unix_token_" + std::to_string(std::chrono::duration_cast<std::chrono::seconds>(
|
|
|
|
|
+ std::chrono::system_clock::now().time_since_epoch()).count()) + "_" + username;
|
|
|
|
|
+
|
|
|
|
|
+ json response = {
|
|
|
|
|
+ {"token", token},
|
|
|
|
|
+ {"user", {
|
|
|
|
|
+ {"id", result.userId},
|
|
|
|
|
+ {"username", result.username},
|
|
|
|
|
+ {"role", result.role},
|
|
|
|
|
+ {"permissions", result.permissions}
|
|
|
|
|
+ }},
|
|
|
|
|
+ {"message", "Unix authentication successful"}
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ sendJsonResponse(res, response);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // For non-Unix auth, validate required fields
|
|
|
|
|
+ if (!requestJson.contains("username") || !requestJson.contains("password")) {
|
|
|
|
|
+ sendErrorResponse(res, "Missing username or password", 400, "MISSING_CREDENTIALS", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ std::string username = requestJson["username"];
|
|
|
|
|
+ std::string password = requestJson["password"];
|
|
|
|
|
+
|
|
|
|
|
+ // Authenticate user
|
|
|
|
|
+ auto result = m_userManager->authenticateUser(username, password);
|
|
|
|
|
+
|
|
|
|
|
+ if (!result.success) {
|
|
|
|
|
+ sendErrorResponse(res, result.errorMessage, 401, "INVALID_CREDENTIALS", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Generate JWT token if using JWT auth
|
|
|
|
|
+ std::string token;
|
|
|
|
|
+ if (m_authMiddleware->getConfig().authMethod == AuthMethod::JWT) {
|
|
|
|
|
+ // For now, create a simple token (in a real implementation, use JWT)
|
|
|
|
|
+ token = "token_" + std::to_string(std::chrono::duration_cast<std::chrono::seconds>(
|
|
|
|
|
+ std::chrono::system_clock::now().time_since_epoch()).count()) + "_" + username;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ json response = {
|
|
|
|
|
+ {"token", token},
|
|
|
|
|
+ {"user", {
|
|
|
|
|
+ {"id", result.userId},
|
|
|
|
|
+ {"username", result.username},
|
|
|
|
|
+ {"role", result.role},
|
|
|
|
|
+ {"permissions", result.permissions}
|
|
|
|
|
+ }},
|
|
|
|
|
+ {"message", "Login successful"}
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ sendJsonResponse(res, response);
|
|
|
|
|
+
|
|
|
|
|
+ } catch (const std::exception& e) {
|
|
|
|
|
+ sendErrorResponse(res, std::string("Login failed: ") + e.what(), 500, "LOGIN_ERROR", requestId);
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+void Server::handleLogout(const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
+ std::string requestId = generateRequestId();
|
|
|
|
|
+
|
|
|
|
|
+ try {
|
|
|
|
|
+ // For now, just return success (in a real implementation, invalidate the token)
|
|
|
|
|
+ json response = {
|
|
|
|
|
+ {"message", "Logout successful"}
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ sendJsonResponse(res, response);
|
|
|
|
|
+
|
|
|
|
|
+ } catch (const std::exception& e) {
|
|
|
|
|
+ sendErrorResponse(res, std::string("Logout failed: ") + e.what(), 500, "LOGOUT_ERROR", requestId);
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+void Server::handleValidateToken(const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
+ std::string requestId = generateRequestId();
|
|
|
|
|
+
|
|
|
|
|
+ try {
|
|
|
|
|
+ if (!m_userManager || !m_authMiddleware) {
|
|
|
|
|
+ sendErrorResponse(res, "Authentication system not available", 500, "AUTH_UNAVAILABLE", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Extract token from header
|
|
|
|
|
+ std::string authHeader = req.get_header_value("Authorization");
|
|
|
|
|
+ if (authHeader.empty()) {
|
|
|
|
|
+ sendErrorResponse(res, "Missing authorization token", 401, "MISSING_TOKEN", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Simple token validation (in a real implementation, validate JWT)
|
|
|
|
|
+ // For now, just check if it starts with "token_"
|
|
|
|
|
+ if (authHeader.find("Bearer ") != 0) {
|
|
|
|
|
+ sendErrorResponse(res, "Invalid authorization header format", 401, "INVALID_HEADER", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ std::string token = authHeader.substr(7); // Remove "Bearer "
|
|
|
|
|
+
|
|
|
|
|
+ if (token.find("token_") != 0) {
|
|
|
|
|
+ sendErrorResponse(res, "Invalid token", 401, "INVALID_TOKEN", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Extract username from token (simple format: token_timestamp_username)
|
|
|
|
|
+ size_t last_underscore = token.find_last_of('_');
|
|
|
|
|
+ if (last_underscore == std::string::npos) {
|
|
|
|
|
+ sendErrorResponse(res, "Invalid token format", 401, "INVALID_TOKEN", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ std::string username = token.substr(last_underscore + 1);
|
|
|
|
|
+
|
|
|
|
|
+ // Get user info
|
|
|
|
|
+ auto userInfo = m_userManager->getUserInfoByUsername(username);
|
|
|
|
|
+ if (userInfo.id.empty()) {
|
|
|
|
|
+ sendErrorResponse(res, "User not found", 401, "USER_NOT_FOUND", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ json response = {
|
|
|
|
|
+ {"user", {
|
|
|
|
|
+ {"id", userInfo.id},
|
|
|
|
|
+ {"username", userInfo.username},
|
|
|
|
|
+ {"role", userInfo.role},
|
|
|
|
|
+ {"permissions", userInfo.permissions}
|
|
|
|
|
+ }},
|
|
|
|
|
+ {"valid", true}
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ sendJsonResponse(res, response);
|
|
|
|
|
+
|
|
|
|
|
+ } catch (const std::exception& e) {
|
|
|
|
|
+ sendErrorResponse(res, std::string("Token validation failed: ") + e.what(), 500, "VALIDATION_ERROR", requestId);
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+void Server::handleRefreshToken(const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
+ std::string requestId = generateRequestId();
|
|
|
|
|
+
|
|
|
|
|
+ try {
|
|
|
|
|
+ // For now, just return a new token (in a real implementation, refresh JWT)
|
|
|
|
|
+ json response = {
|
|
|
|
|
+ {"token", "new_token_" + std::to_string(std::chrono::duration_cast<std::chrono::seconds>(
|
|
|
|
|
+ std::chrono::system_clock::now().time_since_epoch()).count())},
|
|
|
|
|
+ {"message", "Token refreshed successfully"}
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ sendJsonResponse(res, response);
|
|
|
|
|
+
|
|
|
|
|
+ } catch (const std::exception& e) {
|
|
|
|
|
+ sendErrorResponse(res, std::string("Token refresh failed: ") + e.what(), 500, "REFRESH_ERROR", requestId);
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+void Server::handleGetCurrentUser(const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
+ std::string requestId = generateRequestId();
|
|
|
|
|
+
|
|
|
|
|
+ try {
|
|
|
|
|
+ if (!m_userManager || !m_authMiddleware) {
|
|
|
|
|
+ sendErrorResponse(res, "Authentication system not available", 500, "AUTH_UNAVAILABLE", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Authenticate the request
|
|
|
|
|
+ AuthContext authContext = m_authMiddleware->authenticate(req, res);
|
|
|
|
|
+
|
|
|
|
|
+ if (!authContext.authenticated) {
|
|
|
|
|
+ sendErrorResponse(res, "Authentication required", 401, "AUTH_REQUIRED", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ json response = {
|
|
|
|
|
+ {"user", {
|
|
|
|
|
+ {"id", authContext.userId},
|
|
|
|
|
+ {"username", authContext.username},
|
|
|
|
|
+ {"role", authContext.role},
|
|
|
|
|
+ {"permissions", authContext.permissions}
|
|
|
|
|
+ }}
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ sendJsonResponse(res, response);
|
|
|
|
|
+
|
|
|
|
|
+ } catch (const std::exception& e) {
|
|
|
|
|
+ sendErrorResponse(res, std::string("Get current user failed: ") + e.what(), 500, "USER_ERROR", requestId);
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
void Server::setupCORS() {
|
|
void Server::setupCORS() {
|
|
|
// Use post-routing handler to set CORS headers after the response is generated
|
|
// Use post-routing handler to set CORS headers after the response is generated
|
|
|
// This ensures we don't duplicate headers that may be set by other handlers
|
|
// This ensures we don't duplicate headers that may be set by other handlers
|
|
@@ -863,34 +1332,70 @@ void Server::handleDownloadOutput(const httplib::Request& req, httplib::Response
|
|
|
try {
|
|
try {
|
|
|
// Extract job ID and filename from URL path
|
|
// Extract job ID and filename from URL path
|
|
|
if (req.matches.size() < 3) {
|
|
if (req.matches.size() < 3) {
|
|
|
- sendErrorResponse(res, "Invalid request: job ID and filename required", 400);
|
|
|
|
|
|
|
+ sendErrorResponse(res, "Invalid request: job ID and filename required", 400, "INVALID_REQUEST", "");
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
std::string jobId = req.matches[1];
|
|
std::string jobId = req.matches[1];
|
|
|
std::string filename = req.matches[2];
|
|
std::string filename = req.matches[2];
|
|
|
|
|
|
|
|
- // Construct file path using the same logic as when saving:
|
|
|
|
|
|
|
+ // Validate inputs
|
|
|
|
|
+ if (jobId.empty() || filename.empty()) {
|
|
|
|
|
+ sendErrorResponse(res, "Job ID and filename cannot be empty", 400, "INVALID_PARAMETERS", "");
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Construct absolute file path using the same logic as when saving:
|
|
|
// {outputDir}/{jobId}/{filename}
|
|
// {outputDir}/{jobId}/{filename}
|
|
|
- std::string fullPath = m_outputDir + "/" + jobId + "/" + filename;
|
|
|
|
|
|
|
+ std::string fullPath = std::filesystem::absolute(m_outputDir + "/" + jobId + "/" + filename).string();
|
|
|
|
|
+
|
|
|
|
|
+ // Log the request for debugging
|
|
|
|
|
+ std::cout << "Image download request: jobId=" << jobId << ", filename=" << filename
|
|
|
|
|
+ << ", fullPath=" << fullPath << std::endl;
|
|
|
|
|
|
|
|
// Check if file exists
|
|
// Check if file exists
|
|
|
if (!std::filesystem::exists(fullPath)) {
|
|
if (!std::filesystem::exists(fullPath)) {
|
|
|
- sendErrorResponse(res, "Output file not found: " + fullPath, 404);
|
|
|
|
|
|
|
+ std::cerr << "Output file not found: " << fullPath << std::endl;
|
|
|
|
|
+ sendErrorResponse(res, "Output file not found: " + filename, 404, "FILE_NOT_FOUND", "");
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // Check if file exists on filesystem
|
|
|
|
|
|
|
+ // Check file size to detect zero-byte files
|
|
|
|
|
+ auto fileSize = std::filesystem::file_size(fullPath);
|
|
|
|
|
+ if (fileSize == 0) {
|
|
|
|
|
+ std::cerr << "Output file is zero bytes: " << fullPath << std::endl;
|
|
|
|
|
+ sendErrorResponse(res, "Output file is empty (corrupted generation)", 500, "EMPTY_FILE", "");
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Check if file is accessible
|
|
|
std::ifstream file(fullPath, std::ios::binary);
|
|
std::ifstream file(fullPath, std::ios::binary);
|
|
|
if (!file.is_open()) {
|
|
if (!file.is_open()) {
|
|
|
- sendErrorResponse(res, "Output file not accessible", 404);
|
|
|
|
|
|
|
+ std::cerr << "Failed to open output file: " << fullPath << std::endl;
|
|
|
|
|
+ sendErrorResponse(res, "Output file not accessible", 500, "FILE_ACCESS_ERROR", "");
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// Read file contents
|
|
// Read file contents
|
|
|
- std::ostringstream fileContent;
|
|
|
|
|
- fileContent << file.rdbuf();
|
|
|
|
|
- file.close();
|
|
|
|
|
|
|
+ std::string fileContent;
|
|
|
|
|
+ try {
|
|
|
|
|
+ fileContent = std::string(
|
|
|
|
|
+ std::istreambuf_iterator<char>(file),
|
|
|
|
|
+ std::istreambuf_iterator<char>()
|
|
|
|
|
+ );
|
|
|
|
|
+ file.close();
|
|
|
|
|
+ } catch (const std::exception& e) {
|
|
|
|
|
+ std::cerr << "Failed to read file content: " << e.what() << std::endl;
|
|
|
|
|
+ sendErrorResponse(res, "Failed to read file content", 500, "FILE_READ_ERROR", "");
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Verify we actually read data
|
|
|
|
|
+ if (fileContent.empty()) {
|
|
|
|
|
+ std::cerr << "File content is empty after read: " << fullPath << std::endl;
|
|
|
|
|
+ sendErrorResponse(res, "File content is empty after read", 500, "EMPTY_CONTENT", "");
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
// Determine content type based on file extension
|
|
// Determine content type based on file extension
|
|
|
std::string contentType = "application/octet-stream";
|
|
std::string contentType = "application/octet-stream";
|
|
@@ -903,16 +1408,27 @@ void Server::handleDownloadOutput(const httplib::Request& req, httplib::Response
|
|
|
contentType = "video/mp4";
|
|
contentType = "video/mp4";
|
|
|
} else if (Utils::endsWith(filename, ".gif")) {
|
|
} else if (Utils::endsWith(filename, ".gif")) {
|
|
|
contentType = "image/gif";
|
|
contentType = "image/gif";
|
|
|
|
|
+ } else if (Utils::endsWith(filename, ".webp")) {
|
|
|
|
|
+ contentType = "image/webp";
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // Set response headers
|
|
|
|
|
|
|
+ // Set response headers for proper browser handling
|
|
|
res.set_header("Content-Type", contentType);
|
|
res.set_header("Content-Type", contentType);
|
|
|
- //res.set_header("Content-Disposition", "attachment; filename=\"" + filename + "\"");
|
|
|
|
|
- res.set_content(fileContent.str(), contentType);
|
|
|
|
|
|
|
+ res.set_header("Content-Length", std::to_string(fileContent.length()));
|
|
|
|
|
+ res.set_header("Cache-Control", "public, max-age=3600"); // Cache for 1 hour
|
|
|
|
|
+ res.set_header("Access-Control-Allow-Origin", "*"); // CORS for image access
|
|
|
|
|
+ // Uncomment if you want to force download instead of inline display:
|
|
|
|
|
+ // res.set_header("Content-Disposition", "attachment; filename=\"" + filename + "\"");
|
|
|
|
|
+
|
|
|
|
|
+ // Set the content
|
|
|
|
|
+ res.set_content(fileContent, contentType);
|
|
|
res.status = 200;
|
|
res.status = 200;
|
|
|
|
|
|
|
|
|
|
+ std::cout << "Successfully served image: " << filename << " (" << fileContent.length() << " bytes)" << std::endl;
|
|
|
|
|
+
|
|
|
} catch (const std::exception& e) {
|
|
} catch (const std::exception& e) {
|
|
|
- sendErrorResponse(res, std::string("Failed to download file: ") + e.what(), 500);
|
|
|
|
|
|
|
+ std::cerr << "Exception in handleDownloadOutput: " << e.what() << std::endl;
|
|
|
|
|
+ sendErrorResponse(res, std::string("Failed to download file: ") + e.what(), 500, "DOWNLOAD_ERROR", "");
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -1966,6 +2482,194 @@ void Server::handleUpscale(const httplib::Request& req, httplib::Response& res)
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+void Server::handleInpainting(const httplib::Request& req, httplib::Response& res) {
|
|
|
|
|
+ std::string requestId = generateRequestId();
|
|
|
|
|
+
|
|
|
|
|
+ try {
|
|
|
|
|
+ if (!m_generationQueue) {
|
|
|
|
|
+ sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ json requestJson = json::parse(req.body);
|
|
|
|
|
+
|
|
|
|
|
+ // Validate required fields for inpainting
|
|
|
|
|
+ if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
|
|
|
|
|
+ sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (!requestJson.contains("source_image") || !requestJson["source_image"].is_string()) {
|
|
|
|
|
+ sendErrorResponse(res, "Missing or invalid 'source_image' field", 400, "INVALID_PARAMETERS", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (!requestJson.contains("mask_image") || !requestJson["mask_image"].is_string()) {
|
|
|
|
|
+ sendErrorResponse(res, "Missing or invalid 'mask_image' field", 400, "INVALID_PARAMETERS", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Validate all parameters
|
|
|
|
|
+ auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
|
|
|
|
|
+ if (!isValid) {
|
|
|
|
|
+ sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Check if any model is loaded
|
|
|
|
|
+ if (!m_modelManager) {
|
|
|
|
|
+ sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Get currently loaded checkpoint model
|
|
|
|
|
+ auto allModels = m_modelManager->getAllModels();
|
|
|
|
|
+ std::string loadedModelName;
|
|
|
|
|
+
|
|
|
|
|
+ for (const auto& [modelName, modelInfo] : allModels) {
|
|
|
|
|
+ if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
|
|
|
|
|
+ loadedModelName = modelName;
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (loadedModelName.empty()) {
|
|
|
|
|
+ sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Load the source image
|
|
|
|
|
+ std::string sourceImageInput = requestJson["source_image"];
|
|
|
|
|
+ auto [sourceImageData, sourceImgWidth, sourceImgHeight, sourceImgChannels, sourceSuccess, sourceLoadError] = loadImageFromInput(sourceImageInput);
|
|
|
|
|
+
|
|
|
|
|
+ if (!sourceSuccess) {
|
|
|
|
|
+ sendErrorResponse(res, "Failed to load source image: " + sourceLoadError, 400, "IMAGE_LOAD_ERROR", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Load the mask image
|
|
|
|
|
+ std::string maskImageInput = requestJson["mask_image"];
|
|
|
|
|
+ auto [maskImageData, maskImgWidth, maskImgHeight, maskImgChannels, maskSuccess, maskLoadError] = loadImageFromInput(maskImageInput);
|
|
|
|
|
+
|
|
|
|
|
+ if (!maskSuccess) {
|
|
|
|
|
+ sendErrorResponse(res, "Failed to load mask image: " + maskLoadError, 400, "MASK_LOAD_ERROR", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Validate that source and mask images have compatible dimensions
|
|
|
|
|
+ if (sourceImgWidth != maskImgWidth || sourceImgHeight != maskImgHeight) {
|
|
|
|
|
+ sendErrorResponse(res, "Source and mask images must have the same dimensions", 400, "DIMENSION_MISMATCH", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Create generation request specifically for inpainting
|
|
|
|
|
+ GenerationRequest genRequest;
|
|
|
|
|
+ genRequest.id = requestId;
|
|
|
|
|
+ genRequest.requestType = GenerationRequest::RequestType::INPAINTING;
|
|
|
|
|
+ genRequest.modelName = loadedModelName; // Use the currently loaded model
|
|
|
|
|
+ genRequest.prompt = requestJson["prompt"];
|
|
|
|
|
+ genRequest.negativePrompt = requestJson.value("negative_prompt", "");
|
|
|
|
|
+ genRequest.width = requestJson.value("width", sourceImgWidth); // Default to input image dimensions
|
|
|
|
|
+ genRequest.height = requestJson.value("height", sourceImgHeight);
|
|
|
|
|
+ genRequest.batchCount = requestJson.value("batch_count", 1);
|
|
|
|
|
+ genRequest.steps = requestJson.value("steps", 20);
|
|
|
|
|
+ genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
|
|
|
|
|
+ genRequest.seed = requestJson.value("seed", "random");
|
|
|
|
|
+ genRequest.strength = requestJson.value("strength", 0.75f);
|
|
|
|
|
+
|
|
|
|
|
+ // Set source image data
|
|
|
|
|
+ genRequest.initImageData = sourceImageData;
|
|
|
|
|
+ genRequest.initImageWidth = sourceImgWidth;
|
|
|
|
|
+ genRequest.initImageHeight = sourceImgHeight;
|
|
|
|
|
+ genRequest.initImageChannels = sourceImgChannels;
|
|
|
|
|
+
|
|
|
|
|
+ // Set mask image data
|
|
|
|
|
+ genRequest.maskImageData = maskImageData;
|
|
|
|
|
+ genRequest.maskImageWidth = maskImgWidth;
|
|
|
|
|
+ genRequest.maskImageHeight = maskImgHeight;
|
|
|
|
|
+ genRequest.maskImageChannels = maskImgChannels;
|
|
|
|
|
+
|
|
|
|
|
+ // Parse optional parameters
|
|
|
|
|
+ if (requestJson.contains("sampling_method")) {
|
|
|
|
|
+ genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
|
|
|
|
|
+ }
|
|
|
|
|
+ if (requestJson.contains("scheduler")) {
|
|
|
|
|
+ genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Optional VAE model
|
|
|
|
|
+ if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
|
|
|
|
|
+ std::string vaeModelId = requestJson["vae_model"];
|
|
|
|
|
+ if (!vaeModelId.empty()) {
|
|
|
|
|
+ auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
|
|
|
|
|
+ if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
|
|
|
|
|
+ genRequest.vaePath = vaeInfo.path;
|
|
|
|
|
+ } else {
|
|
|
|
|
+ sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Optional TAESD model
|
|
|
|
|
+ if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
|
|
|
|
|
+ std::string taesdModelId = requestJson["taesd_model"];
|
|
|
|
|
+ if (!taesdModelId.empty()) {
|
|
|
|
|
+ auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
|
|
|
|
|
+ if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
|
|
|
|
|
+ genRequest.taesdPath = taesdInfo.path;
|
|
|
|
|
+ } else {
|
|
|
|
|
+ sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Enqueue request
|
|
|
|
|
+ auto future = m_generationQueue->enqueueRequest(genRequest);
|
|
|
|
|
+
|
|
|
|
|
+ json params = {
|
|
|
|
|
+ {"prompt", genRequest.prompt},
|
|
|
|
|
+ {"negative_prompt", genRequest.negativePrompt},
|
|
|
|
|
+ {"source_image", requestJson["source_image"]},
|
|
|
|
|
+ {"mask_image", requestJson["mask_image"]},
|
|
|
|
|
+ {"model", genRequest.modelName},
|
|
|
|
|
+ {"width", genRequest.width},
|
|
|
|
|
+ {"height", genRequest.height},
|
|
|
|
|
+ {"batch_count", genRequest.batchCount},
|
|
|
|
|
+ {"steps", genRequest.steps},
|
|
|
|
|
+ {"cfg_scale", genRequest.cfgScale},
|
|
|
|
|
+ {"seed", genRequest.seed},
|
|
|
|
|
+ {"strength", genRequest.strength},
|
|
|
|
|
+ {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
|
|
|
|
|
+ {"scheduler", schedulerToString(genRequest.scheduler)}
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ // Add VAE/TAESD if specified
|
|
|
|
|
+ if (!genRequest.vaePath.empty()) {
|
|
|
|
|
+ params["vae_model"] = requestJson.value("vae_model", "");
|
|
|
|
|
+ }
|
|
|
|
|
+ if (!genRequest.taesdPath.empty()) {
|
|
|
|
|
+ params["taesd_model"] = requestJson.value("taesd_model", "");
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ json response = {
|
|
|
|
|
+ {"request_id", requestId},
|
|
|
|
|
+ {"status", "queued"},
|
|
|
|
|
+ {"message", "Inpainting generation request queued successfully"},
|
|
|
|
|
+ {"queue_position", m_generationQueue->getQueueSize()},
|
|
|
|
|
+ {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
|
|
|
|
|
+ {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
|
|
|
|
|
+ {"type", "inpainting"},
|
|
|
|
|
+ {"parameters", params}
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ sendJsonResponse(res, response, 202);
|
|
|
|
|
+ } catch (const json::parse_error& e) {
|
|
|
|
|
+ sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
|
|
|
|
|
+ } catch (const std::exception& e) {
|
|
|
|
|
+ sendErrorResponse(res, std::string("Inpainting request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
// Utility endpoints
|
|
// Utility endpoints
|
|
|
void Server::handleSamplers(const httplib::Request& req, httplib::Response& res) {
|
|
void Server::handleSamplers(const httplib::Request& req, httplib::Response& res) {
|
|
|
try {
|
|
try {
|