main.cpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589
  1. #include <csignal>
  2. #include <atomic>
  3. #include <chrono>
  4. #include <filesystem>
  5. #include <iostream>
  6. #include <memory>
  7. #include <string>
  8. #include <thread>
  9. #include "auth_middleware.h"
  10. #include "generation_queue.h"
  11. #include "logger.h"
  12. #include "model_manager.h"
  13. #include "server.h"
  14. #include "server_config.h"
  15. #include "user_manager.h"
  16. #include "version.h"
  17. // Global flag for signal handling
  18. std::atomic<bool> g_running(true);
  19. // Global pointer to server instance for signal handler access
  20. Server* g_server = nullptr;
  21. // Global shutdown delay for signal handler
  22. int g_shutdownDelayMs = 100;
  23. // Signal handler for graceful shutdown
  24. void signalHandler(int signal) {
  25. LOG_INFO("Received signal " + std::to_string(signal) + ", shutting down gracefully...");
  26. g_running.store(false);
  27. // Stop the server directly from signal handler
  28. if (g_server != nullptr) {
  29. g_server->stop();
  30. }
  31. // Give a brief moment for cleanup, then force exit
  32. std::this_thread::sleep_for(std::chrono::milliseconds(g_shutdownDelayMs));
  33. LOG_INFO("Exiting process...");
  34. Logger::getInstance().close();
  35. exit(0);
  36. }
  37. // Helper function to resolve directory path
  38. std::string resolveDirectoryPath(const std::string& path, const std::string& modelsDir) {
  39. if (path.empty()) {
  40. return "";
  41. }
  42. std::filesystem::path dirPath(path);
  43. // If the path is absolute, use it as-is
  44. if (dirPath.is_absolute()) {
  45. return path;
  46. }
  47. // If the path is relative and models-dir is specified, prepend models-dir
  48. if (!modelsDir.empty()) {
  49. std::filesystem::path baseDir(modelsDir);
  50. return (baseDir / dirPath).string();
  51. }
  52. // If no models-dir, return the relative path as-is
  53. return path;
  54. }
  55. // Parse command line arguments
  56. ServerConfig parseArguments(int argc, char* argv[]) {
  57. ServerConfig config;
  58. // Track which parameters were explicitly set
  59. bool modelsDirSet = false;
  60. bool checkpointsSet = false;
  61. bool controlnetSet = false;
  62. bool embeddingsSet = false;
  63. bool esrganSet = false;
  64. bool loraSet = false;
  65. bool taesdSet = false;
  66. bool vaeSet = false;
  67. bool diffusionModelsSet = false;
  68. for (int i = 1; i < argc; i++) {
  69. std::string arg = argv[i];
  70. if (arg == "--host" && i + 1 < argc) {
  71. config.host = argv[++i];
  72. } else if (arg == "--port" && i + 1 < argc) {
  73. config.port = std::stoi(argv[++i]);
  74. } else if (arg == "--models-dir" && i + 1 < argc) {
  75. config.modelsDir = argv[++i];
  76. modelsDirSet = true;
  77. } else if (arg == "--checkpoints" && i + 1 < argc) {
  78. config.checkpoints = argv[++i];
  79. checkpointsSet = true;
  80. } else if (arg == "--controlnet-dir" && i + 1 < argc) {
  81. config.controlnetDir = argv[++i];
  82. controlnetSet = true;
  83. } else if (arg == "--embeddings-dir" && i + 1 < argc) {
  84. config.embeddingsDir = argv[++i];
  85. embeddingsSet = true;
  86. } else if (arg == "--esrgan-dir" && i + 1 < argc) {
  87. config.esrganDir = argv[++i];
  88. esrganSet = true;
  89. } else if (arg == "--lora-dir" && i + 1 < argc) {
  90. config.loraDir = argv[++i];
  91. loraSet = true;
  92. } else if (arg == "--taesd-dir" && i + 1 < argc) {
  93. config.taesdDir = argv[++i];
  94. taesdSet = true;
  95. } else if (arg == "--vae-dir" && i + 1 < argc) {
  96. config.vaeDir = argv[++i];
  97. vaeSet = true;
  98. } else if (arg == "--diffusion-models-dir" && i + 1 < argc) {
  99. config.diffusionModelsDir = argv[++i];
  100. diffusionModelsSet = true;
  101. } else if (arg == "--max-concurrent" && i + 1 < argc) {
  102. config.maxConcurrentGenerations = std::stoi(argv[++i]);
  103. } else if (arg == "--queue-dir" && i + 1 < argc) {
  104. config.queueDir = argv[++i];
  105. } else if (arg == "--output-dir" && i + 1 < argc) {
  106. config.outputDir = argv[++i];
  107. } else if (arg == "--ui-dir" && i + 1 < argc) {
  108. config.uiDir = argv[++i];
  109. } else if (arg == "--verbose" || arg == "-v") {
  110. config.verbose = true;
  111. } else if (arg == "--log-file" && i + 1 < argc) {
  112. config.enableFileLogging = true;
  113. config.logFilePath = argv[++i];
  114. } else if (arg == "--enable-file-logging") {
  115. config.enableFileLogging = true;
  116. } else if ((arg == "--auth-method" || arg == "--auth") && i + 1 < argc) {
  117. std::string method = argv[++i];
  118. if (method == "none") {
  119. config.auth.authMethod = AuthMethod::NONE;
  120. } else if (method == "jwt") {
  121. config.auth.authMethod = AuthMethod::JWT;
  122. } else if (method == "api-key") {
  123. config.auth.authMethod = AuthMethod::API_KEY;
  124. } else if (method == "unix") {
  125. config.auth.authMethod = AuthMethod::UNIX;
  126. } else if (method == "pam") {
  127. config.auth.authMethod = AuthMethod::PAM;
  128. } else if (method == "optional") {
  129. config.auth.authMethod = AuthMethod::OPTIONAL;
  130. } else {
  131. std::cerr << "Invalid auth method: " << method << std::endl;
  132. exit(1);
  133. }
  134. } else if (arg == "--jwt-secret" && i + 1 < argc) {
  135. config.auth.jwtSecret = argv[++i];
  136. } else if (arg == "--jwt-expiration" && i + 1 < argc) {
  137. config.auth.jwtExpirationMinutes = std::stoi(argv[++i]);
  138. } else if (arg == "--enable-guest-access") {
  139. config.auth.enableGuestAccess = true;
  140. } else if (arg == "--enable-unix-auth") {
  141. // Deprecated flag - show warning and set auth method to UNIX
  142. std::cerr << "Warning: --enable-unix-auth is deprecated. Use --auth unix instead." << std::endl;
  143. config.auth.authMethod = AuthMethod::UNIX;
  144. } else if (arg == "--enable-pam-auth") {
  145. // Deprecated flag - show warning and set auth method to PAM
  146. std::cerr << "Warning: --enable-pam-auth is deprecated. Use --auth pam instead." << std::endl;
  147. config.auth.authMethod = AuthMethod::PAM;
  148. } else if (arg == "--pam-service-name" && i + 1 < argc) {
  149. config.auth.pamServiceName = argv[++i];
  150. } else if (arg == "--auth-data-dir" && i + 1 < argc) {
  151. config.auth.dataDir = argv[++i];
  152. } else if (arg == "--public-paths" && i + 1 < argc) {
  153. config.auth.customPublicPaths = argv[++i];
  154. } else if (arg == "--connection-timeout" && i + 1 < argc) {
  155. config.connectionTimeoutMs = std::stoi(argv[++i]);
  156. } else if (arg == "--read-timeout" && i + 1 < argc) {
  157. config.readTimeoutMs = std::stoi(argv[++i]);
  158. } else if (arg == "--write-timeout" && i + 1 < argc) {
  159. config.writeTimeoutMs = std::stoi(argv[++i]);
  160. } else if (arg == "--max-prompt-length" && i + 1 < argc) {
  161. config.maxPromptLength = std::stoi(argv[++i]);
  162. } else if (arg == "--max-negative-prompt-length" && i + 1 < argc) {
  163. config.maxNegativePromptLength = std::stoi(argv[++i]);
  164. } else if (arg == "--shutdown-delay" && i + 1 < argc) {
  165. config.shutdownDelayMs = std::stoi(argv[++i]);
  166. } else if (arg == "--default-admin-username" && i + 1 < argc) {
  167. config.defaultAdminUsername = argv[++i];
  168. } else if (arg == "--default-admin-password" && i + 1 < argc) {
  169. config.defaultAdminPassword = argv[++i];
  170. } else if (arg == "--default-admin-email" && i + 1 < argc) {
  171. config.defaultAdminEmail = argv[++i];
  172. } else if (arg == "--jwt-audience" && i + 1 < argc) {
  173. config.auth.jwtAudience = argv[++i];
  174. } else if (arg == "--help" || arg == "-h") {
  175. std::cout << "stable-diffusion.cpp-rest server\n"
  176. << "Usage: " << argv[0] << " [options]\n\n"
  177. << "Required Options:\n"
  178. << " --models-dir <dir> Base models directory path (required)\n"
  179. << "\n"
  180. << "Server Options:\n"
  181. << " --host <host> Host address to bind to (default: 0.0.0.0)\n"
  182. << " --port <port> Port number to listen on (default: 8080)\n"
  183. << " --max-concurrent <num> Maximum concurrent generations (default: 1)\n"
  184. << " --queue-dir <dir> Queue persistence directory (default: ./queue)\n"
  185. << " --output-dir <dir> Output files directory (default: ./output)\n"
  186. << " --ui-dir <dir> Web UI static files directory (optional)\n"
  187. << " --verbose, -v Enable verbose logging\n"
  188. << " --enable-file-logging Enable logging to file\n"
  189. << " --log-file <path> Log file path (default: /var/log/stable-diffusion-rest/server.log)\n"
  190. << "\n"
  191. << "Network & Connection Options:\n"
  192. << " --connection-timeout <ms> Connection timeout in milliseconds (default: 500)\n"
  193. << " --read-timeout <ms> Read timeout in milliseconds (default: 500)\n"
  194. << " --write-timeout <ms> Write timeout in milliseconds (default: 500)\n"
  195. << "\n"
  196. << "Authentication Options:\n"
  197. << " --auth <method> Authentication method (none, jwt, api-key, unix, pam, optional)\n"
  198. << " --auth-method <method> Authentication method (alias for --auth)\n"
  199. << " --jwt-secret <secret> JWT secret key (auto-generated if not provided)\n"
  200. << " --jwt-expiration <minutes> JWT token expiration time (default: 60)\n"
  201. << " --enable-guest-access Allow unauthenticated guest access\n"
  202. << " --pam-service-name <name> PAM service name (default: stable-diffusion-rest)\n"
  203. << " --auth-data-dir <dir> Directory for authentication data (default: ./auth)\n"
  204. << " --public-paths <paths> Comma-separated list of public paths (default: /api/health,/api/status)\n"
  205. << " --jwt-audience <audience> JWT audience claim (default: stable-diffusion-rest)\n"
  206. << "\n"
  207. << "Deprecated Options (will be removed in future version):\n"
  208. << " --enable-unix-auth Deprecated: Use --auth unix instead\n"
  209. << " --enable-pam-auth Deprecated: Use --auth pam instead\n"
  210. << "\n"
  211. << "Model Directory Options:\n"
  212. << " All model directories are optional and default to standard folder names\n"
  213. << " under --models-dir. Only specify these if your folder names differ.\n"
  214. << "\n"
  215. << " --checkpoints <dir> Checkpoints directory (default: checkpoints)\n"
  216. << " --controlnet-dir <dir> ControlNet models directory (default: controlnet)\n"
  217. << " --embeddings-dir <dir> Embeddings directory (default: embeddings)\n"
  218. << " --esrgan-dir <dir> ESRGAN models directory (default: ESRGAN)\n"
  219. << " --lora-dir <dir> LoRA models directory (default: loras)\n"
  220. << " --taesd-dir <dir> TAESD models directory (default: TAESD)\n"
  221. << " --vae-dir <dir> VAE models directory (default: vae)\n"
  222. << " --diffusion-models-dir <dir> Diffusion models directory (default: diffusion_models)\n"
  223. << "\n"
  224. << "Generation Limits & Security:\n"
  225. << " --max-prompt-length <len> Maximum prompt character length (default: 10000)\n"
  226. << " --max-negative-prompt-length <len> Maximum negative prompt character length (default: 10000)\n"
  227. << " --shutdown-delay <ms> Graceful shutdown delay in milliseconds (default: 100)\n"
  228. << "\n"
  229. << "Default Admin Credentials:\n"
  230. << " --default-admin-username <name> Default admin username (default: admin)\n"
  231. << " --default-admin-password <pass> Default admin password (default: admin123)\n"
  232. << " --default-admin-email <email> Default admin email (default: admin@localhost)\n"
  233. << "\n"
  234. << "Other Options:\n"
  235. << " --help, -h Show this help message\n"
  236. << "\n"
  237. << "Path Resolution:\n"
  238. << " - Absolute paths are used as-is\n"
  239. << " - Relative paths are resolved relative to --models-dir\n"
  240. << " - Default folder names match standard SD model structure\n"
  241. << "\n"
  242. << "Examples:\n"
  243. << " # Use all defaults (requires standard folder structure)\n"
  244. << " " << argv[0] << " --models-dir /data/SD_MODELS\n"
  245. << "\n"
  246. << " # Override specific folders\n"
  247. << " " << argv[0] << " --models-dir /data/SD_MODELS --checkpoints my_checkpoints\n"
  248. << "\n"
  249. << " # Use absolute path for one folder\n"
  250. << " " << argv[0] << " --models-dir /data/SD_MODELS --lora-dir /other/path/loras\n"
  251. << std::endl;
  252. exit(0);
  253. } else {
  254. std::cerr << "Unknown argument: " << arg << std::endl;
  255. std::cerr << "Use --help for usage information" << std::endl;
  256. exit(1);
  257. }
  258. }
  259. // Validate required parameters
  260. if (!modelsDirSet) {
  261. std::cerr << "Error: --models-dir is required" << std::endl;
  262. std::cerr << "Use --help for usage information" << std::endl;
  263. exit(1);
  264. }
  265. // Set defaults for model directories if not explicitly set
  266. if (!checkpointsSet) {
  267. config.checkpoints = "checkpoints";
  268. }
  269. if (!controlnetSet) {
  270. config.controlnetDir = "controlnet";
  271. }
  272. if (!embeddingsSet) {
  273. config.embeddingsDir = "embeddings";
  274. }
  275. if (!esrganSet) {
  276. config.esrganDir = "ESRGAN";
  277. }
  278. if (!loraSet) {
  279. config.loraDir = "loras";
  280. }
  281. if (!taesdSet) {
  282. config.taesdDir = "TAESD";
  283. }
  284. if (!vaeSet) {
  285. config.vaeDir = "vae";
  286. }
  287. if (!diffusionModelsSet) {
  288. config.diffusionModelsDir = "diffusion_models";
  289. }
  290. // Resolve all directory paths (absolute paths used as-is, relative resolved from models-dir)
  291. config.checkpoints = resolveDirectoryPath(config.checkpoints, config.modelsDir);
  292. config.controlnetDir = resolveDirectoryPath(config.controlnetDir, config.modelsDir);
  293. config.embeddingsDir = resolveDirectoryPath(config.embeddingsDir, config.modelsDir);
  294. config.esrganDir = resolveDirectoryPath(config.esrganDir, config.modelsDir);
  295. config.loraDir = resolveDirectoryPath(config.loraDir, config.modelsDir);
  296. config.taesdDir = resolveDirectoryPath(config.taesdDir, config.modelsDir);
  297. config.vaeDir = resolveDirectoryPath(config.vaeDir, config.modelsDir);
  298. config.diffusionModelsDir = resolveDirectoryPath(config.diffusionModelsDir, config.modelsDir);
  299. return config;
  300. }
  301. int main(int argc, char* argv[]) {
  302. // Parse command line arguments
  303. ServerConfig config = parseArguments(argc, argv);
  304. // Initialize logger
  305. LogLevel minLevel = config.verbose ? LogLevel::DEBUG : LogLevel::INFO;
  306. Logger::getInstance().initialize(config.enableFileLogging, config.logFilePath, minLevel);
  307. // Create log directory if file logging is enabled
  308. if (config.enableFileLogging) {
  309. try {
  310. std::filesystem::path logPath(config.logFilePath);
  311. std::filesystem::create_directories(logPath.parent_path());
  312. } catch (const std::filesystem::filesystem_error& e) {
  313. LOG_ERROR("Failed to create log directory: " + std::string(e.what()));
  314. }
  315. }
  316. LOG_INFO("=== Stable Diffusion REST Server Starting ===");
  317. LOG_INFO("Version: " + sd_rest::VERSION_INFO.version_full + " (" + sd_rest::VERSION_INFO.version_type + ")");
  318. LOG_INFO("Commit: " + sd_rest::VERSION_INFO.commit_short + (sd_rest::VERSION_INFO.is_clean ? "" : " (dirty)"));
  319. LOG_INFO("Build time: " + sd_rest::VERSION_INFO.build_time);
  320. if (config.enableFileLogging) {
  321. LOG_INFO("File logging enabled: " + config.logFilePath);
  322. }
  323. // Create queue and output directories if they don't exist
  324. try {
  325. std::filesystem::create_directories(config.queueDir);
  326. std::filesystem::create_directories(config.outputDir);
  327. } catch (const std::filesystem::filesystem_error& e) {
  328. LOG_WARNING("Failed to create directories: " + std::string(e.what()));
  329. }
  330. if (config.verbose) {
  331. std::cout << "\n=== Configuration ===" << std::endl;
  332. std::cout << "Version: " << sd_rest::VERSION_INFO.version_full << " (" << sd_rest::VERSION_INFO.version_type << ")" << std::endl;
  333. std::cout << "Commit: " << sd_rest::VERSION_INFO.commit_short << (sd_rest::VERSION_INFO.is_clean ? "" : " (dirty)") << std::endl;
  334. std::cout << "Build time: " << sd_rest::VERSION_INFO.build_time << std::endl;
  335. std::cout << std::endl;
  336. std::cout << "Server:" << std::endl;
  337. std::cout << " Host: " << config.host << std::endl;
  338. std::cout << " Port: " << config.port << std::endl;
  339. std::cout << " Max concurrent generations: " << config.maxConcurrentGenerations << std::endl;
  340. std::cout << " Queue directory: " << config.queueDir << std::endl;
  341. std::cout << " Output directory: " << config.outputDir << std::endl;
  342. std::cout << "\nModel Directories:" << std::endl;
  343. std::cout << " Base models directory: " << config.modelsDir << std::endl;
  344. std::cout << " Checkpoints: " << config.checkpoints << std::endl;
  345. std::cout << " ControlNet: " << config.controlnetDir << std::endl;
  346. std::cout << " Embeddings: " << config.embeddingsDir << std::endl;
  347. std::cout << " ESRGAN: " << config.esrganDir << std::endl;
  348. std::cout << " LoRA: " << config.loraDir << std::endl;
  349. std::cout << " TAESD: " << config.taesdDir << std::endl;
  350. std::cout << " VAE: " << config.vaeDir << std::endl;
  351. std::cout << " Diffusion: " << config.diffusionModelsDir << std::endl;
  352. std::cout << std::endl;
  353. }
  354. // Validate directory paths
  355. auto validateDirectory = [](const std::string& path, const std::string& name, bool required) -> bool {
  356. if (path.empty()) {
  357. if (required) {
  358. std::cerr << "Error: " << name << " directory is required but not specified" << std::endl;
  359. return false;
  360. }
  361. return true; // Empty path is valid for optional directories
  362. }
  363. std::filesystem::path dirPath(path);
  364. if (!std::filesystem::exists(dirPath)) {
  365. if (required) {
  366. std::cerr << "Error: " << name << " directory does not exist: " << path << std::endl;
  367. return false;
  368. } else {
  369. std::cerr << "Warning: " << name << " directory does not exist: " << path << std::endl;
  370. return true; // Optional directory can be missing
  371. }
  372. }
  373. if (!std::filesystem::is_directory(dirPath)) {
  374. std::cerr << "Error: " << name << " path is not a directory: " << path << std::endl;
  375. return false;
  376. }
  377. return true;
  378. };
  379. // Validate required directories
  380. bool allValid = true;
  381. // Validate base models directory (required - must exist)
  382. if (!validateDirectory(config.modelsDir, "Base models", true)) {
  383. allValid = false;
  384. }
  385. // Validate all model directories (will warn but not fail if missing)
  386. validateDirectory(config.checkpoints, "Checkpoints", false);
  387. validateDirectory(config.controlnetDir, "ControlNet", false);
  388. validateDirectory(config.embeddingsDir, "Embeddings", false);
  389. validateDirectory(config.esrganDir, "ESRGAN", false);
  390. validateDirectory(config.loraDir, "LoRA", false);
  391. validateDirectory(config.taesdDir, "TAESD", false);
  392. validateDirectory(config.vaeDir, "VAE", false);
  393. validateDirectory(config.diffusionModelsDir, "Diffusion Models", false);
  394. // Validate UI directory if specified
  395. if (!config.uiDir.empty()) {
  396. if (!validateDirectory(config.uiDir, "Web UI", false)) {
  397. std::cerr << "\nError: Web UI directory is invalid" << std::endl;
  398. return 1;
  399. }
  400. // Check if the UI directory is readable
  401. std::filesystem::path uiPath(config.uiDir);
  402. if (!std::filesystem::exists(uiPath / "index.html") &&
  403. !std::filesystem::exists(uiPath / "index.htm")) {
  404. std::cerr << "Warning: Web UI directory does not contain an index.html or index.htm file: " << config.uiDir << std::endl;
  405. }
  406. }
  407. if (!allValid) {
  408. std::cerr << "\nError: Base models directory is invalid or missing" << std::endl;
  409. return 1;
  410. }
  411. // Set up signal handlers for graceful shutdown
  412. signal(SIGINT, signalHandler); // Ctrl+C
  413. signal(SIGTERM, signalHandler); // Termination signal
  414. try {
  415. // Initialize authentication system
  416. if (config.verbose) {
  417. std::cout << "Initializing authentication system..." << std::endl;
  418. }
  419. auto userManager = std::make_shared<UserManager>(config.auth.dataDir,
  420. static_cast<UserManager::AuthMethod>(config.auth.authMethod),
  421. config.defaultAdminUsername,
  422. config.defaultAdminPassword,
  423. config.defaultAdminEmail);
  424. if (!userManager->initialize()) {
  425. std::cerr << "Error: Failed to initialize user manager" << std::endl;
  426. return 1;
  427. }
  428. if (config.verbose) {
  429. std::cout << "User manager initialized" << std::endl;
  430. std::cout << "Authentication method: ";
  431. switch (config.auth.authMethod) {
  432. case AuthMethod::NONE:
  433. std::cout << "None";
  434. break;
  435. case AuthMethod::JWT:
  436. std::cout << "JWT";
  437. break;
  438. case AuthMethod::API_KEY:
  439. std::cout << "API Key";
  440. break;
  441. case AuthMethod::UNIX:
  442. std::cout << "Unix";
  443. break;
  444. case AuthMethod::OPTIONAL:
  445. std::cout << "Optional";
  446. break;
  447. case AuthMethod::PAM:
  448. std::cout << "PAM";
  449. break;
  450. }
  451. std::cout << std::endl;
  452. }
  453. // Initialize authentication middleware
  454. auto authMiddleware = std::make_shared<AuthMiddleware>(config.auth, userManager);
  455. if (!authMiddleware->initialize()) {
  456. std::cerr << "Error: Failed to initialize authentication middleware" << std::endl;
  457. return 1;
  458. }
  459. // Initialize components
  460. auto modelManager = std::make_unique<ModelManager>();
  461. auto generationQueue = std::make_unique<GenerationQueue>(modelManager.get(), config.maxConcurrentGenerations,
  462. config.queueDir, config.outputDir);
  463. auto server = std::make_unique<Server>(modelManager.get(), generationQueue.get(), config.outputDir, config.uiDir, config);
  464. // Set authentication components in server
  465. server->setAuthComponents(userManager, authMiddleware);
  466. // Set global server pointer for signal handler access
  467. g_server = server.get();
  468. // Set global shutdown delay from config
  469. g_shutdownDelayMs = config.shutdownDelayMs;
  470. // Configure model manager with directory parameters
  471. if (config.verbose) {
  472. std::cout << "Configuring model manager..." << std::endl;
  473. }
  474. if (!modelManager->configureFromServerConfig(config)) {
  475. std::cerr << "Error: Failed to configure model manager with server config" << std::endl;
  476. return 1;
  477. }
  478. if (config.verbose) {
  479. std::cout << "Model manager configured with per-type directories" << std::endl;
  480. }
  481. // Scan models directory
  482. if (config.verbose) {
  483. std::cout << "Scanning models directory..." << std::endl;
  484. }
  485. if (!modelManager->scanModelsDirectory()) {
  486. std::cerr << "Warning: Failed to scan models directory" << std::endl;
  487. } else if (config.verbose) {
  488. std::cout << "Found " << modelManager->getAvailableModelsCount() << " models" << std::endl;
  489. }
  490. // Start the generation queue
  491. generationQueue->start();
  492. if (config.verbose) {
  493. std::cout << "Generation queue started" << std::endl;
  494. }
  495. // Start the HTTP server
  496. if (!server->start(config.host, config.port)) {
  497. std::cerr << "Failed to start server" << std::endl;
  498. return 1;
  499. }
  500. std::cout << "Server initialized successfully" << std::endl;
  501. std::cout << "Server listening on " << config.host << ":" << config.port << std::endl;
  502. std::cout << "Press Ctrl+C to stop the server" << std::endl;
  503. // Give server a moment to start
  504. std::this_thread::sleep_for(std::chrono::milliseconds(100));
  505. // Main server loop - wait for shutdown signal
  506. while (g_running.load() && server->isRunning()) {
  507. std::this_thread::sleep_for(std::chrono::milliseconds(100));
  508. }
  509. // Graceful shutdown
  510. std::cout << "Shutting down server..." << std::endl;
  511. // Stop the server first to stop accepting new requests
  512. server->stop();
  513. // Unload all models to ensure contexts are properly freed
  514. std::cout << "Unloading all models..." << std::endl;
  515. modelManager->unloadAllModels();
  516. // Stop the generation queue
  517. generationQueue->stop();
  518. // Wait for server thread to finish
  519. server->waitForStop();
  520. // Clear global server pointer
  521. g_server = nullptr;
  522. std::cout << "Server shutdown complete" << std::endl;
  523. } catch (const std::exception& e) {
  524. std::cerr << "Error: " << e.what() << std::endl;
  525. return 1;
  526. }
  527. return 0;
  528. }