|
@@ -1,19 +1,19 @@
|
|
|
-#include <iostream>
|
|
|
|
|
-#include <memory>
|
|
|
|
|
#include <signal.h>
|
|
#include <signal.h>
|
|
|
-#include <string>
|
|
|
|
|
|
|
+#include <algorithm>
|
|
|
#include <atomic>
|
|
#include <atomic>
|
|
|
-#include <thread>
|
|
|
|
|
#include <chrono>
|
|
#include <chrono>
|
|
|
#include <filesystem>
|
|
#include <filesystem>
|
|
|
-#include <algorithm>
|
|
|
|
|
-#include "server.h"
|
|
|
|
|
-#include "model_manager.h"
|
|
|
|
|
|
|
+#include <iostream>
|
|
|
|
|
+#include <memory>
|
|
|
|
|
+#include <string>
|
|
|
|
|
+#include <thread>
|
|
|
|
|
+#include "auth_middleware.h"
|
|
|
#include "generation_queue.h"
|
|
#include "generation_queue.h"
|
|
|
-#include "server_config.h"
|
|
|
|
|
#include "logger.h"
|
|
#include "logger.h"
|
|
|
|
|
+#include "model_manager.h"
|
|
|
|
|
+#include "server.h"
|
|
|
|
|
+#include "server_config.h"
|
|
|
#include "user_manager.h"
|
|
#include "user_manager.h"
|
|
|
-#include "auth_middleware.h"
|
|
|
|
|
|
|
|
|
|
// Global flag for signal handling
|
|
// Global flag for signal handling
|
|
|
std::atomic<bool> g_running(true);
|
|
std::atomic<bool> g_running(true);
|
|
@@ -66,14 +66,14 @@ ServerConfig parseArguments(int argc, char* argv[]) {
|
|
|
ServerConfig config;
|
|
ServerConfig config;
|
|
|
|
|
|
|
|
// Track which parameters were explicitly set
|
|
// Track which parameters were explicitly set
|
|
|
- bool modelsDirSet = false;
|
|
|
|
|
|
|
+ bool modelsDirSet = false;
|
|
|
bool checkpointsSet = false;
|
|
bool checkpointsSet = false;
|
|
|
- bool controlnetSet = false;
|
|
|
|
|
- bool embeddingsSet = false;
|
|
|
|
|
- bool esrganSet = false;
|
|
|
|
|
- bool loraSet = false;
|
|
|
|
|
- bool taesdSet = false;
|
|
|
|
|
- bool vaeSet = false;
|
|
|
|
|
|
|
+ bool controlnetSet = false;
|
|
|
|
|
+ bool embeddingsSet = false;
|
|
|
|
|
+ bool esrganSet = false;
|
|
|
|
|
+ bool loraSet = false;
|
|
|
|
|
+ bool taesdSet = false;
|
|
|
|
|
+ bool vaeSet = false;
|
|
|
|
|
|
|
|
for (int i = 1; i < argc; i++) {
|
|
for (int i = 1; i < argc; i++) {
|
|
|
std::string arg = argv[i];
|
|
std::string arg = argv[i];
|
|
@@ -84,28 +84,28 @@ ServerConfig parseArguments(int argc, char* argv[]) {
|
|
|
config.port = std::stoi(argv[++i]);
|
|
config.port = std::stoi(argv[++i]);
|
|
|
} else if (arg == "--models-dir" && i + 1 < argc) {
|
|
} else if (arg == "--models-dir" && i + 1 < argc) {
|
|
|
config.modelsDir = argv[++i];
|
|
config.modelsDir = argv[++i];
|
|
|
- modelsDirSet = true;
|
|
|
|
|
|
|
+ modelsDirSet = true;
|
|
|
} else if (arg == "--checkpoints" && i + 1 < argc) {
|
|
} else if (arg == "--checkpoints" && i + 1 < argc) {
|
|
|
config.checkpoints = argv[++i];
|
|
config.checkpoints = argv[++i];
|
|
|
- checkpointsSet = true;
|
|
|
|
|
|
|
+ checkpointsSet = true;
|
|
|
} else if (arg == "--controlnet-dir" && i + 1 < argc) {
|
|
} else if (arg == "--controlnet-dir" && i + 1 < argc) {
|
|
|
config.controlnetDir = argv[++i];
|
|
config.controlnetDir = argv[++i];
|
|
|
- controlnetSet = true;
|
|
|
|
|
|
|
+ controlnetSet = true;
|
|
|
} else if (arg == "--embeddings-dir" && i + 1 < argc) {
|
|
} else if (arg == "--embeddings-dir" && i + 1 < argc) {
|
|
|
config.embeddingsDir = argv[++i];
|
|
config.embeddingsDir = argv[++i];
|
|
|
- embeddingsSet = true;
|
|
|
|
|
|
|
+ embeddingsSet = true;
|
|
|
} else if (arg == "--esrgan-dir" && i + 1 < argc) {
|
|
} else if (arg == "--esrgan-dir" && i + 1 < argc) {
|
|
|
config.esrganDir = argv[++i];
|
|
config.esrganDir = argv[++i];
|
|
|
- esrganSet = true;
|
|
|
|
|
|
|
+ esrganSet = true;
|
|
|
} else if (arg == "--lora-dir" && i + 1 < argc) {
|
|
} else if (arg == "--lora-dir" && i + 1 < argc) {
|
|
|
config.loraDir = argv[++i];
|
|
config.loraDir = argv[++i];
|
|
|
- loraSet = true;
|
|
|
|
|
|
|
+ loraSet = true;
|
|
|
} else if (arg == "--taesd-dir" && i + 1 < argc) {
|
|
} else if (arg == "--taesd-dir" && i + 1 < argc) {
|
|
|
config.taesdDir = argv[++i];
|
|
config.taesdDir = argv[++i];
|
|
|
- taesdSet = true;
|
|
|
|
|
|
|
+ taesdSet = true;
|
|
|
} else if (arg == "--vae-dir" && i + 1 < argc) {
|
|
} else if (arg == "--vae-dir" && i + 1 < argc) {
|
|
|
config.vaeDir = argv[++i];
|
|
config.vaeDir = argv[++i];
|
|
|
- vaeSet = true;
|
|
|
|
|
|
|
+ vaeSet = true;
|
|
|
} else if (arg == "--max-concurrent" && i + 1 < argc) {
|
|
} else if (arg == "--max-concurrent" && i + 1 < argc) {
|
|
|
config.maxConcurrentGenerations = std::stoi(argv[++i]);
|
|
config.maxConcurrentGenerations = std::stoi(argv[++i]);
|
|
|
} else if (arg == "--queue-dir" && i + 1 < argc) {
|
|
} else if (arg == "--queue-dir" && i + 1 < argc) {
|
|
@@ -118,7 +118,7 @@ ServerConfig parseArguments(int argc, char* argv[]) {
|
|
|
config.verbose = true;
|
|
config.verbose = true;
|
|
|
} else if (arg == "--log-file" && i + 1 < argc) {
|
|
} else if (arg == "--log-file" && i + 1 < argc) {
|
|
|
config.enableFileLogging = true;
|
|
config.enableFileLogging = true;
|
|
|
- config.logFilePath = argv[++i];
|
|
|
|
|
|
|
+ config.logFilePath = argv[++i];
|
|
|
} else if (arg == "--enable-file-logging") {
|
|
} else if (arg == "--enable-file-logging") {
|
|
|
config.enableFileLogging = true;
|
|
config.enableFileLogging = true;
|
|
|
} else if ((arg == "--auth-method" || arg == "--auth") && i + 1 < argc) {
|
|
} else if ((arg == "--auth-method" || arg == "--auth") && i + 1 < argc) {
|
|
@@ -259,19 +259,18 @@ ServerConfig parseArguments(int argc, char* argv[]) {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// Resolve all directory paths (absolute paths used as-is, relative resolved from models-dir)
|
|
// Resolve all directory paths (absolute paths used as-is, relative resolved from models-dir)
|
|
|
- config.checkpoints = resolveDirectoryPath(config.checkpoints, config.modelsDir);
|
|
|
|
|
|
|
+ config.checkpoints = resolveDirectoryPath(config.checkpoints, config.modelsDir);
|
|
|
config.controlnetDir = resolveDirectoryPath(config.controlnetDir, config.modelsDir);
|
|
config.controlnetDir = resolveDirectoryPath(config.controlnetDir, config.modelsDir);
|
|
|
config.embeddingsDir = resolveDirectoryPath(config.embeddingsDir, config.modelsDir);
|
|
config.embeddingsDir = resolveDirectoryPath(config.embeddingsDir, config.modelsDir);
|
|
|
- config.esrganDir = resolveDirectoryPath(config.esrganDir, config.modelsDir);
|
|
|
|
|
- config.loraDir = resolveDirectoryPath(config.loraDir, config.modelsDir);
|
|
|
|
|
- config.taesdDir = resolveDirectoryPath(config.taesdDir, config.modelsDir);
|
|
|
|
|
- config.vaeDir = resolveDirectoryPath(config.vaeDir, config.modelsDir);
|
|
|
|
|
|
|
+ config.esrganDir = resolveDirectoryPath(config.esrganDir, config.modelsDir);
|
|
|
|
|
+ config.loraDir = resolveDirectoryPath(config.loraDir, config.modelsDir);
|
|
|
|
|
+ config.taesdDir = resolveDirectoryPath(config.taesdDir, config.modelsDir);
|
|
|
|
|
+ config.vaeDir = resolveDirectoryPath(config.vaeDir, config.modelsDir);
|
|
|
|
|
|
|
|
return config;
|
|
return config;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
int main(int argc, char* argv[]) {
|
|
int main(int argc, char* argv[]) {
|
|
|
-
|
|
|
|
|
// Parse command line arguments
|
|
// Parse command line arguments
|
|
|
ServerConfig config = parseArguments(argc, argv);
|
|
ServerConfig config = parseArguments(argc, argv);
|
|
|
|
|
|
|
@@ -329,7 +328,7 @@ int main(int argc, char* argv[]) {
|
|
|
std::cerr << "Error: " << name << " directory is required but not specified" << std::endl;
|
|
std::cerr << "Error: " << name << " directory is required but not specified" << std::endl;
|
|
|
return false;
|
|
return false;
|
|
|
}
|
|
}
|
|
|
- return true; // Empty path is valid for optional directories
|
|
|
|
|
|
|
+ return true; // Empty path is valid for optional directories
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
std::filesystem::path dirPath(path);
|
|
std::filesystem::path dirPath(path);
|
|
@@ -339,7 +338,7 @@ int main(int argc, char* argv[]) {
|
|
|
return false;
|
|
return false;
|
|
|
} else {
|
|
} else {
|
|
|
std::cerr << "Warning: " << name << " directory does not exist: " << path << std::endl;
|
|
std::cerr << "Warning: " << name << " directory does not exist: " << path << std::endl;
|
|
|
- return true; // Optional directory can be missing
|
|
|
|
|
|
|
+ return true; // Optional directory can be missing
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -399,7 +398,7 @@ int main(int argc, char* argv[]) {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
auto userManager = std::make_shared<UserManager>(config.auth.dataDir,
|
|
auto userManager = std::make_shared<UserManager>(config.auth.dataDir,
|
|
|
- static_cast<UserManager::AuthMethod>(config.auth.authMethod));
|
|
|
|
|
|
|
+ static_cast<UserManager::AuthMethod>(config.auth.authMethod));
|
|
|
if (!userManager->initialize()) {
|
|
if (!userManager->initialize()) {
|
|
|
std::cerr << "Error: Failed to initialize user manager" << std::endl;
|
|
std::cerr << "Error: Failed to initialize user manager" << std::endl;
|
|
|
return 1;
|
|
return 1;
|
|
@@ -409,11 +408,24 @@ int main(int argc, char* argv[]) {
|
|
|
std::cout << "User manager initialized" << std::endl;
|
|
std::cout << "User manager initialized" << std::endl;
|
|
|
std::cout << "Authentication method: ";
|
|
std::cout << "Authentication method: ";
|
|
|
switch (config.auth.authMethod) {
|
|
switch (config.auth.authMethod) {
|
|
|
- case AuthMethod::NONE: std::cout << "None"; break;
|
|
|
|
|
- case AuthMethod::JWT: std::cout << "JWT"; break;
|
|
|
|
|
- case AuthMethod::API_KEY: std::cout << "API Key"; break;
|
|
|
|
|
- case AuthMethod::UNIX: std::cout << "Unix"; break;
|
|
|
|
|
- case AuthMethod::OPTIONAL: std::cout << "Optional"; break;
|
|
|
|
|
|
|
+ case AuthMethod::NONE:
|
|
|
|
|
+ std::cout << "None";
|
|
|
|
|
+ break;
|
|
|
|
|
+ case AuthMethod::JWT:
|
|
|
|
|
+ std::cout << "JWT";
|
|
|
|
|
+ break;
|
|
|
|
|
+ case AuthMethod::API_KEY:
|
|
|
|
|
+ std::cout << "API Key";
|
|
|
|
|
+ break;
|
|
|
|
|
+ case AuthMethod::UNIX:
|
|
|
|
|
+ std::cout << "Unix";
|
|
|
|
|
+ break;
|
|
|
|
|
+ case AuthMethod::OPTIONAL:
|
|
|
|
|
+ std::cout << "Optional";
|
|
|
|
|
+ break;
|
|
|
|
|
+ case AuthMethod::PAM:
|
|
|
|
|
+ std::cout << "PAM";
|
|
|
|
|
+ break;
|
|
|
}
|
|
}
|
|
|
std::cout << std::endl;
|
|
std::cout << std::endl;
|
|
|
}
|
|
}
|
|
@@ -426,10 +438,10 @@ int main(int argc, char* argv[]) {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// Initialize components
|
|
// Initialize components
|
|
|
- auto modelManager = std::make_unique<ModelManager>();
|
|
|
|
|
|
|
+ auto modelManager = std::make_unique<ModelManager>();
|
|
|
auto generationQueue = std::make_unique<GenerationQueue>(modelManager.get(), config.maxConcurrentGenerations,
|
|
auto generationQueue = std::make_unique<GenerationQueue>(modelManager.get(), config.maxConcurrentGenerations,
|
|
|
config.queueDir, config.outputDir);
|
|
config.queueDir, config.outputDir);
|
|
|
- auto server = std::make_unique<Server>(modelManager.get(), generationQueue.get(), config.outputDir, config.uiDir);
|
|
|
|
|
|
|
+ auto server = std::make_unique<Server>(modelManager.get(), generationQueue.get(), config.outputDir, config.uiDir);
|
|
|
|
|
|
|
|
// Set authentication components in server
|
|
// Set authentication components in server
|
|
|
server->setAuthComponents(userManager, authMiddleware);
|
|
server->setAuthComponents(userManager, authMiddleware);
|