main.cpp 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. #include <signal.h>
  2. #include <algorithm>
  3. #include <atomic>
  4. #include <chrono>
  5. #include <filesystem>
  6. #include <iostream>
  7. #include <memory>
  8. #include <string>
  9. #include <thread>
  10. #include "auth_middleware.h"
  11. #include "generation_queue.h"
  12. #include "logger.h"
  13. #include "model_manager.h"
  14. #include "server.h"
  15. #include "server_config.h"
  16. #include "user_manager.h"
  17. // Global flag for signal handling
  18. std::atomic<bool> g_running(true);
  19. // Global pointer to server instance for signal handler access
  20. Server* g_server = nullptr;
  21. // Signal handler for graceful shutdown
  22. void signalHandler(int signal) {
  23. LOG_INFO("Received signal " + std::to_string(signal) + ", shutting down gracefully...");
  24. g_running.store(false);
  25. // Stop the server directly from signal handler
  26. if (g_server != nullptr) {
  27. g_server->stop();
  28. }
  29. // Give a brief moment for cleanup, then force exit
  30. std::this_thread::sleep_for(std::chrono::milliseconds(100));
  31. LOG_INFO("Exiting process...");
  32. Logger::getInstance().close();
  33. exit(0);
  34. }
  35. // Helper function to resolve directory path
  36. std::string resolveDirectoryPath(const std::string& path, const std::string& modelsDir) {
  37. if (path.empty()) {
  38. return "";
  39. }
  40. std::filesystem::path dirPath(path);
  41. // If the path is absolute, use it as-is
  42. if (dirPath.is_absolute()) {
  43. return path;
  44. }
  45. // If the path is relative and models-dir is specified, prepend models-dir
  46. if (!modelsDir.empty()) {
  47. std::filesystem::path baseDir(modelsDir);
  48. return (baseDir / dirPath).string();
  49. }
  50. // If no models-dir, return the relative path as-is
  51. return path;
  52. }
  53. // Parse command line arguments
  54. ServerConfig parseArguments(int argc, char* argv[]) {
  55. ServerConfig config;
  56. // Track which parameters were explicitly set
  57. bool modelsDirSet = false;
  58. bool checkpointsSet = false;
  59. bool controlnetSet = false;
  60. bool embeddingsSet = false;
  61. bool esrganSet = false;
  62. bool loraSet = false;
  63. bool taesdSet = false;
  64. bool vaeSet = false;
  65. for (int i = 1; i < argc; i++) {
  66. std::string arg = argv[i];
  67. if (arg == "--host" && i + 1 < argc) {
  68. config.host = argv[++i];
  69. } else if (arg == "--port" && i + 1 < argc) {
  70. config.port = std::stoi(argv[++i]);
  71. } else if (arg == "--models-dir" && i + 1 < argc) {
  72. config.modelsDir = argv[++i];
  73. modelsDirSet = true;
  74. } else if (arg == "--checkpoints" && i + 1 < argc) {
  75. config.checkpoints = argv[++i];
  76. checkpointsSet = true;
  77. } else if (arg == "--controlnet-dir" && i + 1 < argc) {
  78. config.controlnetDir = argv[++i];
  79. controlnetSet = true;
  80. } else if (arg == "--embeddings-dir" && i + 1 < argc) {
  81. config.embeddingsDir = argv[++i];
  82. embeddingsSet = true;
  83. } else if (arg == "--esrgan-dir" && i + 1 < argc) {
  84. config.esrganDir = argv[++i];
  85. esrganSet = true;
  86. } else if (arg == "--lora-dir" && i + 1 < argc) {
  87. config.loraDir = argv[++i];
  88. loraSet = true;
  89. } else if (arg == "--taesd-dir" && i + 1 < argc) {
  90. config.taesdDir = argv[++i];
  91. taesdSet = true;
  92. } else if (arg == "--vae-dir" && i + 1 < argc) {
  93. config.vaeDir = argv[++i];
  94. vaeSet = true;
  95. } else if (arg == "--max-concurrent" && i + 1 < argc) {
  96. config.maxConcurrentGenerations = std::stoi(argv[++i]);
  97. } else if (arg == "--queue-dir" && i + 1 < argc) {
  98. config.queueDir = argv[++i];
  99. } else if (arg == "--output-dir" && i + 1 < argc) {
  100. config.outputDir = argv[++i];
  101. } else if (arg == "--ui-dir" && i + 1 < argc) {
  102. config.uiDir = argv[++i];
  103. } else if (arg == "--verbose" || arg == "-v") {
  104. config.verbose = true;
  105. } else if (arg == "--log-file" && i + 1 < argc) {
  106. config.enableFileLogging = true;
  107. config.logFilePath = argv[++i];
  108. } else if (arg == "--enable-file-logging") {
  109. config.enableFileLogging = true;
  110. } else if ((arg == "--auth-method" || arg == "--auth") && i + 1 < argc) {
  111. std::string method = argv[++i];
  112. if (method == "none") {
  113. config.auth.authMethod = AuthMethod::NONE;
  114. } else if (method == "jwt") {
  115. config.auth.authMethod = AuthMethod::JWT;
  116. } else if (method == "api-key") {
  117. config.auth.authMethod = AuthMethod::API_KEY;
  118. } else if (method == "unix") {
  119. config.auth.authMethod = AuthMethod::UNIX;
  120. } else if (method == "pam") {
  121. config.auth.authMethod = AuthMethod::PAM;
  122. } else if (method == "optional") {
  123. config.auth.authMethod = AuthMethod::OPTIONAL;
  124. } else {
  125. std::cerr << "Invalid auth method: " << method << std::endl;
  126. exit(1);
  127. }
  128. } else if (arg == "--jwt-secret" && i + 1 < argc) {
  129. config.auth.jwtSecret = argv[++i];
  130. } else if (arg == "--jwt-expiration" && i + 1 < argc) {
  131. config.auth.jwtExpirationMinutes = std::stoi(argv[++i]);
  132. } else if (arg == "--enable-guest-access") {
  133. config.auth.enableGuestAccess = true;
  134. } else if (arg == "--enable-unix-auth") {
  135. // Deprecated flag - show warning and set auth method to UNIX
  136. std::cerr << "Warning: --enable-unix-auth is deprecated. Use --auth unix instead." << std::endl;
  137. config.auth.authMethod = AuthMethod::UNIX;
  138. } else if (arg == "--enable-pam-auth") {
  139. // Deprecated flag - show warning and set auth method to PAM
  140. std::cerr << "Warning: --enable-pam-auth is deprecated. Use --auth pam instead." << std::endl;
  141. config.auth.authMethod = AuthMethod::PAM;
  142. } else if (arg == "--pam-service-name" && i + 1 < argc) {
  143. config.auth.pamServiceName = argv[++i];
  144. } else if (arg == "--auth-data-dir" && i + 1 < argc) {
  145. config.auth.dataDir = argv[++i];
  146. } else if (arg == "--public-paths" && i + 1 < argc) {
  147. config.auth.customPublicPaths = argv[++i];
  148. } else if (arg == "--help" || arg == "-h") {
  149. std::cout << "stable-diffusion.cpp-rest server\n"
  150. << "Usage: " << argv[0] << " [options]\n\n"
  151. << "Required Options:\n"
  152. << " --models-dir <dir> Base models directory path (required)\n"
  153. << "\n"
  154. << "Server Options:\n"
  155. << " --host <host> Host address to bind to (default: 0.0.0.0)\n"
  156. << " --port <port> Port number to listen on (default: 8080)\n"
  157. << " --max-concurrent <num> Maximum concurrent generations (default: 1)\n"
  158. << " --queue-dir <dir> Queue persistence directory (default: ./queue)\n"
  159. << " --output-dir <dir> Output files directory (default: ./output)\n"
  160. << " --ui-dir <dir> Web UI static files directory (optional)\n"
  161. << " --verbose, -v Enable verbose logging\n"
  162. << " --enable-file-logging Enable logging to file\n"
  163. << " --log-file <path> Log file path (default: /var/log/stable-diffusion-rest/server.log)\n"
  164. << "\n"
  165. << "Authentication Options:\n"
  166. << " --auth <method> Authentication method (none, jwt, api-key, unix, pam, optional)\n"
  167. << " --auth-method <method> Authentication method (alias for --auth)\n"
  168. << " --jwt-secret <secret> JWT secret key (auto-generated if not provided)\n"
  169. << " --jwt-expiration <minutes> JWT token expiration time (default: 60)\n"
  170. << " --enable-guest-access Allow unauthenticated guest access\n"
  171. << " --pam-service-name <name> PAM service name (default: stable-diffusion-rest)\n"
  172. << " --auth-data-dir <dir> Directory for authentication data (default: ./auth)\n"
  173. << " --public-paths <paths> Comma-separated list of public paths (default: /api/health,/api/status)\n"
  174. << "\n"
  175. << "Deprecated Options (will be removed in future version):\n"
  176. << " --enable-unix-auth Deprecated: Use --auth unix instead\n"
  177. << " --enable-pam-auth Deprecated: Use --auth pam instead\n"
  178. << "\n"
  179. << "Model Directory Options:\n"
  180. << " All model directories are optional and default to standard folder names\n"
  181. << " under --models-dir. Only specify these if your folder names differ.\n"
  182. << "\n"
  183. << " --checkpoints <dir> Checkpoints directory (default: checkpoints)\n"
  184. << " --controlnet-dir <dir> ControlNet models directory (default: controlnet)\n"
  185. << " --embeddings-dir <dir> Embeddings directory (default: embeddings)\n"
  186. << " --esrgan-dir <dir> ESRGAN models directory (default: ESRGAN)\n"
  187. << " --lora-dir <dir> LoRA models directory (default: loras)\n"
  188. << " --taesd-dir <dir> TAESD models directory (default: TAESD)\n"
  189. << " --vae-dir <dir> VAE models directory (default: vae)\n"
  190. << "\n"
  191. << "Other Options:\n"
  192. << " --help, -h Show this help message\n"
  193. << "\n"
  194. << "Path Resolution:\n"
  195. << " - Absolute paths are used as-is\n"
  196. << " - Relative paths are resolved relative to --models-dir\n"
  197. << " - Default folder names match standard SD model structure\n"
  198. << "\n"
  199. << "Examples:\n"
  200. << " # Use all defaults (requires standard folder structure)\n"
  201. << " " << argv[0] << " --models-dir /data/SD_MODELS\n"
  202. << "\n"
  203. << " # Override specific folders\n"
  204. << " " << argv[0] << " --models-dir /data/SD_MODELS --checkpoints my_checkpoints\n"
  205. << "\n"
  206. << " # Use absolute path for one folder\n"
  207. << " " << argv[0] << " --models-dir /data/SD_MODELS --lora-dir /other/path/loras\n"
  208. << std::endl;
  209. exit(0);
  210. } else {
  211. std::cerr << "Unknown argument: " << arg << std::endl;
  212. std::cerr << "Use --help for usage information" << std::endl;
  213. exit(1);
  214. }
  215. }
  216. // Validate required parameters
  217. if (!modelsDirSet) {
  218. std::cerr << "Error: --models-dir is required" << std::endl;
  219. std::cerr << "Use --help for usage information" << std::endl;
  220. exit(1);
  221. }
  222. // Set defaults for model directories if not explicitly set
  223. if (!checkpointsSet) {
  224. config.checkpoints = "checkpoints";
  225. }
  226. if (!controlnetSet) {
  227. config.controlnetDir = "controlnet";
  228. }
  229. if (!embeddingsSet) {
  230. config.embeddingsDir = "embeddings";
  231. }
  232. if (!esrganSet) {
  233. config.esrganDir = "ESRGAN";
  234. }
  235. if (!loraSet) {
  236. config.loraDir = "loras";
  237. }
  238. if (!taesdSet) {
  239. config.taesdDir = "TAESD";
  240. }
  241. if (!vaeSet) {
  242. config.vaeDir = "vae";
  243. }
  244. // Resolve all directory paths (absolute paths used as-is, relative resolved from models-dir)
  245. config.checkpoints = resolveDirectoryPath(config.checkpoints, config.modelsDir);
  246. config.controlnetDir = resolveDirectoryPath(config.controlnetDir, config.modelsDir);
  247. config.embeddingsDir = resolveDirectoryPath(config.embeddingsDir, config.modelsDir);
  248. config.esrganDir = resolveDirectoryPath(config.esrganDir, config.modelsDir);
  249. config.loraDir = resolveDirectoryPath(config.loraDir, config.modelsDir);
  250. config.taesdDir = resolveDirectoryPath(config.taesdDir, config.modelsDir);
  251. config.vaeDir = resolveDirectoryPath(config.vaeDir, config.modelsDir);
  252. return config;
  253. }
  254. int main(int argc, char* argv[]) {
  255. // Parse command line arguments
  256. ServerConfig config = parseArguments(argc, argv);
  257. // Initialize logger
  258. LogLevel minLevel = config.verbose ? LogLevel::DEBUG : LogLevel::INFO;
  259. Logger::getInstance().initialize(config.enableFileLogging, config.logFilePath, minLevel);
  260. // Create log directory if file logging is enabled
  261. if (config.enableFileLogging) {
  262. try {
  263. std::filesystem::path logPath(config.logFilePath);
  264. std::filesystem::create_directories(logPath.parent_path());
  265. } catch (const std::filesystem::filesystem_error& e) {
  266. LOG_ERROR("Failed to create log directory: " + std::string(e.what()));
  267. }
  268. }
  269. LOG_INFO("=== Stable Diffusion REST Server Starting ===");
  270. if (config.enableFileLogging) {
  271. LOG_INFO("File logging enabled: " + config.logFilePath);
  272. }
  273. // Create queue and output directories if they don't exist
  274. try {
  275. std::filesystem::create_directories(config.queueDir);
  276. std::filesystem::create_directories(config.outputDir);
  277. } catch (const std::filesystem::filesystem_error& e) {
  278. LOG_WARNING("Failed to create directories: " + std::string(e.what()));
  279. }
  280. if (config.verbose) {
  281. std::cout << "\n=== Configuration ===" << std::endl;
  282. std::cout << "Server:" << std::endl;
  283. std::cout << " Host: " << config.host << std::endl;
  284. std::cout << " Port: " << config.port << std::endl;
  285. std::cout << " Max concurrent generations: " << config.maxConcurrentGenerations << std::endl;
  286. std::cout << " Queue directory: " << config.queueDir << std::endl;
  287. std::cout << " Output directory: " << config.outputDir << std::endl;
  288. std::cout << "\nModel Directories:" << std::endl;
  289. std::cout << " Base models directory: " << config.modelsDir << std::endl;
  290. std::cout << " Checkpoints: " << config.checkpoints << std::endl;
  291. std::cout << " ControlNet: " << config.controlnetDir << std::endl;
  292. std::cout << " Embeddings: " << config.embeddingsDir << std::endl;
  293. std::cout << " ESRGAN: " << config.esrganDir << std::endl;
  294. std::cout << " LoRA: " << config.loraDir << std::endl;
  295. std::cout << " TAESD: " << config.taesdDir << std::endl;
  296. std::cout << " VAE: " << config.vaeDir << std::endl;
  297. std::cout << std::endl;
  298. }
  299. // Validate directory paths
  300. auto validateDirectory = [](const std::string& path, const std::string& name, bool required) -> bool {
  301. if (path.empty()) {
  302. if (required) {
  303. std::cerr << "Error: " << name << " directory is required but not specified" << std::endl;
  304. return false;
  305. }
  306. return true; // Empty path is valid for optional directories
  307. }
  308. std::filesystem::path dirPath(path);
  309. if (!std::filesystem::exists(dirPath)) {
  310. if (required) {
  311. std::cerr << "Error: " << name << " directory does not exist: " << path << std::endl;
  312. return false;
  313. } else {
  314. std::cerr << "Warning: " << name << " directory does not exist: " << path << std::endl;
  315. return true; // Optional directory can be missing
  316. }
  317. }
  318. if (!std::filesystem::is_directory(dirPath)) {
  319. std::cerr << "Error: " << name << " path is not a directory: " << path << std::endl;
  320. return false;
  321. }
  322. return true;
  323. };
  324. // Validate required directories
  325. bool allValid = true;
  326. // Validate base models directory (required - must exist)
  327. if (!validateDirectory(config.modelsDir, "Base models", true)) {
  328. allValid = false;
  329. }
  330. // Validate all model directories (will warn but not fail if missing)
  331. validateDirectory(config.checkpoints, "Checkpoints", false);
  332. validateDirectory(config.controlnetDir, "ControlNet", false);
  333. validateDirectory(config.embeddingsDir, "Embeddings", false);
  334. validateDirectory(config.esrganDir, "ESRGAN", false);
  335. validateDirectory(config.loraDir, "LoRA", false);
  336. validateDirectory(config.taesdDir, "TAESD", false);
  337. validateDirectory(config.vaeDir, "VAE", false);
  338. // Validate UI directory if specified
  339. if (!config.uiDir.empty()) {
  340. if (!validateDirectory(config.uiDir, "Web UI", false)) {
  341. std::cerr << "\nError: Web UI directory is invalid" << std::endl;
  342. return 1;
  343. }
  344. // Check if the UI directory is readable
  345. std::filesystem::path uiPath(config.uiDir);
  346. if (!std::filesystem::exists(uiPath / "index.html") &&
  347. !std::filesystem::exists(uiPath / "index.htm")) {
  348. std::cerr << "Warning: Web UI directory does not contain an index.html or index.htm file: " << config.uiDir << std::endl;
  349. }
  350. }
  351. if (!allValid) {
  352. std::cerr << "\nError: Base models directory is invalid or missing" << std::endl;
  353. return 1;
  354. }
  355. // Set up signal handlers for graceful shutdown
  356. signal(SIGINT, signalHandler); // Ctrl+C
  357. signal(SIGTERM, signalHandler); // Termination signal
  358. try {
  359. // Initialize authentication system
  360. if (config.verbose) {
  361. std::cout << "Initializing authentication system..." << std::endl;
  362. }
  363. auto userManager = std::make_shared<UserManager>(config.auth.dataDir,
  364. static_cast<UserManager::AuthMethod>(config.auth.authMethod));
  365. if (!userManager->initialize()) {
  366. std::cerr << "Error: Failed to initialize user manager" << std::endl;
  367. return 1;
  368. }
  369. if (config.verbose) {
  370. std::cout << "User manager initialized" << std::endl;
  371. std::cout << "Authentication method: ";
  372. switch (config.auth.authMethod) {
  373. case AuthMethod::NONE:
  374. std::cout << "None";
  375. break;
  376. case AuthMethod::JWT:
  377. std::cout << "JWT";
  378. break;
  379. case AuthMethod::API_KEY:
  380. std::cout << "API Key";
  381. break;
  382. case AuthMethod::UNIX:
  383. std::cout << "Unix";
  384. break;
  385. case AuthMethod::OPTIONAL:
  386. std::cout << "Optional";
  387. break;
  388. case AuthMethod::PAM:
  389. std::cout << "PAM";
  390. break;
  391. }
  392. std::cout << std::endl;
  393. }
  394. // Initialize authentication middleware
  395. auto authMiddleware = std::make_shared<AuthMiddleware>(config.auth, userManager);
  396. if (!authMiddleware->initialize()) {
  397. std::cerr << "Error: Failed to initialize authentication middleware" << std::endl;
  398. return 1;
  399. }
  400. // Initialize components
  401. auto modelManager = std::make_unique<ModelManager>();
  402. auto generationQueue = std::make_unique<GenerationQueue>(modelManager.get(), config.maxConcurrentGenerations,
  403. config.queueDir, config.outputDir);
  404. auto server = std::make_unique<Server>(modelManager.get(), generationQueue.get(), config.outputDir, config.uiDir);
  405. // Set authentication components in server
  406. server->setAuthComponents(userManager, authMiddleware);
  407. // Set global server pointer for signal handler access
  408. g_server = server.get();
  409. // Configure model manager with directory parameters
  410. if (config.verbose) {
  411. std::cout << "Configuring model manager..." << std::endl;
  412. }
  413. if (!modelManager->configureFromServerConfig(config)) {
  414. std::cerr << "Error: Failed to configure model manager with server config" << std::endl;
  415. return 1;
  416. }
  417. if (config.verbose) {
  418. std::cout << "Model manager configured with per-type directories" << std::endl;
  419. }
  420. // Scan models directory
  421. if (config.verbose) {
  422. std::cout << "Scanning models directory..." << std::endl;
  423. }
  424. if (!modelManager->scanModelsDirectory()) {
  425. std::cerr << "Warning: Failed to scan models directory" << std::endl;
  426. } else if (config.verbose) {
  427. std::cout << "Found " << modelManager->getAvailableModelsCount() << " models" << std::endl;
  428. }
  429. // Start the generation queue
  430. generationQueue->start();
  431. if (config.verbose) {
  432. std::cout << "Generation queue started" << std::endl;
  433. }
  434. // Start the HTTP server
  435. if (!server->start(config.host, config.port)) {
  436. std::cerr << "Failed to start server" << std::endl;
  437. return 1;
  438. }
  439. std::cout << "Server initialized successfully" << std::endl;
  440. std::cout << "Server listening on " << config.host << ":" << config.port << std::endl;
  441. std::cout << "Press Ctrl+C to stop the server" << std::endl;
  442. // Give server a moment to start
  443. std::this_thread::sleep_for(std::chrono::milliseconds(100));
  444. // Main server loop - wait for shutdown signal
  445. while (g_running.load() && server->isRunning()) {
  446. std::this_thread::sleep_for(std::chrono::milliseconds(100));
  447. }
  448. // Graceful shutdown
  449. std::cout << "Shutting down server..." << std::endl;
  450. // Stop the server first to stop accepting new requests
  451. server->stop();
  452. // Stop the generation queue
  453. generationQueue->stop();
  454. // Wait for server thread to finish
  455. server->waitForStop();
  456. // Clear global server pointer
  457. g_server = nullptr;
  458. std::cout << "Server shutdown complete" << std::endl;
  459. } catch (const std::exception& e) {
  460. std::cerr << "Error: " << e.what() << std::endl;
  461. return 1;
  462. }
  463. return 0;
  464. }