| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510 |
- #include <iostream>
- #include <memory>
- #include <signal.h>
- #include <string>
- #include <atomic>
- #include <thread>
- #include <chrono>
- #include <filesystem>
- #include <algorithm>
- #include "server.h"
- #include "model_manager.h"
- #include "generation_queue.h"
- #include "server_config.h"
- #include "logger.h"
- #include "user_manager.h"
- #include "auth_middleware.h"
- // Global flag for signal handling
- std::atomic<bool> g_running(true);
- // Global pointer to server instance for signal handler access
- Server* g_server = nullptr;
- // Signal handler for graceful shutdown
- void signalHandler(int signal) {
- LOG_INFO("Received signal " + std::to_string(signal) + ", shutting down gracefully...");
- g_running.store(false);
- // Stop the server directly from signal handler
- if (g_server != nullptr) {
- g_server->stop();
- }
- // Give a brief moment for cleanup, then force exit
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
- LOG_INFO("Exiting process...");
- Logger::getInstance().close();
- exit(0);
- }
- // Helper function to resolve directory path
- std::string resolveDirectoryPath(const std::string& path, const std::string& modelsDir) {
- if (path.empty()) {
- return "";
- }
- std::filesystem::path dirPath(path);
- // If the path is absolute, use it as-is
- if (dirPath.is_absolute()) {
- return path;
- }
- // If the path is relative and models-dir is specified, prepend models-dir
- if (!modelsDir.empty()) {
- std::filesystem::path baseDir(modelsDir);
- return (baseDir / dirPath).string();
- }
- // If no models-dir, return the relative path as-is
- return path;
- }
- // Parse command line arguments
- ServerConfig parseArguments(int argc, char* argv[]) {
- ServerConfig config;
- // Track which parameters were explicitly set
- bool modelsDirSet = false;
- bool checkpointsSet = false;
- bool controlnetSet = false;
- bool embeddingsSet = false;
- bool esrganSet = false;
- bool loraSet = false;
- bool taesdSet = false;
- bool vaeSet = false;
- for (int i = 1; i < argc; i++) {
- std::string arg = argv[i];
- if (arg == "--host" && i + 1 < argc) {
- config.host = argv[++i];
- } else if (arg == "--port" && i + 1 < argc) {
- config.port = std::stoi(argv[++i]);
- } else if (arg == "--models-dir" && i + 1 < argc) {
- config.modelsDir = argv[++i];
- modelsDirSet = true;
- } else if (arg == "--checkpoints" && i + 1 < argc) {
- config.checkpoints = argv[++i];
- checkpointsSet = true;
- } else if (arg == "--controlnet-dir" && i + 1 < argc) {
- config.controlnetDir = argv[++i];
- controlnetSet = true;
- } else if (arg == "--embeddings-dir" && i + 1 < argc) {
- config.embeddingsDir = argv[++i];
- embeddingsSet = true;
- } else if (arg == "--esrgan-dir" && i + 1 < argc) {
- config.esrganDir = argv[++i];
- esrganSet = true;
- } else if (arg == "--lora-dir" && i + 1 < argc) {
- config.loraDir = argv[++i];
- loraSet = true;
- } else if (arg == "--taesd-dir" && i + 1 < argc) {
- config.taesdDir = argv[++i];
- taesdSet = true;
- } else if (arg == "--vae-dir" && i + 1 < argc) {
- config.vaeDir = argv[++i];
- vaeSet = true;
- } else if (arg == "--max-concurrent" && i + 1 < argc) {
- config.maxConcurrentGenerations = std::stoi(argv[++i]);
- } else if (arg == "--queue-dir" && i + 1 < argc) {
- config.queueDir = argv[++i];
- } else if (arg == "--output-dir" && i + 1 < argc) {
- config.outputDir = argv[++i];
- } else if (arg == "--ui-dir" && i + 1 < argc) {
- config.uiDir = argv[++i];
- } else if (arg == "--verbose" || arg == "-v") {
- config.verbose = true;
- } else if (arg == "--log-file" && i + 1 < argc) {
- config.enableFileLogging = true;
- config.logFilePath = argv[++i];
- } else if (arg == "--enable-file-logging") {
- config.enableFileLogging = true;
- } else if ((arg == "--auth-method" || arg == "--auth") && i + 1 < argc) {
- std::string method = argv[++i];
- if (method == "none") {
- config.auth.authMethod = AuthMethod::NONE;
- } else if (method == "jwt") {
- config.auth.authMethod = AuthMethod::JWT;
- } else if (method == "api-key") {
- config.auth.authMethod = AuthMethod::API_KEY;
- } else if (method == "unix") {
- config.auth.authMethod = AuthMethod::UNIX;
- } else if (method == "pam") {
- config.auth.authMethod = AuthMethod::PAM;
- } else if (method == "optional") {
- config.auth.authMethod = AuthMethod::OPTIONAL;
- } else {
- std::cerr << "Invalid auth method: " << method << std::endl;
- exit(1);
- }
- } else if (arg == "--jwt-secret" && i + 1 < argc) {
- config.auth.jwtSecret = argv[++i];
- } else if (arg == "--jwt-expiration" && i + 1 < argc) {
- config.auth.jwtExpirationMinutes = std::stoi(argv[++i]);
- } else if (arg == "--enable-guest-access") {
- config.auth.enableGuestAccess = true;
- } else if (arg == "--enable-unix-auth") {
- // Deprecated flag - show warning and set auth method to UNIX
- std::cerr << "Warning: --enable-unix-auth is deprecated. Use --auth unix instead." << std::endl;
- config.auth.authMethod = AuthMethod::UNIX;
- } else if (arg == "--enable-pam-auth") {
- // Deprecated flag - show warning and set auth method to PAM
- std::cerr << "Warning: --enable-pam-auth is deprecated. Use --auth pam instead." << std::endl;
- config.auth.authMethod = AuthMethod::PAM;
- } else if (arg == "--pam-service-name" && i + 1 < argc) {
- config.auth.pamServiceName = argv[++i];
- } else if (arg == "--auth-data-dir" && i + 1 < argc) {
- config.auth.dataDir = argv[++i];
- } else if (arg == "--public-paths" && i + 1 < argc) {
- config.auth.customPublicPaths = argv[++i];
- } else if (arg == "--help" || arg == "-h") {
- std::cout << "stable-diffusion.cpp-rest server\n"
- << "Usage: " << argv[0] << " [options]\n\n"
- << "Required Options:\n"
- << " --models-dir <dir> Base models directory path (required)\n"
- << "\n"
- << "Server Options:\n"
- << " --host <host> Host address to bind to (default: 0.0.0.0)\n"
- << " --port <port> Port number to listen on (default: 8080)\n"
- << " --max-concurrent <num> Maximum concurrent generations (default: 1)\n"
- << " --queue-dir <dir> Queue persistence directory (default: ./queue)\n"
- << " --output-dir <dir> Output files directory (default: ./output)\n"
- << " --ui-dir <dir> Web UI static files directory (optional)\n"
- << " --verbose, -v Enable verbose logging\n"
- << " --enable-file-logging Enable logging to file\n"
- << " --log-file <path> Log file path (default: /var/log/stable-diffusion-rest/server.log)\n"
- << "\n"
- << "Authentication Options:\n"
- << " --auth <method> Authentication method (none, jwt, api-key, unix, pam, optional)\n"
- << " --auth-method <method> Authentication method (alias for --auth)\n"
- << " --jwt-secret <secret> JWT secret key (auto-generated if not provided)\n"
- << " --jwt-expiration <minutes> JWT token expiration time (default: 60)\n"
- << " --enable-guest-access Allow unauthenticated guest access\n"
- << " --pam-service-name <name> PAM service name (default: stable-diffusion-rest)\n"
- << " --auth-data-dir <dir> Directory for authentication data (default: ./auth)\n"
- << " --public-paths <paths> Comma-separated list of public paths (default: /api/health,/api/status)\n"
- << "\n"
- << "Deprecated Options (will be removed in future version):\n"
- << " --enable-unix-auth Deprecated: Use --auth unix instead\n"
- << " --enable-pam-auth Deprecated: Use --auth pam instead\n"
- << "\n"
- << "Model Directory Options:\n"
- << " All model directories are optional and default to standard folder names\n"
- << " under --models-dir. Only specify these if your folder names differ.\n"
- << "\n"
- << " --checkpoints <dir> Checkpoints directory (default: checkpoints)\n"
- << " --controlnet-dir <dir> ControlNet models directory (default: controlnet)\n"
- << " --embeddings-dir <dir> Embeddings directory (default: embeddings)\n"
- << " --esrgan-dir <dir> ESRGAN models directory (default: ESRGAN)\n"
- << " --lora-dir <dir> LoRA models directory (default: loras)\n"
- << " --taesd-dir <dir> TAESD models directory (default: TAESD)\n"
- << " --vae-dir <dir> VAE models directory (default: vae)\n"
- << "\n"
- << "Other Options:\n"
- << " --help, -h Show this help message\n"
- << "\n"
- << "Path Resolution:\n"
- << " - Absolute paths are used as-is\n"
- << " - Relative paths are resolved relative to --models-dir\n"
- << " - Default folder names match standard SD model structure\n"
- << "\n"
- << "Examples:\n"
- << " # Use all defaults (requires standard folder structure)\n"
- << " " << argv[0] << " --models-dir /data/SD_MODELS\n"
- << "\n"
- << " # Override specific folders\n"
- << " " << argv[0] << " --models-dir /data/SD_MODELS --checkpoints my_checkpoints\n"
- << "\n"
- << " # Use absolute path for one folder\n"
- << " " << argv[0] << " --models-dir /data/SD_MODELS --lora-dir /other/path/loras\n"
- << std::endl;
- exit(0);
- } else {
- std::cerr << "Unknown argument: " << arg << std::endl;
- std::cerr << "Use --help for usage information" << std::endl;
- exit(1);
- }
- }
- // Validate required parameters
- if (!modelsDirSet) {
- std::cerr << "Error: --models-dir is required" << std::endl;
- std::cerr << "Use --help for usage information" << std::endl;
- exit(1);
- }
- // Set defaults for model directories if not explicitly set
- if (!checkpointsSet) {
- config.checkpoints = "checkpoints";
- }
- if (!controlnetSet) {
- config.controlnetDir = "controlnet";
- }
- if (!embeddingsSet) {
- config.embeddingsDir = "embeddings";
- }
- if (!esrganSet) {
- config.esrganDir = "ESRGAN";
- }
- if (!loraSet) {
- config.loraDir = "loras";
- }
- if (!taesdSet) {
- config.taesdDir = "TAESD";
- }
- if (!vaeSet) {
- config.vaeDir = "vae";
- }
- // Resolve all directory paths (absolute paths used as-is, relative resolved from models-dir)
- config.checkpoints = resolveDirectoryPath(config.checkpoints, config.modelsDir);
- config.controlnetDir = resolveDirectoryPath(config.controlnetDir, config.modelsDir);
- config.embeddingsDir = resolveDirectoryPath(config.embeddingsDir, config.modelsDir);
- config.esrganDir = resolveDirectoryPath(config.esrganDir, config.modelsDir);
- config.loraDir = resolveDirectoryPath(config.loraDir, config.modelsDir);
- config.taesdDir = resolveDirectoryPath(config.taesdDir, config.modelsDir);
- config.vaeDir = resolveDirectoryPath(config.vaeDir, config.modelsDir);
- return config;
- }
- int main(int argc, char* argv[]) {
- // Parse command line arguments
- ServerConfig config = parseArguments(argc, argv);
- // Initialize logger
- LogLevel minLevel = config.verbose ? LogLevel::DEBUG : LogLevel::INFO;
- Logger::getInstance().initialize(config.enableFileLogging, config.logFilePath, minLevel);
- // Create log directory if file logging is enabled
- if (config.enableFileLogging) {
- try {
- std::filesystem::path logPath(config.logFilePath);
- std::filesystem::create_directories(logPath.parent_path());
- } catch (const std::filesystem::filesystem_error& e) {
- LOG_ERROR("Failed to create log directory: " + std::string(e.what()));
- }
- }
- LOG_INFO("=== Stable Diffusion REST Server Starting ===");
- if (config.enableFileLogging) {
- LOG_INFO("File logging enabled: " + config.logFilePath);
- }
- // Create queue and output directories if they don't exist
- try {
- std::filesystem::create_directories(config.queueDir);
- std::filesystem::create_directories(config.outputDir);
- } catch (const std::filesystem::filesystem_error& e) {
- LOG_WARNING("Failed to create directories: " + std::string(e.what()));
- }
- if (config.verbose) {
- std::cout << "\n=== Configuration ===" << std::endl;
- std::cout << "Server:" << std::endl;
- std::cout << " Host: " << config.host << std::endl;
- std::cout << " Port: " << config.port << std::endl;
- std::cout << " Max concurrent generations: " << config.maxConcurrentGenerations << std::endl;
- std::cout << " Queue directory: " << config.queueDir << std::endl;
- std::cout << " Output directory: " << config.outputDir << std::endl;
- std::cout << "\nModel Directories:" << std::endl;
- std::cout << " Base models directory: " << config.modelsDir << std::endl;
- std::cout << " Checkpoints: " << config.checkpoints << std::endl;
- std::cout << " ControlNet: " << config.controlnetDir << std::endl;
- std::cout << " Embeddings: " << config.embeddingsDir << std::endl;
- std::cout << " ESRGAN: " << config.esrganDir << std::endl;
- std::cout << " LoRA: " << config.loraDir << std::endl;
- std::cout << " TAESD: " << config.taesdDir << std::endl;
- std::cout << " VAE: " << config.vaeDir << std::endl;
- std::cout << std::endl;
- }
- // Validate directory paths
- auto validateDirectory = [](const std::string& path, const std::string& name, bool required) -> bool {
- if (path.empty()) {
- if (required) {
- std::cerr << "Error: " << name << " directory is required but not specified" << std::endl;
- return false;
- }
- return true; // Empty path is valid for optional directories
- }
- std::filesystem::path dirPath(path);
- if (!std::filesystem::exists(dirPath)) {
- if (required) {
- std::cerr << "Error: " << name << " directory does not exist: " << path << std::endl;
- return false;
- } else {
- std::cerr << "Warning: " << name << " directory does not exist: " << path << std::endl;
- return true; // Optional directory can be missing
- }
- }
- if (!std::filesystem::is_directory(dirPath)) {
- std::cerr << "Error: " << name << " path is not a directory: " << path << std::endl;
- return false;
- }
- return true;
- };
- // Validate required directories
- bool allValid = true;
- // Validate base models directory (required - must exist)
- if (!validateDirectory(config.modelsDir, "Base models", true)) {
- allValid = false;
- }
- // Validate all model directories (will warn but not fail if missing)
- validateDirectory(config.checkpoints, "Checkpoints", false);
- validateDirectory(config.controlnetDir, "ControlNet", false);
- validateDirectory(config.embeddingsDir, "Embeddings", false);
- validateDirectory(config.esrganDir, "ESRGAN", false);
- validateDirectory(config.loraDir, "LoRA", false);
- validateDirectory(config.taesdDir, "TAESD", false);
- validateDirectory(config.vaeDir, "VAE", false);
- // Validate UI directory if specified
- if (!config.uiDir.empty()) {
- if (!validateDirectory(config.uiDir, "Web UI", false)) {
- std::cerr << "\nError: Web UI directory is invalid" << std::endl;
- return 1;
- }
- // Check if the UI directory is readable
- std::filesystem::path uiPath(config.uiDir);
- if (!std::filesystem::exists(uiPath / "index.html") &&
- !std::filesystem::exists(uiPath / "index.htm")) {
- std::cerr << "Warning: Web UI directory does not contain an index.html or index.htm file: " << config.uiDir << std::endl;
- }
- }
- if (!allValid) {
- std::cerr << "\nError: Base models directory is invalid or missing" << std::endl;
- return 1;
- }
- // Set up signal handlers for graceful shutdown
- signal(SIGINT, signalHandler); // Ctrl+C
- signal(SIGTERM, signalHandler); // Termination signal
- try {
- // Initialize authentication system
- if (config.verbose) {
- std::cout << "Initializing authentication system..." << std::endl;
- }
- auto userManager = std::make_shared<UserManager>(config.auth.dataDir,
- static_cast<UserManager::AuthMethod>(config.auth.authMethod));
- if (!userManager->initialize()) {
- std::cerr << "Error: Failed to initialize user manager" << std::endl;
- return 1;
- }
- if (config.verbose) {
- std::cout << "User manager initialized" << std::endl;
- std::cout << "Authentication method: ";
- switch (config.auth.authMethod) {
- case AuthMethod::NONE: std::cout << "None"; break;
- case AuthMethod::JWT: std::cout << "JWT"; break;
- case AuthMethod::API_KEY: std::cout << "API Key"; break;
- case AuthMethod::UNIX: std::cout << "Unix"; break;
- case AuthMethod::OPTIONAL: std::cout << "Optional"; break;
- }
- std::cout << std::endl;
- }
- // Initialize authentication middleware
- auto authMiddleware = std::make_shared<AuthMiddleware>(config.auth, userManager);
- if (!authMiddleware->initialize()) {
- std::cerr << "Error: Failed to initialize authentication middleware" << std::endl;
- return 1;
- }
- // Initialize components
- auto modelManager = std::make_unique<ModelManager>();
- auto generationQueue = std::make_unique<GenerationQueue>(modelManager.get(), config.maxConcurrentGenerations,
- config.queueDir, config.outputDir);
- auto server = std::make_unique<Server>(modelManager.get(), generationQueue.get(), config.outputDir, config.uiDir);
- // Set authentication components in server
- server->setAuthComponents(userManager, authMiddleware);
- // Set global server pointer for signal handler access
- g_server = server.get();
- // Configure model manager with directory parameters
- if (config.verbose) {
- std::cout << "Configuring model manager..." << std::endl;
- }
- if (!modelManager->configureFromServerConfig(config)) {
- std::cerr << "Error: Failed to configure model manager with server config" << std::endl;
- return 1;
- }
- if (config.verbose) {
- std::cout << "Model manager configured with per-type directories" << std::endl;
- }
- // Scan models directory
- if (config.verbose) {
- std::cout << "Scanning models directory..." << std::endl;
- }
- if (!modelManager->scanModelsDirectory()) {
- std::cerr << "Warning: Failed to scan models directory" << std::endl;
- } else if (config.verbose) {
- std::cout << "Found " << modelManager->getAvailableModelsCount() << " models" << std::endl;
- }
- // Start the generation queue
- generationQueue->start();
- if (config.verbose) {
- std::cout << "Generation queue started" << std::endl;
- }
- // Start the HTTP server
- if (!server->start(config.host, config.port)) {
- std::cerr << "Failed to start server" << std::endl;
- return 1;
- }
- std::cout << "Server initialized successfully" << std::endl;
- std::cout << "Server listening on " << config.host << ":" << config.port << std::endl;
- std::cout << "Press Ctrl+C to stop the server" << std::endl;
- // Give server a moment to start
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
- // Main server loop - wait for shutdown signal
- while (g_running.load() && server->isRunning()) {
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
- }
- // Graceful shutdown
- std::cout << "Shutting down server..." << std::endl;
- // Stop the server first to stop accepting new requests
- server->stop();
- // Stop the generation queue
- generationQueue->stop();
- // Wait for server thread to finish
- server->waitForStop();
- // Clear global server pointer
- g_server = nullptr;
- std::cout << "Server shutdown complete" << std::endl;
- } catch (const std::exception& e) {
- std::cerr << "Error: " << e.what() << std::endl;
- return 1;
- }
- return 0;
- }
|