#include #include #include #include #include #include #include #include #include #include "server.h" #include "model_manager.h" #include "generation_queue.h" #include "server_config.h" #include "logger.h" #include "user_manager.h" #include "auth_middleware.h" // Global flag for signal handling std::atomic g_running(true); // Global pointer to server instance for signal handler access Server* g_server = nullptr; // Signal handler for graceful shutdown void signalHandler(int signal) { LOG_INFO("Received signal " + std::to_string(signal) + ", shutting down gracefully..."); g_running.store(false); // Stop the server directly from signal handler if (g_server != nullptr) { g_server->stop(); } // Give a brief moment for cleanup, then force exit std::this_thread::sleep_for(std::chrono::milliseconds(100)); LOG_INFO("Exiting process..."); Logger::getInstance().close(); exit(0); } // Helper function to resolve directory path std::string resolveDirectoryPath(const std::string& path, const std::string& modelsDir) { if (path.empty()) { return ""; } std::filesystem::path dirPath(path); // If the path is absolute, use it as-is if (dirPath.is_absolute()) { return path; } // If the path is relative and models-dir is specified, prepend models-dir if (!modelsDir.empty()) { std::filesystem::path baseDir(modelsDir); return (baseDir / dirPath).string(); } // If no models-dir, return the relative path as-is return path; } // Parse command line arguments ServerConfig parseArguments(int argc, char* argv[]) { ServerConfig config; // Track which parameters were explicitly set bool modelsDirSet = false; bool checkpointsSet = false; bool controlnetSet = false; bool embeddingsSet = false; bool esrganSet = false; bool loraSet = false; bool taesdSet = false; bool vaeSet = false; for (int i = 1; i < argc; i++) { std::string arg = argv[i]; if (arg == "--host" && i + 1 < argc) { config.host = argv[++i]; } else if (arg == "--port" && i + 1 < argc) { config.port = std::stoi(argv[++i]); } else if (arg == "--models-dir" && i + 1 < argc) { config.modelsDir = argv[++i]; modelsDirSet = true; } else if (arg == "--checkpoints" && i + 1 < argc) { config.checkpoints = argv[++i]; checkpointsSet = true; } else if (arg == "--controlnet-dir" && i + 1 < argc) { config.controlnetDir = argv[++i]; controlnetSet = true; } else if (arg == "--embeddings-dir" && i + 1 < argc) { config.embeddingsDir = argv[++i]; embeddingsSet = true; } else if (arg == "--esrgan-dir" && i + 1 < argc) { config.esrganDir = argv[++i]; esrganSet = true; } else if (arg == "--lora-dir" && i + 1 < argc) { config.loraDir = argv[++i]; loraSet = true; } else if (arg == "--taesd-dir" && i + 1 < argc) { config.taesdDir = argv[++i]; taesdSet = true; } else if (arg == "--vae-dir" && i + 1 < argc) { config.vaeDir = argv[++i]; vaeSet = true; } else if (arg == "--max-concurrent" && i + 1 < argc) { config.maxConcurrentGenerations = std::stoi(argv[++i]); } else if (arg == "--queue-dir" && i + 1 < argc) { config.queueDir = argv[++i]; } else if (arg == "--output-dir" && i + 1 < argc) { config.outputDir = argv[++i]; } else if (arg == "--ui-dir" && i + 1 < argc) { config.uiDir = argv[++i]; } else if (arg == "--verbose" || arg == "-v") { config.verbose = true; } else if (arg == "--log-file" && i + 1 < argc) { config.enableFileLogging = true; config.logFilePath = argv[++i]; } else if (arg == "--enable-file-logging") { config.enableFileLogging = true; } else if ((arg == "--auth-method" || arg == "--auth") && i + 1 < argc) { std::string method = argv[++i]; if (method == "none") { config.auth.authMethod = AuthMethod::NONE; } else if (method == "jwt") { config.auth.authMethod = AuthMethod::JWT; } else if (method == "api-key") { config.auth.authMethod = AuthMethod::API_KEY; } else if (method == "unix") { config.auth.authMethod = AuthMethod::UNIX; } else if (method == "pam") { config.auth.authMethod = AuthMethod::PAM; } else if (method == "optional") { config.auth.authMethod = AuthMethod::OPTIONAL; } else { std::cerr << "Invalid auth method: " << method << std::endl; exit(1); } } else if (arg == "--jwt-secret" && i + 1 < argc) { config.auth.jwtSecret = argv[++i]; } else if (arg == "--jwt-expiration" && i + 1 < argc) { config.auth.jwtExpirationMinutes = std::stoi(argv[++i]); } else if (arg == "--enable-guest-access") { config.auth.enableGuestAccess = true; } else if (arg == "--enable-unix-auth") { // Deprecated flag - show warning and set auth method to UNIX std::cerr << "Warning: --enable-unix-auth is deprecated. Use --auth unix instead." << std::endl; config.auth.authMethod = AuthMethod::UNIX; } else if (arg == "--enable-pam-auth") { // Deprecated flag - show warning and set auth method to PAM std::cerr << "Warning: --enable-pam-auth is deprecated. Use --auth pam instead." << std::endl; config.auth.authMethod = AuthMethod::PAM; } else if (arg == "--pam-service-name" && i + 1 < argc) { config.auth.pamServiceName = argv[++i]; } else if (arg == "--auth-data-dir" && i + 1 < argc) { config.auth.dataDir = argv[++i]; } else if (arg == "--public-paths" && i + 1 < argc) { config.auth.customPublicPaths = argv[++i]; } else if (arg == "--help" || arg == "-h") { std::cout << "stable-diffusion.cpp-rest server\n" << "Usage: " << argv[0] << " [options]\n\n" << "Required Options:\n" << " --models-dir Base models directory path (required)\n" << "\n" << "Server Options:\n" << " --host Host address to bind to (default: 0.0.0.0)\n" << " --port Port number to listen on (default: 8080)\n" << " --max-concurrent Maximum concurrent generations (default: 1)\n" << " --queue-dir Queue persistence directory (default: ./queue)\n" << " --output-dir Output files directory (default: ./output)\n" << " --ui-dir Web UI static files directory (optional)\n" << " --verbose, -v Enable verbose logging\n" << " --enable-file-logging Enable logging to file\n" << " --log-file Log file path (default: /var/log/stable-diffusion-rest/server.log)\n" << "\n" << "Authentication Options:\n" << " --auth Authentication method (none, jwt, api-key, unix, pam, optional)\n" << " --auth-method Authentication method (alias for --auth)\n" << " --jwt-secret JWT secret key (auto-generated if not provided)\n" << " --jwt-expiration JWT token expiration time (default: 60)\n" << " --enable-guest-access Allow unauthenticated guest access\n" << " --pam-service-name PAM service name (default: stable-diffusion-rest)\n" << " --auth-data-dir Directory for authentication data (default: ./auth)\n" << " --public-paths Comma-separated list of public paths (default: /api/health,/api/status)\n" << "\n" << "Deprecated Options (will be removed in future version):\n" << " --enable-unix-auth Deprecated: Use --auth unix instead\n" << " --enable-pam-auth Deprecated: Use --auth pam instead\n" << "\n" << "Model Directory Options:\n" << " All model directories are optional and default to standard folder names\n" << " under --models-dir. Only specify these if your folder names differ.\n" << "\n" << " --checkpoints Checkpoints directory (default: checkpoints)\n" << " --controlnet-dir ControlNet models directory (default: controlnet)\n" << " --embeddings-dir Embeddings directory (default: embeddings)\n" << " --esrgan-dir ESRGAN models directory (default: ESRGAN)\n" << " --lora-dir LoRA models directory (default: loras)\n" << " --taesd-dir TAESD models directory (default: TAESD)\n" << " --vae-dir VAE models directory (default: vae)\n" << "\n" << "Other Options:\n" << " --help, -h Show this help message\n" << "\n" << "Path Resolution:\n" << " - Absolute paths are used as-is\n" << " - Relative paths are resolved relative to --models-dir\n" << " - Default folder names match standard SD model structure\n" << "\n" << "Examples:\n" << " # Use all defaults (requires standard folder structure)\n" << " " << argv[0] << " --models-dir /data/SD_MODELS\n" << "\n" << " # Override specific folders\n" << " " << argv[0] << " --models-dir /data/SD_MODELS --checkpoints my_checkpoints\n" << "\n" << " # Use absolute path for one folder\n" << " " << argv[0] << " --models-dir /data/SD_MODELS --lora-dir /other/path/loras\n" << std::endl; exit(0); } else { std::cerr << "Unknown argument: " << arg << std::endl; std::cerr << "Use --help for usage information" << std::endl; exit(1); } } // Validate required parameters if (!modelsDirSet) { std::cerr << "Error: --models-dir is required" << std::endl; std::cerr << "Use --help for usage information" << std::endl; exit(1); } // Set defaults for model directories if not explicitly set if (!checkpointsSet) { config.checkpoints = "checkpoints"; } if (!controlnetSet) { config.controlnetDir = "controlnet"; } if (!embeddingsSet) { config.embeddingsDir = "embeddings"; } if (!esrganSet) { config.esrganDir = "ESRGAN"; } if (!loraSet) { config.loraDir = "loras"; } if (!taesdSet) { config.taesdDir = "TAESD"; } if (!vaeSet) { config.vaeDir = "vae"; } // Resolve all directory paths (absolute paths used as-is, relative resolved from models-dir) config.checkpoints = resolveDirectoryPath(config.checkpoints, config.modelsDir); config.controlnetDir = resolveDirectoryPath(config.controlnetDir, config.modelsDir); config.embeddingsDir = resolveDirectoryPath(config.embeddingsDir, config.modelsDir); config.esrganDir = resolveDirectoryPath(config.esrganDir, config.modelsDir); config.loraDir = resolveDirectoryPath(config.loraDir, config.modelsDir); config.taesdDir = resolveDirectoryPath(config.taesdDir, config.modelsDir); config.vaeDir = resolveDirectoryPath(config.vaeDir, config.modelsDir); return config; } int main(int argc, char* argv[]) { // Parse command line arguments ServerConfig config = parseArguments(argc, argv); // Initialize logger LogLevel minLevel = config.verbose ? LogLevel::DEBUG : LogLevel::INFO; Logger::getInstance().initialize(config.enableFileLogging, config.logFilePath, minLevel); // Create log directory if file logging is enabled if (config.enableFileLogging) { try { std::filesystem::path logPath(config.logFilePath); std::filesystem::create_directories(logPath.parent_path()); } catch (const std::filesystem::filesystem_error& e) { LOG_ERROR("Failed to create log directory: " + std::string(e.what())); } } LOG_INFO("=== Stable Diffusion REST Server Starting ==="); if (config.enableFileLogging) { LOG_INFO("File logging enabled: " + config.logFilePath); } // Create queue and output directories if they don't exist try { std::filesystem::create_directories(config.queueDir); std::filesystem::create_directories(config.outputDir); } catch (const std::filesystem::filesystem_error& e) { LOG_WARNING("Failed to create directories: " + std::string(e.what())); } if (config.verbose) { std::cout << "\n=== Configuration ===" << std::endl; std::cout << "Server:" << std::endl; std::cout << " Host: " << config.host << std::endl; std::cout << " Port: " << config.port << std::endl; std::cout << " Max concurrent generations: " << config.maxConcurrentGenerations << std::endl; std::cout << " Queue directory: " << config.queueDir << std::endl; std::cout << " Output directory: " << config.outputDir << std::endl; std::cout << "\nModel Directories:" << std::endl; std::cout << " Base models directory: " << config.modelsDir << std::endl; std::cout << " Checkpoints: " << config.checkpoints << std::endl; std::cout << " ControlNet: " << config.controlnetDir << std::endl; std::cout << " Embeddings: " << config.embeddingsDir << std::endl; std::cout << " ESRGAN: " << config.esrganDir << std::endl; std::cout << " LoRA: " << config.loraDir << std::endl; std::cout << " TAESD: " << config.taesdDir << std::endl; std::cout << " VAE: " << config.vaeDir << std::endl; std::cout << std::endl; } // Validate directory paths auto validateDirectory = [](const std::string& path, const std::string& name, bool required) -> bool { if (path.empty()) { if (required) { std::cerr << "Error: " << name << " directory is required but not specified" << std::endl; return false; } return true; // Empty path is valid for optional directories } std::filesystem::path dirPath(path); if (!std::filesystem::exists(dirPath)) { if (required) { std::cerr << "Error: " << name << " directory does not exist: " << path << std::endl; return false; } else { std::cerr << "Warning: " << name << " directory does not exist: " << path << std::endl; return true; // Optional directory can be missing } } if (!std::filesystem::is_directory(dirPath)) { std::cerr << "Error: " << name << " path is not a directory: " << path << std::endl; return false; } return true; }; // Validate required directories bool allValid = true; // Validate base models directory (required - must exist) if (!validateDirectory(config.modelsDir, "Base models", true)) { allValid = false; } // Validate all model directories (will warn but not fail if missing) validateDirectory(config.checkpoints, "Checkpoints", false); validateDirectory(config.controlnetDir, "ControlNet", false); validateDirectory(config.embeddingsDir, "Embeddings", false); validateDirectory(config.esrganDir, "ESRGAN", false); validateDirectory(config.loraDir, "LoRA", false); validateDirectory(config.taesdDir, "TAESD", false); validateDirectory(config.vaeDir, "VAE", false); // Validate UI directory if specified if (!config.uiDir.empty()) { if (!validateDirectory(config.uiDir, "Web UI", false)) { std::cerr << "\nError: Web UI directory is invalid" << std::endl; return 1; } // Check if the UI directory is readable std::filesystem::path uiPath(config.uiDir); if (!std::filesystem::exists(uiPath / "index.html") && !std::filesystem::exists(uiPath / "index.htm")) { std::cerr << "Warning: Web UI directory does not contain an index.html or index.htm file: " << config.uiDir << std::endl; } } if (!allValid) { std::cerr << "\nError: Base models directory is invalid or missing" << std::endl; return 1; } // Set up signal handlers for graceful shutdown signal(SIGINT, signalHandler); // Ctrl+C signal(SIGTERM, signalHandler); // Termination signal try { // Initialize authentication system if (config.verbose) { std::cout << "Initializing authentication system..." << std::endl; } auto userManager = std::make_shared(config.auth.dataDir, static_cast(config.auth.authMethod)); if (!userManager->initialize()) { std::cerr << "Error: Failed to initialize user manager" << std::endl; return 1; } if (config.verbose) { std::cout << "User manager initialized" << std::endl; std::cout << "Authentication method: "; switch (config.auth.authMethod) { case AuthMethod::NONE: std::cout << "None"; break; case AuthMethod::JWT: std::cout << "JWT"; break; case AuthMethod::API_KEY: std::cout << "API Key"; break; case AuthMethod::UNIX: std::cout << "Unix"; break; case AuthMethod::OPTIONAL: std::cout << "Optional"; break; } std::cout << std::endl; } // Initialize authentication middleware auto authMiddleware = std::make_shared(config.auth, userManager); if (!authMiddleware->initialize()) { std::cerr << "Error: Failed to initialize authentication middleware" << std::endl; return 1; } // Initialize components auto modelManager = std::make_unique(); auto generationQueue = std::make_unique(modelManager.get(), config.maxConcurrentGenerations, config.queueDir, config.outputDir); auto server = std::make_unique(modelManager.get(), generationQueue.get(), config.outputDir, config.uiDir); // Set authentication components in server server->setAuthComponents(userManager, authMiddleware); // Set global server pointer for signal handler access g_server = server.get(); // Configure model manager with directory parameters if (config.verbose) { std::cout << "Configuring model manager..." << std::endl; } if (!modelManager->configureFromServerConfig(config)) { std::cerr << "Error: Failed to configure model manager with server config" << std::endl; return 1; } if (config.verbose) { std::cout << "Model manager configured with per-type directories" << std::endl; } // Scan models directory if (config.verbose) { std::cout << "Scanning models directory..." << std::endl; } if (!modelManager->scanModelsDirectory()) { std::cerr << "Warning: Failed to scan models directory" << std::endl; } else if (config.verbose) { std::cout << "Found " << modelManager->getAvailableModelsCount() << " models" << std::endl; } // Start the generation queue generationQueue->start(); if (config.verbose) { std::cout << "Generation queue started" << std::endl; } // Start the HTTP server if (!server->start(config.host, config.port)) { std::cerr << "Failed to start server" << std::endl; return 1; } std::cout << "Server initialized successfully" << std::endl; std::cout << "Server listening on " << config.host << ":" << config.port << std::endl; std::cout << "Press Ctrl+C to stop the server" << std::endl; // Give server a moment to start std::this_thread::sleep_for(std::chrono::milliseconds(100)); // Main server loop - wait for shutdown signal while (g_running.load() && server->isRunning()) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); } // Graceful shutdown std::cout << "Shutting down server..." << std::endl; // Stop the server first to stop accepting new requests server->stop(); // Stop the generation queue generationQueue->stop(); // Wait for server thread to finish server->waitForStop(); // Clear global server pointer g_server = nullptr; std::cout << "Server shutdown complete" << std::endl; } catch (const std::exception& e) { std::cerr << "Error: " << e.what() << std::endl; return 1; } return 0; }