| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821 |
- #include "server.h"
- #include "model_manager.h"
- #include "generation_queue.h"
- #include "utils.h"
- #include <httplib.h>
- #include <nlohmann/json.hpp>
- #include <iostream>
- #include <sstream>
- #include <fstream>
- #include <chrono>
- #include <random>
- #include <iomanip>
- #include <algorithm>
- #include <thread>
- #include <filesystem>
- // Include stb_image for loading images (implementation is in generation_queue.cpp)
- #include "../stable-diffusion.cpp-src/thirdparty/stb_image.h"
- #include <sys/socket.h>
- #include <netinet/in.h>
- #include <unistd.h>
- #include <arpa/inet.h>
- using json = nlohmann::json;
- Server::Server(ModelManager* modelManager, GenerationQueue* generationQueue, const std::string& outputDir, const std::string& uiDir)
- : m_modelManager(modelManager)
- , m_generationQueue(generationQueue)
- , m_isRunning(false)
- , m_startupFailed(false)
- , m_port(8080)
- , m_outputDir(outputDir)
- , m_uiDir(uiDir)
- {
- m_httpServer = std::make_unique<httplib::Server>();
- }
- Server::~Server() {
- stop();
- }
- bool Server::start(const std::string& host, int port) {
- if (m_isRunning.load()) {
- return false;
- }
- m_host = host;
- m_port = port;
- // Validate host and port
- if (host.empty() || (port < 1 || port > 65535)) {
- return false;
- }
- // Set up CORS headers
- setupCORS();
- // Register API endpoints
- registerEndpoints();
- // Reset startup flags
- m_startupFailed.store(false);
- // Start server in a separate thread
- m_serverThread = std::thread(&Server::serverThreadFunction, this, host, port);
- // Wait for server to actually start and bind to the port
- // Give more time for server to actually start and bind
- for (int i = 0; i < 100; i++) { // Wait up to 10 seconds
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
- // Check if startup failed early
- if (m_startupFailed.load()) {
- if (m_serverThread.joinable()) {
- m_serverThread.join();
- }
- return false;
- }
- if (m_isRunning.load()) {
- // Give it a moment more to ensure server is fully started
- std::this_thread::sleep_for(std::chrono::milliseconds(500));
- if (m_isRunning.load()) {
- return true;
- }
- }
- }
- if (m_isRunning.load()) {
- return true;
- } else {
- if (m_serverThread.joinable()) {
- m_serverThread.join();
- }
- return false;
- }
- }
- void Server::stop() {
- // Use atomic check to ensure thread safety
- bool wasRunning = m_isRunning.exchange(false);
- if (!wasRunning) {
- return; // Already stopped
- }
- if (m_httpServer) {
- m_httpServer->stop();
- // Give the server a moment to stop the blocking listen call
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
- // If server thread is still running, try to force unblock the listen call
- // by making a quick connection to the server port
- if (m_serverThread.joinable()) {
- try {
- // Create a quick connection to interrupt the blocking listen
- httplib::Client client("127.0.0.1", m_port);
- client.set_connection_timeout(0, 500000); // 0.5 seconds
- client.set_read_timeout(0, 500000); // 0.5 seconds
- client.set_write_timeout(0, 500000); // 0.5 seconds
- auto res = client.Get("/api/health");
- // We don't care about the response, just trying to unblock
- } catch (...) {
- // Ignore any connection errors - we're just trying to unblock
- }
- }
- }
- if (m_serverThread.joinable()) {
- m_serverThread.join();
- }
- }
- bool Server::isRunning() const {
- return m_isRunning.load();
- }
- void Server::waitForStop() {
- if (m_serverThread.joinable()) {
- m_serverThread.join();
- }
- }
- void Server::registerEndpoints() {
- // Health check endpoint
- m_httpServer->Get("/api/health", [this](const httplib::Request& req, httplib::Response& res) {
- handleHealthCheck(req, res);
- });
- // API status endpoint
- m_httpServer->Get("/api/status", [this](const httplib::Request& req, httplib::Response& res) {
- handleApiStatus(req, res);
- });
- // Specialized generation endpoints
- m_httpServer->Post("/api/generate/text2img", [this](const httplib::Request& req, httplib::Response& res) {
- handleText2Img(req, res);
- });
- m_httpServer->Post("/api/generate/img2img", [this](const httplib::Request& req, httplib::Response& res) {
- handleImg2Img(req, res);
- });
- m_httpServer->Post("/api/generate/controlnet", [this](const httplib::Request& req, httplib::Response& res) {
- handleControlNet(req, res);
- });
- m_httpServer->Post("/api/generate/upscale", [this](const httplib::Request& req, httplib::Response& res) {
- handleUpscale(req, res);
- });
- // Utility endpoints
- m_httpServer->Get("/api/samplers", [this](const httplib::Request& req, httplib::Response& res) {
- handleSamplers(req, res);
- });
- m_httpServer->Get("/api/schedulers", [this](const httplib::Request& req, httplib::Response& res) {
- handleSchedulers(req, res);
- });
- m_httpServer->Get("/api/parameters", [this](const httplib::Request& req, httplib::Response& res) {
- handleParameters(req, res);
- });
- m_httpServer->Post("/api/validate", [this](const httplib::Request& req, httplib::Response& res) {
- handleValidate(req, res);
- });
- m_httpServer->Post("/api/estimate", [this](const httplib::Request& req, httplib::Response& res) {
- handleEstimate(req, res);
- });
- m_httpServer->Get("/api/config", [this](const httplib::Request& req, httplib::Response& res) {
- handleConfig(req, res);
- });
- m_httpServer->Get("/api/system", [this](const httplib::Request& req, httplib::Response& res) {
- handleSystem(req, res);
- });
- m_httpServer->Post("/api/system/restart", [this](const httplib::Request& req, httplib::Response& res) {
- handleSystemRestart(req, res);
- });
- // Models list endpoint
- m_httpServer->Get("/api/models", [this](const httplib::Request& req, httplib::Response& res) {
- handleModelsList(req, res);
- });
- // Model-specific endpoints
- m_httpServer->Get("/api/models/(.*)", [this](const httplib::Request& req, httplib::Response& res) {
- handleModelInfo(req, res);
- });
- m_httpServer->Post("/api/models/(.*)/load", [this](const httplib::Request& req, httplib::Response& res) {
- handleLoadModelById(req, res);
- });
- m_httpServer->Post("/api/models/(.*)/unload", [this](const httplib::Request& req, httplib::Response& res) {
- handleUnloadModelById(req, res);
- });
- // Model management endpoints
- m_httpServer->Get("/api/models/types", [this](const httplib::Request& req, httplib::Response& res) {
- handleModelTypes(req, res);
- });
- m_httpServer->Get("/api/models/directories", [this](const httplib::Request& req, httplib::Response& res) {
- handleModelDirectories(req, res);
- });
- m_httpServer->Post("/api/models/refresh", [this](const httplib::Request& req, httplib::Response& res) {
- handleRefreshModels(req, res);
- });
- m_httpServer->Post("/api/models/hash", [this](const httplib::Request& req, httplib::Response& res) {
- handleHashModels(req, res);
- });
- m_httpServer->Post("/api/models/convert", [this](const httplib::Request& req, httplib::Response& res) {
- handleConvertModel(req, res);
- });
- m_httpServer->Get("/api/models/stats", [this](const httplib::Request& req, httplib::Response& res) {
- handleModelStats(req, res);
- });
- m_httpServer->Post("/api/models/batch", [this](const httplib::Request& req, httplib::Response& res) {
- handleBatchModels(req, res);
- });
- // Model validation endpoints
- m_httpServer->Post("/api/models/validate", [this](const httplib::Request& req, httplib::Response& res) {
- handleValidateModel(req, res);
- });
- m_httpServer->Post("/api/models/compatible", [this](const httplib::Request& req, httplib::Response& res) {
- handleCheckCompatibility(req, res);
- });
- m_httpServer->Post("/api/models/requirements", [this](const httplib::Request& req, httplib::Response& res) {
- handleModelRequirements(req, res);
- });
- // Queue status endpoint
- m_httpServer->Get("/api/queue/status", [this](const httplib::Request& req, httplib::Response& res) {
- handleQueueStatus(req, res);
- });
- // Download job output file endpoint (must be before job status endpoint to match more specific pattern first)
- m_httpServer->Get("/api/queue/job/(.*)/output/(.*)", [this](const httplib::Request& req, httplib::Response& res) {
- handleDownloadOutput(req, res);
- });
- // Job status endpoint
- m_httpServer->Get("/api/queue/job/(.*)", [this](const httplib::Request& req, httplib::Response& res) {
- handleJobStatus(req, res);
- });
- // Cancel job endpoint
- m_httpServer->Post("/api/queue/cancel", [this](const httplib::Request& req, httplib::Response& res) {
- handleCancelJob(req, res);
- });
- // Clear queue endpoint
- m_httpServer->Post("/api/queue/clear", [this](const httplib::Request& req, httplib::Response& res) {
- handleClearQueue(req, res);
- });
- // Serve static web UI files if uiDir is configured
- if (!m_uiDir.empty() && std::filesystem::exists(m_uiDir)) {
- std::cout << "Serving static UI files from: " << m_uiDir << " at /ui" << std::endl;
- // Read UI version from version.json if available
- std::string uiVersion = "unknown";
- std::string versionFilePath = m_uiDir + "/version.json";
- if (std::filesystem::exists(versionFilePath)) {
- try {
- std::ifstream versionFile(versionFilePath);
- if (versionFile.is_open()) {
- nlohmann::json versionData = nlohmann::json::parse(versionFile);
- if (versionData.contains("version")) {
- uiVersion = versionData["version"].get<std::string>();
- }
- versionFile.close();
- }
- } catch (const std::exception& e) {
- std::cerr << "Failed to read UI version: " << e.what() << std::endl;
- }
- }
- std::cout << "UI version: " << uiVersion << std::endl;
- // Serve dynamic config.js that provides runtime configuration to the web UI
- m_httpServer->Get("/ui/config.js", [this, uiVersion](const httplib::Request& req, httplib::Response& res) {
- // Generate JavaScript configuration with current server settings
- std::ostringstream configJs;
- configJs << "// Auto-generated configuration\n"
- << "window.__SERVER_CONFIG__ = {\n"
- << " apiUrl: 'http://" << m_host << ":" << m_port << "',\n"
- << " apiBasePath: '/api',\n"
- << " host: '" << m_host << "',\n"
- << " port: " << m_port << ",\n"
- << " uiVersion: '" << uiVersion << "'\n"
- << "};\n";
- // No cache for config.js - always fetch fresh
- res.set_header("Cache-Control", "no-cache, no-store, must-revalidate");
- res.set_header("Pragma", "no-cache");
- res.set_header("Expires", "0");
- res.set_content(configJs.str(), "application/javascript");
- });
- // Set up file request handler for caching static assets
- m_httpServer->set_file_request_handler([uiVersion](const httplib::Request& req, httplib::Response& res) {
- // Add cache headers based on file type and version
- std::string path = req.path;
- // For versioned static assets (.js, .css, images), use long cache
- if (path.find("/_next/") != std::string::npos ||
- path.find(".js") != std::string::npos ||
- path.find(".css") != std::string::npos ||
- path.find(".png") != std::string::npos ||
- path.find(".jpg") != std::string::npos ||
- path.find(".svg") != std::string::npos ||
- path.find(".ico") != std::string::npos ||
- path.find(".woff") != std::string::npos ||
- path.find(".woff2") != std::string::npos ||
- path.find(".ttf") != std::string::npos) {
- // Long cache (1 year) for static assets
- res.set_header("Cache-Control", "public, max-age=31536000, immutable");
- // Add ETag based on UI version for cache validation
- res.set_header("ETag", "\"" + uiVersion + "\"");
- // Check If-None-Match for conditional requests
- if (req.has_header("If-None-Match")) {
- std::string clientETag = req.get_header_value("If-None-Match");
- if (clientETag == "\"" + uiVersion + "\"") {
- res.status = 304; // Not Modified
- return;
- }
- }
- } else if (path.find(".html") != std::string::npos || path == "/ui/" || path == "/ui") {
- // HTML files should revalidate but can be cached briefly
- res.set_header("Cache-Control", "public, max-age=0, must-revalidate");
- res.set_header("ETag", "\"" + uiVersion + "\"");
- }
- });
- // Mount the static file directory at /ui
- if (!m_httpServer->set_mount_point("/ui", m_uiDir)) {
- std::cerr << "Failed to mount UI directory: " << m_uiDir << std::endl;
- }
- // Redirect /ui to /ui/ to ensure proper routing
- m_httpServer->Get("/ui", [](const httplib::Request& req, httplib::Response& res) {
- res.set_redirect("/ui/");
- });
- }
- }
- void Server::setupCORS() {
- // Use post-routing handler to set CORS headers after the response is generated
- // This ensures we don't duplicate headers that may be set by other handlers
- m_httpServer->set_post_routing_handler([](const httplib::Request& req, httplib::Response& res) {
- // Only add CORS headers if they haven't been set already
- if (!res.has_header("Access-Control-Allow-Origin")) {
- res.set_header("Access-Control-Allow-Origin", "*");
- }
- if (!res.has_header("Access-Control-Allow-Methods")) {
- res.set_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS");
- }
- if (!res.has_header("Access-Control-Allow-Headers")) {
- res.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
- }
- });
- // Handle OPTIONS requests for CORS preflight (API endpoints only)
- m_httpServer->Options("/api/.*", [](const httplib::Request&, httplib::Response& res) {
- res.set_header("Access-Control-Allow-Origin", "*");
- res.set_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS");
- res.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
- res.status = 200;
- });
- }
- void Server::handleHealthCheck(const httplib::Request& req, httplib::Response& res) {
- try {
- json response = {
- {"status", "healthy"},
- {"timestamp", std::chrono::duration_cast<std::chrono::seconds>(
- std::chrono::system_clock::now().time_since_epoch()).count()},
- {"version", "1.0.0"}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Health check failed: ") + e.what(), 500);
- }
- }
- void Server::handleApiStatus(const httplib::Request& req, httplib::Response& res) {
- try {
- json response = {
- {"server", {
- {"running", m_isRunning.load()},
- {"host", m_host},
- {"port", m_port}
- }},
- {"generation_queue", {
- {"running", m_generationQueue ? m_generationQueue->isRunning() : false},
- {"queue_size", m_generationQueue ? m_generationQueue->getQueueSize() : 0},
- {"active_generations", m_generationQueue ? m_generationQueue->getActiveGenerations() : 0}
- }},
- {"models", {
- {"loaded_count", m_modelManager ? m_modelManager->getLoadedModelsCount() : 0},
- {"available_count", m_modelManager ? m_modelManager->getAvailableModelsCount() : 0}
- }}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Status check failed: ") + e.what(), 500);
- }
- }
- void Server::handleModelsList(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Parse query parameters for enhanced filtering
- std::string typeFilter = req.get_param_value("type");
- std::string searchQuery = req.get_param_value("search");
- std::string sortBy = req.get_param_value("sort_by");
- std::string sortOrder = req.get_param_value("sort_order");
- std::string dateFilter = req.get_param_value("date");
- std::string sizeFilter = req.get_param_value("size");
- // Pagination parameters
- int page = 1;
- int limit = 50;
- try {
- if (!req.get_param_value("page").empty()) {
- page = std::stoi(req.get_param_value("page"));
- if (page < 1) page = 1;
- }
- if (!req.get_param_value("limit").empty()) {
- limit = std::stoi(req.get_param_value("limit"));
- if (limit < 1) limit = 1;
- if (limit > 200) limit = 200; // Max limit to prevent performance issues
- }
- } catch (const std::exception& e) {
- sendErrorResponse(res, "Invalid pagination parameters", 400, "INVALID_PAGINATION", requestId);
- return;
- }
- // Filter parameters
- bool includeLoaded = req.get_param_value("loaded") == "true";
- bool includeUnloaded = req.get_param_value("unloaded") == "true";
- bool includeMetadata = req.get_param_value("include_metadata") == "true";
- bool includeThumbnails = req.get_param_value("include_thumbnails") == "true";
- // Get all models
- auto allModels = m_modelManager->getAllModels();
- json models = json::array();
- // Apply filters and build response
- for (const auto& pair : allModels) {
- const auto& modelInfo = pair.second;
- // Apply type filter
- if (!typeFilter.empty()) {
- ModelType filterType = ModelManager::stringToModelType(typeFilter);
- if (modelInfo.type != filterType) continue;
- }
- // Apply loaded/unloaded filters
- if (includeLoaded && !modelInfo.isLoaded) continue;
- if (includeUnloaded && modelInfo.isLoaded) continue;
- // Apply search filter (case-insensitive search in name and description)
- if (!searchQuery.empty()) {
- std::string searchLower = searchQuery;
- std::transform(searchLower.begin(), searchLower.end(), searchLower.begin(), ::tolower);
- std::string nameLower = modelInfo.name;
- std::transform(nameLower.begin(), nameLower.end(), nameLower.begin(), ::tolower);
- std::string descLower = modelInfo.description;
- std::transform(descLower.begin(), descLower.end(), descLower.begin(), ::tolower);
- if (nameLower.find(searchLower) == std::string::npos &&
- descLower.find(searchLower) == std::string::npos) {
- continue;
- }
- }
- // Apply date filter (simplified - expects "recent", "old", or YYYY-MM-DD)
- if (!dateFilter.empty()) {
- auto now = std::filesystem::file_time_type::clock::now();
- auto modelTime = modelInfo.modifiedAt;
- auto duration = std::chrono::duration_cast<std::chrono::hours>(now - modelTime).count();
- if (dateFilter == "recent" && duration > 24 * 7) continue; // Older than 1 week
- if (dateFilter == "old" && duration < 24 * 30) continue; // Newer than 1 month
- }
- // Apply size filter (expects "small", "medium", "large", or size in MB)
- if (!sizeFilter.empty()) {
- double sizeMB = modelInfo.fileSize / (1024.0 * 1024.0);
- if (sizeFilter == "small" && sizeMB > 1024) continue; // > 1GB
- if (sizeFilter == "medium" && (sizeMB < 1024 || sizeMB > 4096)) continue; // < 1GB or > 4GB
- if (sizeFilter == "large" && sizeMB < 4096) continue; // < 4GB
- // Try to parse as specific size in MB
- try {
- double maxSizeMB = std::stod(sizeFilter);
- if (sizeMB > maxSizeMB) continue;
- } catch (...) {
- // Ignore if parsing fails
- }
- }
- // Build model JSON with only essential information
- json modelJson = {
- {"name", modelInfo.name},
- {"type", ModelManager::modelTypeToString(modelInfo.type)},
- {"file_size", modelInfo.fileSize},
- {"file_size_mb", modelInfo.fileSize / (1024.0 * 1024.0)},
- {"sha256", modelInfo.sha256.empty() ? nullptr : json(modelInfo.sha256)},
- {"sha256_short", (modelInfo.sha256.empty() || modelInfo.sha256.length() < 10) ? nullptr : json(modelInfo.sha256.substr(0, 10))}
- };
- // Add architecture information if available (checkpoints only)
- if (!modelInfo.architecture.empty()) {
- modelJson["architecture"] = modelInfo.architecture;
- modelJson["recommended_vae"] = modelInfo.recommendedVAE.empty() ? nullptr : json(modelInfo.recommendedVAE);
- if (modelInfo.recommendedWidth > 0) {
- modelJson["recommended_width"] = modelInfo.recommendedWidth;
- }
- if (modelInfo.recommendedHeight > 0) {
- modelJson["recommended_height"] = modelInfo.recommendedHeight;
- }
- if (modelInfo.recommendedSteps > 0) {
- modelJson["recommended_steps"] = modelInfo.recommendedSteps;
- }
- if (!modelInfo.recommendedSampler.empty()) {
- modelJson["recommended_sampler"] = modelInfo.recommendedSampler;
- }
- if (!modelInfo.requiredModels.empty()) {
- modelJson["required_models"] = modelInfo.requiredModels;
- }
- if (!modelInfo.missingModels.empty()) {
- modelJson["missing_models"] = modelInfo.missingModels;
- modelJson["has_missing_dependencies"] = true;
- } else {
- modelJson["has_missing_dependencies"] = false;
- }
- }
- models.push_back(modelJson);
- }
- // Apply sorting
- if (!sortBy.empty()) {
- std::sort(models.begin(), models.end(), [&sortBy, &sortOrder](const json& a, const json& b) {
- bool ascending = sortOrder != "desc";
- if (sortBy == "name") {
- return ascending ? a["name"] < b["name"] : a["name"] > b["name"];
- } else if (sortBy == "size") {
- return ascending ? a["file_size"] < b["file_size"] : a["file_size"] > b["file_size"];
- } else if (sortBy == "date") {
- return ascending ? a["last_modified"] < b["last_modified"] : a["last_modified"] > b["last_modified"];
- } else if (sortBy == "type") {
- return ascending ? a["type"] < b["type"] : a["type"] > b["type"];
- } else if (sortBy == "loaded") {
- return ascending ? a["is_loaded"] < b["is_loaded"] : a["is_loaded"] > b["is_loaded"];
- }
- return false;
- });
- }
- // Apply pagination
- int totalCount = models.size();
- int totalPages = (totalCount + limit - 1) / limit;
- int startIndex = (page - 1) * limit;
- int endIndex = std::min(startIndex + limit, totalCount);
- json paginatedModels = json::array();
- for (int i = startIndex; i < endIndex; ++i) {
- paginatedModels.push_back(models[i]);
- }
- // Build comprehensive response
- json response = {
- {"models", paginatedModels},
- {"pagination", {
- {"page", page},
- {"limit", limit},
- {"total_count", totalCount},
- {"total_pages", totalPages},
- {"has_next", page < totalPages},
- {"has_prev", page > 1}
- }},
- {"filters_applied", {
- {"type", typeFilter.empty() ? json(nullptr) : json(typeFilter)},
- {"search", searchQuery.empty() ? json(nullptr) : json(searchQuery)},
- {"date", dateFilter.empty() ? json(nullptr) : json(dateFilter)},
- {"size", sizeFilter.empty() ? json(nullptr) : json(sizeFilter)},
- {"loaded", includeLoaded ? json(true) : json(nullptr)},
- {"unloaded", includeUnloaded ? json(true) : json(nullptr)}
- }},
- {"sorting", {
- {"sort_by", sortBy.empty() ? "name" : json(sortBy)},
- {"sort_order", sortOrder.empty() ? "asc" : json(sortOrder)}
- }},
- {"statistics", {
- {"loaded_count", m_modelManager->getLoadedModelsCount()},
- {"available_count", m_modelManager->getAvailableModelsCount()}
- }},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to list models: ") + e.what(), 500, "MODEL_LIST_ERROR", requestId);
- }
- }
- void Server::handleQueueStatus(const httplib::Request& req, httplib::Response& res) {
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500);
- return;
- }
- // Get detailed queue status
- auto jobs = m_generationQueue->getQueueStatus();
- // Convert jobs to JSON
- json jobsJson = json::array();
- for (const auto& job : jobs) {
- std::string statusStr;
- switch (job.status) {
- case GenerationStatus::QUEUED: statusStr = "queued"; break;
- case GenerationStatus::PROCESSING: statusStr = "processing"; break;
- case GenerationStatus::COMPLETED: statusStr = "completed"; break;
- case GenerationStatus::FAILED: statusStr = "failed"; break;
- }
- // Convert time points to timestamps
- auto queuedTime = std::chrono::duration_cast<std::chrono::milliseconds>(
- job.queuedTime.time_since_epoch()).count();
- auto startTime = std::chrono::duration_cast<std::chrono::milliseconds>(
- job.startTime.time_since_epoch()).count();
- auto endTime = std::chrono::duration_cast<std::chrono::milliseconds>(
- job.endTime.time_since_epoch()).count();
- jobsJson.push_back({
- {"id", job.id},
- {"status", statusStr},
- {"prompt", job.prompt},
- {"queued_time", queuedTime},
- {"start_time", startTime > 0 ? json(startTime) : json(nullptr)},
- {"end_time", endTime > 0 ? json(endTime) : json(nullptr)},
- {"position", job.position},
- {"progress", job.progress}
- });
- }
- json response = {
- {"queue", {
- {"size", m_generationQueue->getQueueSize()},
- {"active_generations", m_generationQueue->getActiveGenerations()},
- {"running", m_generationQueue->isRunning()},
- {"jobs", jobsJson}
- }}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Queue status check failed: ") + e.what(), 500);
- }
- }
- void Server::handleJobStatus(const httplib::Request& req, httplib::Response& res) {
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500);
- return;
- }
- // Extract job ID from URL path
- std::string jobId = req.matches[1].str();
- if (jobId.empty()) {
- sendErrorResponse(res, "Missing job ID", 400);
- return;
- }
- // Get job information
- auto jobInfo = m_generationQueue->getJobInfo(jobId);
- if (jobInfo.id.empty()) {
- sendErrorResponse(res, "Job not found", 404);
- return;
- }
- // Convert status to string
- std::string statusStr;
- switch (jobInfo.status) {
- case GenerationStatus::QUEUED: statusStr = "queued"; break;
- case GenerationStatus::PROCESSING: statusStr = "processing"; break;
- case GenerationStatus::COMPLETED: statusStr = "completed"; break;
- case GenerationStatus::FAILED: statusStr = "failed"; break;
- }
- // Convert time points to timestamps
- auto queuedTime = std::chrono::duration_cast<std::chrono::milliseconds>(
- jobInfo.queuedTime.time_since_epoch()).count();
- auto startTime = std::chrono::duration_cast<std::chrono::milliseconds>(
- jobInfo.startTime.time_since_epoch()).count();
- auto endTime = std::chrono::duration_cast<std::chrono::milliseconds>(
- jobInfo.endTime.time_since_epoch()).count();
- // Create download URLs for output files
- json outputUrls = json::array();
- for (const auto& filePath : jobInfo.outputFiles) {
- // Extract filename from full path
- std::filesystem::path p(filePath);
- std::string filename = p.filename().string();
- // Create download URL
- std::string url = "/api/queue/job/" + jobInfo.id + "/output/" + filename;
- json fileInfo = {
- {"filename", filename},
- {"url", url},
- {"path", filePath}
- };
- outputUrls.push_back(fileInfo);
- }
- json response = {
- {"job", {
- {"id", jobInfo.id},
- {"status", statusStr},
- {"prompt", jobInfo.prompt},
- {"queued_time", queuedTime},
- {"start_time", startTime > 0 ? json(startTime) : json(nullptr)},
- {"end_time", endTime > 0 ? json(endTime) : json(nullptr)},
- {"position", jobInfo.position},
- {"outputs", outputUrls},
- {"error_message", jobInfo.errorMessage},
- {"progress", jobInfo.progress}
- }}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Job status check failed: ") + e.what(), 500);
- }
- }
- void Server::handleCancelJob(const httplib::Request& req, httplib::Response& res) {
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500);
- return;
- }
- // Parse JSON request body
- json requestJson = json::parse(req.body);
- // Validate required fields
- if (!requestJson.contains("job_id") || !requestJson["job_id"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'job_id' field", 400);
- return;
- }
- std::string jobId = requestJson["job_id"];
- // Try to cancel the job
- bool cancelled = m_generationQueue->cancelJob(jobId);
- if (cancelled) {
- json response = {
- {"status", "success"},
- {"message", "Job cancelled successfully"},
- {"job_id", jobId}
- };
- sendJsonResponse(res, response);
- } else {
- json response = {
- {"status", "error"},
- {"message", "Job not found or already processing"},
- {"job_id", jobId}
- };
- sendJsonResponse(res, response, 404);
- }
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Job cancellation failed: ") + e.what(), 500);
- }
- }
- void Server::handleClearQueue(const httplib::Request& req, httplib::Response& res) {
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500);
- return;
- }
- // Clear the queue
- m_generationQueue->clearQueue();
- json response = {
- {"status", "success"},
- {"message", "Queue cleared successfully"}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Queue clear failed: ") + e.what(), 500);
- }
- }
- void Server::handleDownloadOutput(const httplib::Request& req, httplib::Response& res) {
- try {
- // Extract job ID and filename from URL path
- if (req.matches.size() < 3) {
- sendErrorResponse(res, "Invalid request: job ID and filename required", 400);
- return;
- }
- std::string jobId = req.matches[1];
- std::string filename = req.matches[2];
- // Construct file path using the same logic as when saving:
- // {outputDir}/{jobId}/{filename}
- std::string fullPath = m_outputDir + "/" + jobId + "/" + filename;
- // Check if file exists
- if (!std::filesystem::exists(fullPath)) {
- sendErrorResponse(res, "Output file not found: " + fullPath, 404);
- return;
- }
- // Check if file exists on filesystem
- std::ifstream file(fullPath, std::ios::binary);
- if (!file.is_open()) {
- sendErrorResponse(res, "Output file not accessible", 404);
- return;
- }
- // Read file contents
- std::ostringstream fileContent;
- fileContent << file.rdbuf();
- file.close();
- // Determine content type based on file extension
- std::string contentType = "application/octet-stream";
- if (Utils::endsWith(filename, ".png")) {
- contentType = "image/png";
- } else if (Utils::endsWith(filename, ".jpg") || Utils::endsWith(filename, ".jpeg")) {
- contentType = "image/jpeg";
- } else if (Utils::endsWith(filename, ".mp4")) {
- contentType = "video/mp4";
- } else if (Utils::endsWith(filename, ".gif")) {
- contentType = "image/gif";
- }
- // Set response headers
- res.set_header("Content-Type", contentType);
- //res.set_header("Content-Disposition", "attachment; filename=\"" + filename + "\"");
- res.set_content(fileContent.str(), contentType);
- res.status = 200;
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to download file: ") + e.what(), 500);
- }
- }
- void Server::sendJsonResponse(httplib::Response& res, const nlohmann::json& json, int status_code) {
- res.set_header("Content-Type", "application/json");
- res.status = status_code;
- res.body = json.dump();
- }
- void Server::sendErrorResponse(httplib::Response& res, const std::string& message, int status_code,
- const std::string& error_code, const std::string& request_id) {
- json errorResponse = {
- {"error", {
- {"message", message},
- {"status_code", status_code},
- {"error_code", error_code},
- {"request_id", request_id},
- {"timestamp", std::chrono::duration_cast<std::chrono::seconds>(
- std::chrono::system_clock::now().time_since_epoch()).count()}
- }}
- };
- sendJsonResponse(res, errorResponse, status_code);
- }
- std::pair<bool, std::string> Server::validateGenerationParameters(const nlohmann::json& params) {
- // Validate required fields
- if (!params.contains("prompt") || !params["prompt"].is_string()) {
- return {false, "Missing or invalid 'prompt' field"};
- }
- const std::string& prompt = params["prompt"];
- if (prompt.empty()) {
- return {false, "Prompt cannot be empty"};
- }
- if (prompt.length() > 10000) {
- return {false, "Prompt too long (max 10000 characters)"};
- }
- // Validate negative prompt if present
- if (params.contains("negative_prompt")) {
- if (!params["negative_prompt"].is_string()) {
- return {false, "Invalid 'negative_prompt' field, must be string"};
- }
- if (params["negative_prompt"].get<std::string>().length() > 10000) {
- return {false, "Negative prompt too long (max 10000 characters)"};
- }
- }
- // Validate width
- if (params.contains("width")) {
- if (!params["width"].is_number_integer()) {
- return {false, "Invalid 'width' field, must be integer"};
- }
- int width = params["width"];
- if (width < 64 || width > 2048 || width % 64 != 0) {
- return {false, "Width must be between 64 and 2048 and divisible by 64"};
- }
- }
- // Validate height
- if (params.contains("height")) {
- if (!params["height"].is_number_integer()) {
- return {false, "Invalid 'height' field, must be integer"};
- }
- int height = params["height"];
- if (height < 64 || height > 2048 || height % 64 != 0) {
- return {false, "Height must be between 64 and 2048 and divisible by 64"};
- }
- }
- // Validate batch count
- if (params.contains("batch_count")) {
- if (!params["batch_count"].is_number_integer()) {
- return {false, "Invalid 'batch_count' field, must be integer"};
- }
- int batchCount = params["batch_count"];
- if (batchCount < 1 || batchCount > 100) {
- return {false, "Batch count must be between 1 and 100"};
- }
- }
- // Validate steps
- if (params.contains("steps")) {
- if (!params["steps"].is_number_integer()) {
- return {false, "Invalid 'steps' field, must be integer"};
- }
- int steps = params["steps"];
- if (steps < 1 || steps > 150) {
- return {false, "Steps must be between 1 and 150"};
- }
- }
- // Validate CFG scale
- if (params.contains("cfg_scale")) {
- if (!params["cfg_scale"].is_number()) {
- return {false, "Invalid 'cfg_scale' field, must be number"};
- }
- float cfgScale = params["cfg_scale"];
- if (cfgScale < 1.0f || cfgScale > 30.0f) {
- return {false, "CFG scale must be between 1.0 and 30.0"};
- }
- }
- // Validate seed
- if (params.contains("seed")) {
- if (!params["seed"].is_string() && !params["seed"].is_number_integer()) {
- return {false, "Invalid 'seed' field, must be string or integer"};
- }
- }
- // Validate sampling method
- if (params.contains("sampling_method")) {
- if (!params["sampling_method"].is_string()) {
- return {false, "Invalid 'sampling_method' field, must be string"};
- }
- std::string method = params["sampling_method"];
- std::vector<std::string> validMethods = {
- "euler", "euler_a", "heun", "dpm2", "dpm++2s_a", "dpm++2m",
- "dpm++2mv2", "ipndm", "ipndm_v", "lcm", "ddim_trailing", "tcd", "default"
- };
- if (std::find(validMethods.begin(), validMethods.end(), method) == validMethods.end()) {
- return {false, "Invalid sampling method"};
- }
- }
- // Validate scheduler
- if (params.contains("scheduler")) {
- if (!params["scheduler"].is_string()) {
- return {false, "Invalid 'scheduler' field, must be string"};
- }
- std::string scheduler = params["scheduler"];
- std::vector<std::string> validSchedulers = {
- "discrete", "karras", "exponential", "ays", "gits",
- "smoothstep", "sgm_uniform", "simple", "default"
- };
- if (std::find(validSchedulers.begin(), validSchedulers.end(), scheduler) == validSchedulers.end()) {
- return {false, "Invalid scheduler"};
- }
- }
- // Validate strength
- if (params.contains("strength")) {
- if (!params["strength"].is_number()) {
- return {false, "Invalid 'strength' field, must be number"};
- }
- float strength = params["strength"];
- if (strength < 0.0f || strength > 1.0f) {
- return {false, "Strength must be between 0.0 and 1.0"};
- }
- }
- // Validate control strength
- if (params.contains("control_strength")) {
- if (!params["control_strength"].is_number()) {
- return {false, "Invalid 'control_strength' field, must be number"};
- }
- float controlStrength = params["control_strength"];
- if (controlStrength < 0.0f || controlStrength > 1.0f) {
- return {false, "Control strength must be between 0.0 and 1.0"};
- }
- }
- // Validate clip skip
- if (params.contains("clip_skip")) {
- if (!params["clip_skip"].is_number_integer()) {
- return {false, "Invalid 'clip_skip' field, must be integer"};
- }
- int clipSkip = params["clip_skip"];
- if (clipSkip < -1 || clipSkip > 12) {
- return {false, "Clip skip must be between -1 and 12"};
- }
- }
- // Validate threads
- if (params.contains("threads")) {
- if (!params["threads"].is_number_integer()) {
- return {false, "Invalid 'threads' field, must be integer"};
- }
- int threads = params["threads"];
- if (threads < -1 || threads > 32) {
- return {false, "Threads must be between -1 (auto) and 32"};
- }
- }
- return {true, ""};
- }
- SamplingMethod Server::parseSamplingMethod(const std::string& method) {
- if (method == "euler") return SamplingMethod::EULER;
- else if (method == "euler_a") return SamplingMethod::EULER_A;
- else if (method == "heun") return SamplingMethod::HEUN;
- else if (method == "dpm2") return SamplingMethod::DPM2;
- else if (method == "dpm++2s_a") return SamplingMethod::DPMPP2S_A;
- else if (method == "dpm++2m") return SamplingMethod::DPMPP2M;
- else if (method == "dpm++2mv2") return SamplingMethod::DPMPP2MV2;
- else if (method == "ipndm") return SamplingMethod::IPNDM;
- else if (method == "ipndm_v") return SamplingMethod::IPNDM_V;
- else if (method == "lcm") return SamplingMethod::LCM;
- else if (method == "ddim_trailing") return SamplingMethod::DDIM_TRAILING;
- else if (method == "tcd") return SamplingMethod::TCD;
- else return SamplingMethod::DEFAULT;
- }
- Scheduler Server::parseScheduler(const std::string& scheduler) {
- if (scheduler == "discrete") return Scheduler::DISCRETE;
- else if (scheduler == "karras") return Scheduler::KARRAS;
- else if (scheduler == "exponential") return Scheduler::EXPONENTIAL;
- else if (scheduler == "ays") return Scheduler::AYS;
- else if (scheduler == "gits") return Scheduler::GITS;
- else if (scheduler == "smoothstep") return Scheduler::SMOOTHSTEP;
- else if (scheduler == "sgm_uniform") return Scheduler::SGM_UNIFORM;
- else if (scheduler == "simple") return Scheduler::SIMPLE;
- else return Scheduler::DEFAULT;
- }
- std::string Server::generateRequestId() {
- std::random_device rd;
- std::mt19937 gen(rd());
- std::uniform_int_distribution<> dis(100000, 999999);
- return "req_" + std::to_string(dis(gen));
- }
- std::tuple<std::vector<uint8_t>, int, int, int, bool, std::string>
- Server::loadImageFromInput(const std::string& input) {
- std::vector<uint8_t> imageData;
- int width = 0, height = 0, channels = 0;
- // Auto-detect input source type
- // 1. Check if input is a URL (starts with http:// or https://)
- if (Utils::startsWith(input, "http://") || Utils::startsWith(input, "https://")) {
- // Parse URL to extract host and path
- std::string url = input;
- std::string scheme, host, path;
- int port = 80;
- // Determine scheme and port
- if (Utils::startsWith(url, "https://")) {
- scheme = "https";
- port = 443;
- url = url.substr(8); // Remove "https://"
- } else {
- scheme = "http";
- port = 80;
- url = url.substr(7); // Remove "http://"
- }
- // Extract host and path
- size_t slashPos = url.find('/');
- if (slashPos != std::string::npos) {
- host = url.substr(0, slashPos);
- path = url.substr(slashPos);
- } else {
- host = url;
- path = "/";
- }
- // Check for custom port
- size_t colonPos = host.find(':');
- if (colonPos != std::string::npos) {
- try {
- port = std::stoi(host.substr(colonPos + 1));
- host = host.substr(0, colonPos);
- } catch (...) {
- return {imageData, 0, 0, 0, false, "Invalid port in URL"};
- }
- }
- // Download image using httplib
- try {
- httplib::Result res;
- if (scheme == "https") {
- #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
- httplib::SSLClient client(host, port);
- client.set_follow_location(true);
- client.set_connection_timeout(30, 0); // 30 seconds
- client.set_read_timeout(60, 0); // 60 seconds
- res = client.Get(path.c_str());
- #else
- return {imageData, 0, 0, 0, false, "HTTPS not supported (OpenSSL not available)"};
- #endif
- } else {
- httplib::Client client(host, port);
- client.set_follow_location(true);
- client.set_connection_timeout(30, 0); // 30 seconds
- client.set_read_timeout(60, 0); // 60 seconds
- res = client.Get(path.c_str());
- }
- if (!res) {
- return {imageData, 0, 0, 0, false, "Failed to download image from URL: Connection error"};
- }
- if (res->status != 200) {
- return {imageData, 0, 0, 0, false, "Failed to download image from URL: HTTP " + std::to_string(res->status)};
- }
- // Convert response body to vector
- std::vector<uint8_t> downloadedData(res->body.begin(), res->body.end());
- // Load image from memory
- int w, h, c;
- unsigned char* pixels = stbi_load_from_memory(
- downloadedData.data(),
- downloadedData.size(),
- &w, &h, &c,
- 3 // Force RGB
- );
- if (!pixels) {
- return {imageData, 0, 0, 0, false, "Failed to decode image from URL"};
- }
- width = w;
- height = h;
- channels = 3;
- size_t dataSize = width * height * channels;
- imageData.resize(dataSize);
- std::memcpy(imageData.data(), pixels, dataSize);
- stbi_image_free(pixels);
- } catch (const std::exception& e) {
- return {imageData, 0, 0, 0, false, "Failed to download image from URL: " + std::string(e.what())};
- }
- }
- // 2. Check if input is base64 encoded data URI (starts with "data:image")
- else if (Utils::startsWith(input, "data:image")) {
- // Extract base64 data after the comma
- size_t commaPos = input.find(',');
- if (commaPos == std::string::npos) {
- return {imageData, 0, 0, 0, false, "Invalid data URI format"};
- }
- std::string base64Data = input.substr(commaPos + 1);
- std::vector<uint8_t> decodedData = Utils::base64Decode(base64Data);
- // Load image from memory using stb_image
- int w, h, c;
- unsigned char* pixels = stbi_load_from_memory(
- decodedData.data(),
- decodedData.size(),
- &w, &h, &c,
- 3 // Force RGB
- );
- if (!pixels) {
- return {imageData, 0, 0, 0, false, "Failed to decode image from base64 data URI"};
- }
- width = w;
- height = h;
- channels = 3; // We forced RGB
- // Copy pixel data
- size_t dataSize = width * height * channels;
- imageData.resize(dataSize);
- std::memcpy(imageData.data(), pixels, dataSize);
- stbi_image_free(pixels);
- }
- // 3. Check if input is raw base64 (long string without slashes, likely base64)
- else if (input.length() > 100 && input.find('/') == std::string::npos && input.find('.') == std::string::npos) {
- // Likely raw base64 without data URI prefix
- std::vector<uint8_t> decodedData = Utils::base64Decode(input);
- int w, h, c;
- unsigned char* pixels = stbi_load_from_memory(
- decodedData.data(),
- decodedData.size(),
- &w, &h, &c,
- 3 // Force RGB
- );
- if (!pixels) {
- return {imageData, 0, 0, 0, false, "Failed to decode image from base64"};
- }
- width = w;
- height = h;
- channels = 3;
- size_t dataSize = width * height * channels;
- imageData.resize(dataSize);
- std::memcpy(imageData.data(), pixels, dataSize);
- stbi_image_free(pixels);
- }
- // 4. Treat as local file path
- else {
- int w, h, c;
- unsigned char* pixels = stbi_load(input.c_str(), &w, &h, &c, 3);
- if (!pixels) {
- return {imageData, 0, 0, 0, false, "Failed to load image from file: " + input};
- }
- width = w;
- height = h;
- channels = 3;
- size_t dataSize = width * height * channels;
- imageData.resize(dataSize);
- std::memcpy(imageData.data(), pixels, dataSize);
- stbi_image_free(pixels);
- }
- return {imageData, width, height, channels, true, ""};
- }
- std::string Server::samplingMethodToString(SamplingMethod method) {
- switch (method) {
- case SamplingMethod::EULER: return "euler";
- case SamplingMethod::EULER_A: return "euler_a";
- case SamplingMethod::HEUN: return "heun";
- case SamplingMethod::DPM2: return "dpm2";
- case SamplingMethod::DPMPP2S_A: return "dpm++2s_a";
- case SamplingMethod::DPMPP2M: return "dpm++2m";
- case SamplingMethod::DPMPP2MV2: return "dpm++2mv2";
- case SamplingMethod::IPNDM: return "ipndm";
- case SamplingMethod::IPNDM_V: return "ipndm_v";
- case SamplingMethod::LCM: return "lcm";
- case SamplingMethod::DDIM_TRAILING: return "ddim_trailing";
- case SamplingMethod::TCD: return "tcd";
- default: return "default";
- }
- }
- std::string Server::schedulerToString(Scheduler scheduler) {
- switch (scheduler) {
- case Scheduler::DISCRETE: return "discrete";
- case Scheduler::KARRAS: return "karras";
- case Scheduler::EXPONENTIAL: return "exponential";
- case Scheduler::AYS: return "ays";
- case Scheduler::GITS: return "gits";
- case Scheduler::SMOOTHSTEP: return "smoothstep";
- case Scheduler::SGM_UNIFORM: return "sgm_uniform";
- case Scheduler::SIMPLE: return "simple";
- default: return "default";
- }
- }
- uint64_t Server::estimateGenerationTime(const GenerationRequest& request) {
- // Basic estimation based on parameters
- uint64_t baseTime = 1000; // 1 second base time
- // Factor in steps
- baseTime *= request.steps;
- // Factor in resolution
- double resolutionFactor = (request.width * request.height) / (512.0 * 512.0);
- baseTime = static_cast<uint64_t>(baseTime * resolutionFactor);
- // Factor in batch count
- baseTime *= request.batchCount;
- // Adjust for sampling method (some are faster than others)
- switch (request.samplingMethod) {
- case SamplingMethod::LCM:
- baseTime /= 4; // LCM is much faster
- break;
- case SamplingMethod::EULER:
- case SamplingMethod::EULER_A:
- baseTime *= 0.8; // Euler methods are faster
- break;
- case SamplingMethod::DPM2:
- case SamplingMethod::DPMPP2S_A:
- baseTime *= 1.2; // DPM methods are slower
- break;
- default:
- break;
- }
- return baseTime;
- }
- size_t Server::estimateMemoryUsage(const GenerationRequest& request) {
- // Basic memory estimation in bytes
- size_t baseMemory = 1024 * 1024 * 1024; // 1GB base
- // Factor in resolution
- double resolutionFactor = (request.width * request.height) / (512.0 * 512.0);
- baseMemory = static_cast<size_t>(baseMemory * resolutionFactor);
- // Factor in batch count
- baseMemory *= request.batchCount;
- // Additional memory for certain features
- if (request.diffusionFlashAttn) {
- baseMemory += 512 * 1024 * 1024; // Extra 512MB for flash attention
- }
- if (!request.controlNetPath.empty()) {
- baseMemory += 1024 * 1024 * 1024; // Extra 1GB for ControlNet
- }
- return baseMemory;
- }
- // Specialized generation endpoints
- void Server::handleText2Img(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
- return;
- }
- json requestJson = json::parse(req.body);
- // Validate required fields for text2img
- if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Validate all parameters
- auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
- if (!isValid) {
- sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Check if any model is loaded
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Get currently loaded checkpoint model
- auto allModels = m_modelManager->getAllModels();
- std::string loadedModelName;
- for (const auto& [modelName, modelInfo] : allModels) {
- if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
- loadedModelName = modelName;
- break;
- }
- }
- if (loadedModelName.empty()) {
- sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
- return;
- }
- // Create generation request specifically for text2img
- GenerationRequest genRequest;
- genRequest.id = requestId;
- genRequest.modelName = loadedModelName; // Use the currently loaded model
- genRequest.prompt = requestJson["prompt"];
- genRequest.negativePrompt = requestJson.value("negative_prompt", "");
- genRequest.width = requestJson.value("width", 512);
- genRequest.height = requestJson.value("height", 512);
- genRequest.batchCount = requestJson.value("batch_count", 1);
- genRequest.steps = requestJson.value("steps", 20);
- genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
- genRequest.seed = requestJson.value("seed", "random");
- // Parse optional parameters
- if (requestJson.contains("sampling_method")) {
- genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
- }
- if (requestJson.contains("scheduler")) {
- genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
- }
- // Set text2img specific defaults
- genRequest.strength = 1.0f; // Full strength for text2img
- // Optional VAE model
- if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
- std::string vaeModelId = requestJson["vae_model"];
- if (!vaeModelId.empty()) {
- auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
- if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
- genRequest.vaePath = vaeInfo.path;
- } else {
- sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
- return;
- }
- }
- }
- // Optional TAESD model
- if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
- std::string taesdModelId = requestJson["taesd_model"];
- if (!taesdModelId.empty()) {
- auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
- if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
- genRequest.taesdPath = taesdInfo.path;
- } else {
- sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
- return;
- }
- }
- }
- // Enqueue request
- auto future = m_generationQueue->enqueueRequest(genRequest);
- json params = {
- {"prompt", genRequest.prompt},
- {"negative_prompt", genRequest.negativePrompt},
- {"model", genRequest.modelName},
- {"width", genRequest.width},
- {"height", genRequest.height},
- {"batch_count", genRequest.batchCount},
- {"steps", genRequest.steps},
- {"cfg_scale", genRequest.cfgScale},
- {"seed", genRequest.seed},
- {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
- {"scheduler", schedulerToString(genRequest.scheduler)}
- };
- // Add VAE/TAESD if specified
- if (!genRequest.vaePath.empty()) {
- params["vae_model"] = requestJson.value("vae_model", "");
- }
- if (!genRequest.taesdPath.empty()) {
- params["taesd_model"] = requestJson.value("taesd_model", "");
- }
- json response = {
- {"request_id", requestId},
- {"status", "queued"},
- {"message", "Text-to-image generation request queued successfully"},
- {"queue_position", m_generationQueue->getQueueSize()},
- {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
- {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
- {"type", "text2img"},
- {"parameters", params}
- };
- sendJsonResponse(res, response, 202);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Text-to-image request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleImg2Img(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
- return;
- }
- json requestJson = json::parse(req.body);
- // Validate required fields for img2img
- if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- if (!requestJson.contains("init_image") || !requestJson["init_image"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'init_image' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Validate all parameters
- auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
- if (!isValid) {
- sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Check if any model is loaded
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Get currently loaded checkpoint model
- auto allModels = m_modelManager->getAllModels();
- std::string loadedModelName;
- for (const auto& [modelName, modelInfo] : allModels) {
- if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
- loadedModelName = modelName;
- break;
- }
- }
- if (loadedModelName.empty()) {
- sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
- return;
- }
- // Load the init image
- std::string initImageInput = requestJson["init_image"];
- auto [imageData, imgWidth, imgHeight, imgChannels, success, loadError] = loadImageFromInput(initImageInput);
- if (!success) {
- sendErrorResponse(res, "Failed to load init image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId);
- return;
- }
- // Create generation request specifically for img2img
- GenerationRequest genRequest;
- genRequest.id = requestId;
- genRequest.requestType = GenerationRequest::RequestType::IMG2IMG;
- genRequest.modelName = loadedModelName; // Use the currently loaded model
- genRequest.prompt = requestJson["prompt"];
- genRequest.negativePrompt = requestJson.value("negative_prompt", "");
- genRequest.width = requestJson.value("width", imgWidth); // Default to input image dimensions
- genRequest.height = requestJson.value("height", imgHeight);
- genRequest.batchCount = requestJson.value("batch_count", 1);
- genRequest.steps = requestJson.value("steps", 20);
- genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
- genRequest.seed = requestJson.value("seed", "random");
- genRequest.strength = requestJson.value("strength", 0.75f);
- // Set init image data
- genRequest.initImageData = imageData;
- genRequest.initImageWidth = imgWidth;
- genRequest.initImageHeight = imgHeight;
- genRequest.initImageChannels = imgChannels;
- // Parse optional parameters
- if (requestJson.contains("sampling_method")) {
- genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
- }
- if (requestJson.contains("scheduler")) {
- genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
- }
- // Optional VAE model
- if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
- std::string vaeModelId = requestJson["vae_model"];
- if (!vaeModelId.empty()) {
- auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
- if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
- genRequest.vaePath = vaeInfo.path;
- } else {
- sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
- return;
- }
- }
- }
- // Optional TAESD model
- if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
- std::string taesdModelId = requestJson["taesd_model"];
- if (!taesdModelId.empty()) {
- auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
- if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
- genRequest.taesdPath = taesdInfo.path;
- } else {
- sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
- return;
- }
- }
- }
- // Enqueue request
- auto future = m_generationQueue->enqueueRequest(genRequest);
- json params = {
- {"prompt", genRequest.prompt},
- {"negative_prompt", genRequest.negativePrompt},
- {"init_image", requestJson["init_image"]},
- {"model", genRequest.modelName},
- {"width", genRequest.width},
- {"height", genRequest.height},
- {"batch_count", genRequest.batchCount},
- {"steps", genRequest.steps},
- {"cfg_scale", genRequest.cfgScale},
- {"seed", genRequest.seed},
- {"strength", genRequest.strength},
- {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
- {"scheduler", schedulerToString(genRequest.scheduler)}
- };
- // Add VAE/TAESD if specified
- if (!genRequest.vaePath.empty()) {
- params["vae_model"] = requestJson.value("vae_model", "");
- }
- if (!genRequest.taesdPath.empty()) {
- params["taesd_model"] = requestJson.value("taesd_model", "");
- }
- json response = {
- {"request_id", requestId},
- {"status", "queued"},
- {"message", "Image-to-image generation request queued successfully"},
- {"queue_position", m_generationQueue->getQueueSize()},
- {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
- {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
- {"type", "img2img"},
- {"parameters", params}
- };
- sendJsonResponse(res, response, 202);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Image-to-image request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleControlNet(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
- return;
- }
- json requestJson = json::parse(req.body);
- // Validate required fields for ControlNet
- if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- if (!requestJson.contains("control_image") || !requestJson["control_image"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'control_image' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Validate all parameters
- auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
- if (!isValid) {
- sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Check if any model is loaded
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Get currently loaded checkpoint model
- auto allModels = m_modelManager->getAllModels();
- std::string loadedModelName;
- for (const auto& [modelName, modelInfo] : allModels) {
- if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
- loadedModelName = modelName;
- break;
- }
- }
- if (loadedModelName.empty()) {
- sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
- return;
- }
- // Create generation request specifically for ControlNet
- GenerationRequest genRequest;
- genRequest.id = requestId;
- genRequest.modelName = loadedModelName; // Use the currently loaded model
- genRequest.prompt = requestJson["prompt"];
- genRequest.negativePrompt = requestJson.value("negative_prompt", "");
- genRequest.width = requestJson.value("width", 512);
- genRequest.height = requestJson.value("height", 512);
- genRequest.batchCount = requestJson.value("batch_count", 1);
- genRequest.steps = requestJson.value("steps", 20);
- genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
- genRequest.seed = requestJson.value("seed", "random");
- genRequest.controlStrength = requestJson.value("control_strength", 0.9f);
- genRequest.controlNetPath = requestJson.value("control_net_model", "");
- // Parse optional parameters
- if (requestJson.contains("sampling_method")) {
- genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
- }
- if (requestJson.contains("scheduler")) {
- genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
- }
- // Optional VAE model
- if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
- std::string vaeModelId = requestJson["vae_model"];
- if (!vaeModelId.empty()) {
- auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
- if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
- genRequest.vaePath = vaeInfo.path;
- } else {
- sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
- return;
- }
- }
- }
- // Optional TAESD model
- if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
- std::string taesdModelId = requestJson["taesd_model"];
- if (!taesdModelId.empty()) {
- auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
- if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
- genRequest.taesdPath = taesdInfo.path;
- } else {
- sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
- return;
- }
- }
- }
- // Store control image path (would be handled in actual implementation)
- genRequest.outputPath = requestJson.value("control_image", "");
- // Enqueue request
- auto future = m_generationQueue->enqueueRequest(genRequest);
- json params = {
- {"prompt", genRequest.prompt},
- {"negative_prompt", genRequest.negativePrompt},
- {"control_image", requestJson["control_image"]},
- {"control_net_model", genRequest.controlNetPath},
- {"model", genRequest.modelName},
- {"width", genRequest.width},
- {"height", genRequest.height},
- {"batch_count", genRequest.batchCount},
- {"steps", genRequest.steps},
- {"cfg_scale", genRequest.cfgScale},
- {"seed", genRequest.seed},
- {"control_strength", genRequest.controlStrength},
- {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
- {"scheduler", schedulerToString(genRequest.scheduler)}
- };
- // Add VAE/TAESD if specified
- if (!genRequest.vaePath.empty()) {
- params["vae_model"] = requestJson.value("vae_model", "");
- }
- if (!genRequest.taesdPath.empty()) {
- params["taesd_model"] = requestJson.value("taesd_model", "");
- }
- json response = {
- {"request_id", requestId},
- {"status", "queued"},
- {"message", "ControlNet generation request queued successfully"},
- {"queue_position", m_generationQueue->getQueueSize()},
- {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
- {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
- {"type", "controlnet"},
- {"parameters", params}
- };
- sendJsonResponse(res, response, 202);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("ControlNet request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleUpscale(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
- return;
- }
- json requestJson = json::parse(req.body);
- // Validate required fields for upscaler
- if (!requestJson.contains("image") || !requestJson["image"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'image' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- if (!requestJson.contains("esrgan_model") || !requestJson["esrgan_model"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'esrgan_model' field (model hash or name)", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Check if model manager is available
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Get the ESRGAN/upscaler model
- std::string esrganModelId = requestJson["esrgan_model"];
- auto modelInfo = m_modelManager->getModelInfo(esrganModelId);
- if (modelInfo.name.empty()) {
- sendErrorResponse(res, "ESRGAN model not found: " + esrganModelId, 404, "MODEL_NOT_FOUND", requestId);
- return;
- }
- if (modelInfo.type != ModelType::ESRGAN && modelInfo.type != ModelType::UPSCALER) {
- sendErrorResponse(res, "Model is not an ESRGAN/upscaler model", 400, "INVALID_MODEL_TYPE", requestId);
- return;
- }
- // Load the input image
- std::string imageInput = requestJson["image"];
- auto [imageData, imgWidth, imgHeight, imgChannels, success, loadError] = loadImageFromInput(imageInput);
- if (!success) {
- sendErrorResponse(res, "Failed to load image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId);
- return;
- }
- // Create upscaler request
- GenerationRequest genRequest;
- genRequest.id = requestId;
- genRequest.requestType = GenerationRequest::RequestType::UPSCALER;
- genRequest.esrganPath = modelInfo.path;
- genRequest.upscaleFactor = requestJson.value("upscale_factor", 4);
- genRequest.nThreads = requestJson.value("threads", -1);
- genRequest.offloadParamsToCpu = requestJson.value("offload_to_cpu", false);
- genRequest.diffusionConvDirect = requestJson.value("direct", false);
- // Set input image data
- genRequest.initImageData = imageData;
- genRequest.initImageWidth = imgWidth;
- genRequest.initImageHeight = imgHeight;
- genRequest.initImageChannels = imgChannels;
- // Enqueue request
- auto future = m_generationQueue->enqueueRequest(genRequest);
- json response = {
- {"request_id", requestId},
- {"status", "queued"},
- {"message", "Upscale request queued successfully"},
- {"queue_position", m_generationQueue->getQueueSize()},
- {"type", "upscale"},
- {"parameters", {
- {"esrgan_model", esrganModelId},
- {"upscale_factor", genRequest.upscaleFactor},
- {"input_width", imgWidth},
- {"input_height", imgHeight},
- {"output_width", imgWidth * genRequest.upscaleFactor},
- {"output_height", imgHeight * genRequest.upscaleFactor}
- }}
- };
- sendJsonResponse(res, response, 202);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Upscale request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- // Utility endpoints
- void Server::handleSamplers(const httplib::Request& req, httplib::Response& res) {
- try {
- json samplers = {
- {"samplers", {
- {
- {"name", "euler"},
- {"description", "Euler sampler - fast and simple"},
- {"recommended_steps", 20}
- },
- {
- {"name", "euler_a"},
- {"description", "Euler ancestral sampler - adds randomness"},
- {"recommended_steps", 20}
- },
- {
- {"name", "heun"},
- {"description", "Heun sampler - more accurate but slower"},
- {"recommended_steps", 20}
- },
- {
- {"name", "dpm2"},
- {"description", "DPM2 sampler - second-order DPM"},
- {"recommended_steps", 20}
- },
- {
- {"name", "dpm++2s_a"},
- {"description", "DPM++ 2s ancestral sampler"},
- {"recommended_steps", 20}
- },
- {
- {"name", "dpm++2m"},
- {"description", "DPM++ 2m sampler - multistep"},
- {"recommended_steps", 20}
- },
- {
- {"name", "dpm++2mv2"},
- {"description", "DPM++ 2m v2 sampler - improved multistep"},
- {"recommended_steps", 20}
- },
- {
- {"name", "ipndm"},
- {"description", "IPNDM sampler - improved noise prediction"},
- {"recommended_steps", 20}
- },
- {
- {"name", "ipndm_v"},
- {"description", "IPNDM v sampler - variant of IPNDM"},
- {"recommended_steps", 20}
- },
- {
- {"name", "lcm"},
- {"description", "LCM sampler - Latent Consistency Model, very fast"},
- {"recommended_steps", 4}
- },
- {
- {"name", "ddim_trailing"},
- {"description", "DDIM trailing sampler - deterministic"},
- {"recommended_steps", 20}
- },
- {
- {"name", "tcd"},
- {"description", "TCD sampler - Trajectory Consistency Distillation"},
- {"recommended_steps", 8}
- },
- {
- {"name", "default"},
- {"description", "Use model's default sampler"},
- {"recommended_steps", 20}
- }
- }}
- };
- sendJsonResponse(res, samplers);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to get samplers: ") + e.what(), 500);
- }
- }
- void Server::handleSchedulers(const httplib::Request& req, httplib::Response& res) {
- try {
- json schedulers = {
- {"schedulers", {
- {
- {"name", "discrete"},
- {"description", "Discrete scheduler - standard noise schedule"}
- },
- {
- {"name", "karras"},
- {"description", "Karras scheduler - improved noise schedule"}
- },
- {
- {"name", "exponential"},
- {"description", "Exponential scheduler - exponential noise decay"}
- },
- {
- {"name", "ays"},
- {"description", "AYS scheduler - Adaptive Your Scheduler"}
- },
- {
- {"name", "gits"},
- {"description", "GITS scheduler - Generalized Iterative Time Steps"}
- },
- {
- {"name", "smoothstep"},
- {"description", "Smoothstep scheduler - smooth transition function"}
- },
- {
- {"name", "sgm_uniform"},
- {"description", "SGM uniform scheduler - uniform noise schedule"}
- },
- {
- {"name", "simple"},
- {"description", "Simple scheduler - basic linear schedule"}
- },
- {
- {"name", "default"},
- {"description", "Use model's default scheduler"}
- }
- }}
- };
- sendJsonResponse(res, schedulers);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to get schedulers: ") + e.what(), 500);
- }
- }
- void Server::handleParameters(const httplib::Request& req, httplib::Response& res) {
- try {
- json parameters = {
- {"parameters", {
- {
- {"name", "prompt"},
- {"type", "string"},
- {"required", true},
- {"description", "Text prompt for image generation"},
- {"min_length", 1},
- {"max_length", 10000},
- {"example", "a beautiful landscape with mountains"}
- },
- {
- {"name", "negative_prompt"},
- {"type", "string"},
- {"required", false},
- {"description", "Negative prompt to guide generation away from"},
- {"min_length", 0},
- {"max_length", 10000},
- {"example", "blurry, low quality, distorted"}
- },
- {
- {"name", "width"},
- {"type", "integer"},
- {"required", false},
- {"description", "Image width in pixels"},
- {"min", 64},
- {"max", 2048},
- {"multiple_of", 64},
- {"default", 512}
- },
- {
- {"name", "height"},
- {"type", "integer"},
- {"required", false},
- {"description", "Image height in pixels"},
- {"min", 64},
- {"max", 2048},
- {"multiple_of", 64},
- {"default", 512}
- },
- {
- {"name", "steps"},
- {"type", "integer"},
- {"required", false},
- {"description", "Number of diffusion steps"},
- {"min", 1},
- {"max", 150},
- {"default", 20}
- },
- {
- {"name", "cfg_scale"},
- {"type", "number"},
- {"required", false},
- {"description", "Classifier-Free Guidance scale"},
- {"min", 1.0},
- {"max", 30.0},
- {"default", 7.5}
- },
- {
- {"name", "seed"},
- {"type", "string|integer"},
- {"required", false},
- {"description", "Seed for generation (use 'random' for random seed)"},
- {"example", "42"}
- },
- {
- {"name", "sampling_method"},
- {"type", "string"},
- {"required", false},
- {"description", "Sampling method to use"},
- {"enum", {"euler", "euler_a", "heun", "dpm2", "dpm++2s_a", "dpm++2m", "dpm++2mv2", "ipndm", "ipndm_v", "lcm", "ddim_trailing", "tcd", "default"}},
- {"default", "default"}
- },
- {
- {"name", "scheduler"},
- {"type", "string"},
- {"required", false},
- {"description", "Scheduler to use"},
- {"enum", {"discrete", "karras", "exponential", "ays", "gits", "smoothstep", "sgm_uniform", "simple", "default"}},
- {"default", "default"}
- },
- {
- {"name", "batch_count"},
- {"type", "integer"},
- {"required", false},
- {"description", "Number of images to generate"},
- {"min", 1},
- {"max", 100},
- {"default", 1}
- },
- {
- {"name", "strength"},
- {"type", "number"},
- {"required", false},
- {"description", "Strength for img2img (0.0-1.0)"},
- {"min", 0.0},
- {"max", 1.0},
- {"default", 0.75}
- },
- {
- {"name", "control_strength"},
- {"type", "number"},
- {"required", false},
- {"description", "ControlNet strength (0.0-1.0)"},
- {"min", 0.0},
- {"max", 1.0},
- {"default", 0.9}
- }
- }},
- {"openapi", {
- {"version", "3.0.0"},
- {"info", {
- {"title", "Stable Diffusion REST API"},
- {"version", "1.0.0"},
- {"description", "Comprehensive REST API for stable-diffusion.cpp functionality"}
- }},
- {"components", {
- {"schemas", {
- {"GenerationRequest", {
- {"type", "object"},
- {"required", {"prompt"}},
- {"properties", {
- {"prompt", {{"type", "string"}, {"description", "Text prompt for generation"}}},
- {"negative_prompt", {{"type", "string"}, {"description", "Negative prompt"}}},
- {"width", {{"type", "integer"}, {"minimum", 64}, {"maximum", 2048}, {"default", 512}}},
- {"height", {{"type", "integer"}, {"minimum", 64}, {"maximum", 2048}, {"default", 512}}},
- {"steps", {{"type", "integer"}, {"minimum", 1}, {"maximum", 150}, {"default", 20}}},
- {"cfg_scale", {{"type", "number"}, {"minimum", 1.0}, {"maximum", 30.0}, {"default", 7.5}}}
- }}
- }}
- }}
- }}
- }}
- };
- sendJsonResponse(res, parameters);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to get parameters: ") + e.what(), 500);
- }
- }
- void Server::handleValidate(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- json requestJson = json::parse(req.body);
- // Validate parameters
- auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
- json response = {
- {"request_id", requestId},
- {"valid", isValid},
- {"message", isValid ? "Parameters are valid" : errorMessage},
- {"errors", isValid ? json::array() : json::array({errorMessage})}
- };
- sendJsonResponse(res, response, isValid ? 200 : 400);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Validation failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleEstimate(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- json requestJson = json::parse(req.body);
- // Validate parameters first
- auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
- if (!isValid) {
- sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Create a temporary request to estimate
- GenerationRequest genRequest;
- genRequest.prompt = requestJson["prompt"];
- genRequest.width = requestJson.value("width", 512);
- genRequest.height = requestJson.value("height", 512);
- genRequest.batchCount = requestJson.value("batch_count", 1);
- genRequest.steps = requestJson.value("steps", 20);
- genRequest.diffusionFlashAttn = requestJson.value("diffusion_flash_attn", false);
- genRequest.controlNetPath = requestJson.value("control_net_path", "");
- if (requestJson.contains("sampling_method")) {
- genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
- }
- // Calculate estimates
- uint64_t estimatedTime = estimateGenerationTime(genRequest);
- size_t estimatedMemory = estimateMemoryUsage(genRequest);
- json response = {
- {"request_id", requestId},
- {"estimated_time_seconds", estimatedTime / 1000},
- {"estimated_memory_mb", estimatedMemory / (1024 * 1024)},
- {"parameters", {
- {"resolution", std::to_string(genRequest.width) + "x" + std::to_string(genRequest.height)},
- {"steps", genRequest.steps},
- {"batch_count", genRequest.batchCount},
- {"sampling_method", samplingMethodToString(genRequest.samplingMethod)}
- }}
- };
- sendJsonResponse(res, response);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Estimation failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleConfig(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- // Get current configuration
- json config = {
- {"request_id", requestId},
- {"config", {
- {"server", {
- {"host", m_host},
- {"port", m_port},
- {"max_concurrent_generations", 1}
- }},
- {"generation", {
- {"default_width", 512},
- {"default_height", 512},
- {"default_steps", 20},
- {"default_cfg_scale", 7.5},
- {"max_batch_count", 100},
- {"max_steps", 150},
- {"max_resolution", 2048}
- }},
- {"rate_limiting", {
- {"requests_per_minute", 60},
- {"enabled", true}
- }}
- }}
- };
- sendJsonResponse(res, config);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Config operation failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleSystem(const httplib::Request& req, httplib::Response& res) {
- try {
- json system = {
- {"system", {
- {"version", "1.0.0"},
- {"build", "stable-diffusion.cpp-rest"},
- {"uptime", std::chrono::duration_cast<std::chrono::seconds>(
- std::chrono::steady_clock::now().time_since_epoch()).count()},
- {"capabilities", {
- {"text2img", true},
- {"img2img", true},
- {"controlnet", true},
- {"batch_generation", true},
- {"parameter_validation", true},
- {"estimation", true}
- }},
- {"supported_formats", {
- {"input", {"png", "jpg", "jpeg", "webp"}},
- {"output", {"png", "jpg", "jpeg", "webp"}}
- }},
- {"limits", {
- {"max_resolution", 2048},
- {"max_steps", 150},
- {"max_batch_count", 100},
- {"max_prompt_length", 10000}
- }}
- }},
- {"hardware", {
- {"cpu_threads", std::thread::hardware_concurrency()}
- }}
- };
- sendJsonResponse(res, system);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("System info failed: ") + e.what(), 500);
- }
- }
- void Server::handleSystemRestart(const httplib::Request& req, httplib::Response& res) {
- try {
- json response = {
- {"message", "Server restart initiated. The server will shut down gracefully and exit. Please use a process manager to automatically restart it."},
- {"status", "restarting"}
- };
- sendJsonResponse(res, response);
- // Schedule server stop after response is sent
- // Using a separate thread to allow the response to be sent first
- std::thread([this]() {
- std::this_thread::sleep_for(std::chrono::seconds(1));
- this->stop();
- // Exit with code 42 to signal restart intent to process manager
- std::exit(42);
- }).detach();
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Restart failed: ") + e.what(), 500);
- }
- }
- // Helper methods for model management
- json Server::getModelCapabilities(ModelType type) {
- json capabilities = json::object();
- switch (type) {
- case ModelType::CHECKPOINT:
- capabilities = {
- {"text2img", true},
- {"img2img", true},
- {"inpainting", true},
- {"outpainting", true},
- {"controlnet", true},
- {"lora", true},
- {"vae", true},
- {"sampling_methods", {"euler", "euler_a", "heun", "dpm2", "dpm++2s_a", "dpm++2m", "dpm++2mv2", "ipndm", "ipndm_v", "lcm", "ddim_trailing", "tcd"}},
- {"schedulers", {"discrete", "karras", "exponential", "ays", "gits", "smoothstep", "sgm_uniform", "simple"}},
- {"recommended_resolution", "512x512"},
- {"max_resolution", "2048x2048"},
- {"supports_batch", true}
- };
- break;
- case ModelType::LORA:
- capabilities = {
- {"text2img", true},
- {"img2img", true},
- {"inpainting", true},
- {"controlnet", false},
- {"lora", true},
- {"vae", false},
- {"requires_checkpoint", true},
- {"strength_range", {0.0, 2.0}},
- {"recommended_strength", 1.0}
- };
- break;
- case ModelType::CONTROLNET:
- capabilities = {
- {"text2img", false},
- {"img2img", true},
- {"inpainting", true},
- {"controlnet", true},
- {"requires_checkpoint", true},
- {"control_modes", {"canny", "depth", "pose", "scribble", "hed", "mlsd", "normal", "seg"}},
- {"strength_range", {0.0, 1.0}},
- {"recommended_strength", 0.9}
- };
- break;
- case ModelType::VAE:
- capabilities = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"vae", true},
- {"requires_checkpoint", true},
- {"encoding", true},
- {"decoding", true},
- {"precision", {"fp16", "fp32"}}
- };
- break;
- case ModelType::EMBEDDING:
- capabilities = {
- {"text2img", true},
- {"img2img", true},
- {"inpainting", true},
- {"embedding", true},
- {"requires_checkpoint", true},
- {"token_count", 1},
- {"compatible_with", {"checkpoint", "lora"}}
- };
- break;
- case ModelType::TAESD:
- capabilities = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"vae", true},
- {"requires_checkpoint", true},
- {"fast_decoding", true},
- {"real_time", true},
- {"precision", {"fp16", "fp32"}}
- };
- break;
- case ModelType::ESRGAN:
- capabilities = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"upscaling", true},
- {"scale_factors", {2, 4}},
- {"models", {"ESRGAN", "RealESRGAN", "SwinIR"}},
- {"supports_alpha", false}
- };
- break;
- default:
- capabilities = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"capabilities", {}}
- };
- break;
- }
- return capabilities;
- }
- json Server::getModelTypeStatistics() {
- if (!m_modelManager) return json::object();
- json stats = json::object();
- auto allModels = m_modelManager->getAllModels();
- // Initialize counters for each type
- std::map<ModelType, int> typeCounts;
- std::map<ModelType, int> loadedCounts;
- std::map<ModelType, size_t> sizeByType;
- for (const auto& pair : allModels) {
- ModelType type = pair.second.type;
- typeCounts[type]++;
- if (pair.second.isLoaded) {
- loadedCounts[type]++;
- }
- sizeByType[type] += pair.second.fileSize;
- }
- // Build statistics JSON
- for (const auto& count : typeCounts) {
- std::string typeName = ModelManager::modelTypeToString(count.first);
- stats[typeName] = {
- {"total_count", count.second},
- {"loaded_count", loadedCounts[count.first]},
- {"total_size_bytes", sizeByType[count.first]},
- {"total_size_mb", sizeByType[count.first] / (1024.0 * 1024.0)},
- {"average_size_mb", count.second > 0 ? (sizeByType[count.first] / (1024.0 * 1024.0)) / count.second : 0.0}
- };
- }
- return stats;
- }
- // Additional helper methods for model management
- json Server::getModelCompatibility(const ModelManager::ModelInfo& modelInfo) {
- json compatibility = {
- {"is_compatible", true},
- {"compatibility_score", 100},
- {"issues", json::array()},
- {"warnings", json::array()},
- {"requirements", {
- {"min_memory_mb", 1024},
- {"recommended_memory_mb", 2048},
- {"supported_formats", {"safetensors", "ckpt", "gguf"}},
- {"required_dependencies", {}}
- }}
- };
- // Check for specific compatibility issues based on model type
- if (modelInfo.type == ModelType::LORA) {
- compatibility["requirements"]["required_dependencies"] = {"checkpoint"};
- } else if (modelInfo.type == ModelType::CONTROLNET) {
- compatibility["requirements"]["required_dependencies"] = {"checkpoint"};
- } else if (modelInfo.type == ModelType::VAE) {
- compatibility["requirements"]["required_dependencies"] = {"checkpoint"};
- }
- return compatibility;
- }
- json Server::getModelRequirements(ModelType type) {
- json requirements = {
- {"min_memory_mb", 1024},
- {"recommended_memory_mb", 2048},
- {"min_disk_space_mb", 1024},
- {"supported_formats", {"safetensors", "ckpt", "gguf"}},
- {"required_dependencies", json::array()},
- {"optional_dependencies", json::array()},
- {"system_requirements", {
- {"cpu_cores", 4},
- {"cpu_architecture", "x86_64"},
- {"os", "Linux/Windows/macOS"},
- {"gpu_memory_mb", 2048},
- {"gpu_compute_capability", "3.5+"}
- }}
- };
- switch (type) {
- case ModelType::CHECKPOINT:
- requirements["min_memory_mb"] = 2048;
- requirements["recommended_memory_mb"] = 4096;
- requirements["min_disk_space_mb"] = 2048;
- requirements["supported_formats"] = {"safetensors", "ckpt", "gguf"};
- break;
- case ModelType::LORA:
- requirements["min_memory_mb"] = 512;
- requirements["recommended_memory_mb"] = 1024;
- requirements["min_disk_space_mb"] = 100;
- requirements["supported_formats"] = {"safetensors", "ckpt"};
- requirements["required_dependencies"] = {"checkpoint"};
- break;
- case ModelType::CONTROLNET:
- requirements["min_memory_mb"] = 1024;
- requirements["recommended_memory_mb"] = 2048;
- requirements["min_disk_space_mb"] = 500;
- requirements["supported_formats"] = {"safetensors", "pth"};
- requirements["required_dependencies"] = {"checkpoint"};
- break;
- case ModelType::VAE:
- requirements["min_memory_mb"] = 512;
- requirements["recommended_memory_mb"] = 1024;
- requirements["min_disk_space_mb"] = 200;
- requirements["supported_formats"] = {"safetensors", "pt", "ckpt", "gguf"};
- requirements["required_dependencies"] = {"checkpoint"};
- break;
- case ModelType::EMBEDDING:
- requirements["min_memory_mb"] = 64;
- requirements["recommended_memory_mb"] = 256;
- requirements["min_disk_space_mb"] = 10;
- requirements["supported_formats"] = {"safetensors", "pt"};
- requirements["required_dependencies"] = {"checkpoint"};
- break;
- case ModelType::TAESD:
- requirements["min_memory_mb"] = 256;
- requirements["recommended_memory_mb"] = 512;
- requirements["min_disk_space_mb"] = 100;
- requirements["supported_formats"] = {"safetensors", "pth", "gguf"};
- requirements["required_dependencies"] = {"checkpoint"};
- break;
- case ModelType::ESRGAN:
- requirements["min_memory_mb"] = 1024;
- requirements["recommended_memory_mb"] = 2048;
- requirements["min_disk_space_mb"] = 500;
- requirements["supported_formats"] = {"pth", "pt"};
- requirements["optional_dependencies"] = {"checkpoint"};
- break;
- default:
- break;
- }
- return requirements;
- }
- json Server::getRecommendedUsage(ModelType type) {
- json usage = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"controlnet", false},
- {"lora", false},
- {"vae", false},
- {"recommended_resolution", "512x512"},
- {"recommended_steps", 20},
- {"recommended_cfg_scale", 7.5},
- {"recommended_batch_size", 1}
- };
- switch (type) {
- case ModelType::CHECKPOINT:
- usage = {
- {"text2img", true},
- {"img2img", true},
- {"inpainting", true},
- {"controlnet", true},
- {"lora", true},
- {"vae", true},
- {"recommended_resolution", "512x512"},
- {"recommended_steps", 20},
- {"recommended_cfg_scale", 7.5},
- {"recommended_batch_size", 1}
- };
- break;
- case ModelType::LORA:
- usage = {
- {"text2img", true},
- {"img2img", true},
- {"inpainting", true},
- {"controlnet", false},
- {"lora", true},
- {"vae", false},
- {"recommended_strength", 1.0},
- {"recommended_usage", "Style transfer, character customization"}
- };
- break;
- case ModelType::CONTROLNET:
- usage = {
- {"text2img", false},
- {"img2img", true},
- {"inpainting", true},
- {"controlnet", true},
- {"lora", false},
- {"vae", false},
- {"recommended_strength", 0.9},
- {"recommended_usage", "Precise control over output"}
- };
- break;
- case ModelType::VAE:
- usage = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"controlnet", false},
- {"lora", false},
- {"vae", true},
- {"recommended_usage", "Improved encoding/decoding quality"}
- };
- break;
- case ModelType::EMBEDDING:
- usage = {
- {"text2img", true},
- {"img2img", true},
- {"inpainting", true},
- {"controlnet", false},
- {"lora", false},
- {"vae", false},
- {"embedding", true},
- {"recommended_usage", "Concept control, style words"}
- };
- break;
- case ModelType::TAESD:
- usage = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"controlnet", false},
- {"lora", false},
- {"vae", true},
- {"recommended_usage", "Real-time decoding"}
- };
- break;
- case ModelType::ESRGAN:
- usage = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"controlnet", false},
- {"lora", false},
- {"vae", false},
- {"upscaling", true},
- {"recommended_usage", "Image upscaling and quality enhancement"}
- };
- break;
- default:
- break;
- }
- return usage;
- }
- std::string Server::getModelTypeFromDirectoryName(const std::string& dirName) {
- if (dirName == "stable-diffusion" || dirName == "checkpoints") {
- return "checkpoint";
- } else if (dirName == "lora") {
- return "lora";
- } else if (dirName == "controlnet") {
- return "controlnet";
- } else if (dirName == "vae") {
- return "vae";
- } else if (dirName == "taesd") {
- return "taesd";
- } else if (dirName == "esrgan" || dirName == "upscaler") {
- return "esrgan";
- } else if (dirName == "embeddings" || dirName == "textual-inversion") {
- return "embedding";
- } else {
- return "unknown";
- }
- }
- std::string Server::getDirectoryDescription(const std::string& dirName) {
- if (dirName == "stable-diffusion" || dirName == "checkpoints") {
- return "Main stable diffusion model files";
- } else if (dirName == "lora") {
- return "LoRA adapter models for style transfer";
- } else if (dirName == "controlnet") {
- return "ControlNet models for precise control";
- } else if (dirName == "vae") {
- return "VAE models for improved encoding/decoding";
- } else if (dirName == "taesd") {
- return "TAESD models for real-time decoding";
- } else if (dirName == "esrgan" || dirName == "upscaler") {
- return "ESRGAN models for image upscaling";
- } else if (dirName == "embeddings" || dirName == "textual-inversion") {
- return "Text embeddings for concept control";
- } else {
- return "Unknown model directory";
- }
- }
- json Server::getDirectoryContents(const std::string& dirPath) {
- json contents = json::array();
- try {
- if (std::filesystem::exists(dirPath) && std::filesystem::is_directory(dirPath)) {
- for (const auto& entry : std::filesystem::directory_iterator(dirPath)) {
- if (entry.is_regular_file()) {
- json file = {
- {"name", entry.path().filename().string()},
- {"path", entry.path().string()},
- {"size", std::filesystem::file_size(entry.path())},
- {"size_mb", std::filesystem::file_size(entry.path()) / (1024.0 * 1024.0)},
- {"last_modified", std::chrono::duration_cast<std::chrono::seconds>(
- std::filesystem::last_write_time(entry.path()).time_since_epoch()).count()}
- };
- contents.push_back(file);
- }
- }
- }
- } catch (const std::exception& e) {
- // Return empty array if directory access fails
- }
- return contents;
- }
- json Server::getLargestModel(const std::map<std::string, ModelManager::ModelInfo>& allModels) {
- json largest = json::object();
- size_t maxSize = 0;
- std::string largestName;
- for (const auto& pair : allModels) {
- if (pair.second.fileSize > maxSize) {
- maxSize = pair.second.fileSize;
- largestName = pair.second.name;
- }
- }
- if (!largestName.empty()) {
- largest = {
- {"name", largestName},
- {"size", maxSize},
- {"size_mb", maxSize / (1024.0 * 1024.0)},
- {"type", ModelManager::modelTypeToString(allModels.at(largestName).type)}
- };
- }
- return largest;
- }
- json Server::getSmallestModel(const std::map<std::string, ModelManager::ModelInfo>& allModels) {
- json smallest = json::object();
- size_t minSize = SIZE_MAX;
- std::string smallestName;
- for (const auto& pair : allModels) {
- if (pair.second.fileSize < minSize) {
- minSize = pair.second.fileSize;
- smallestName = pair.second.name;
- }
- }
- if (!smallestName.empty()) {
- smallest = {
- {"name", smallestName},
- {"size", minSize},
- {"size_mb", minSize / (1024.0 * 1024.0)},
- {"type", ModelManager::modelTypeToString(allModels.at(smallestName).type)}
- };
- }
- return smallest;
- }
- json Server::validateModelFile(const std::string& modelPath, const std::string& modelType) {
- json validation = {
- {"is_valid", false},
- {"errors", json::array()},
- {"warnings", json::array()},
- {"file_info", json::object()},
- {"compatibility", json::object()},
- {"recommendations", json::array()}
- };
- try {
- if (!std::filesystem::exists(modelPath)) {
- validation["errors"].push_back("File does not exist");
- return validation;
- }
- if (!std::filesystem::is_regular_file(modelPath)) {
- validation["errors"].push_back("Path is not a regular file");
- return validation;
- }
- // Check file extension
- std::string extension = std::filesystem::path(modelPath).extension().string();
- if (extension.empty()) {
- validation["errors"].push_back("Missing file extension");
- return validation;
- }
- // Remove dot and convert to lowercase
- if (extension[0] == '.') {
- extension = extension.substr(1);
- }
- std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower);
- // Validate extension based on model type
- ModelType type = ModelManager::stringToModelType(modelType);
- bool validExtension = false;
- switch (type) {
- case ModelType::CHECKPOINT:
- validExtension = (extension == "safetensors" || extension == "ckpt" || extension == "gguf");
- break;
- case ModelType::LORA:
- validExtension = (extension == "safetensors" || extension == "ckpt");
- break;
- case ModelType::CONTROLNET:
- validExtension = (extension == "safetensors" || extension == "pth");
- break;
- case ModelType::VAE:
- validExtension = (extension == "safetensors" || extension == "pt" || extension == "ckpt" || extension == "gguf");
- break;
- case ModelType::EMBEDDING:
- validExtension = (extension == "safetensors" || extension == "pt");
- break;
- case ModelType::TAESD:
- validExtension = (extension == "safetensors" || extension == "pth" || extension == "gguf");
- break;
- case ModelType::ESRGAN:
- validExtension = (extension == "pth" || extension == "pt");
- break;
- default:
- break;
- }
- if (!validExtension) {
- validation["errors"].push_back("Invalid file extension for model type: " + extension);
- }
- // Check file size
- size_t fileSize = std::filesystem::file_size(modelPath);
- if (fileSize == 0) {
- validation["errors"].push_back("File is empty");
- } else if (fileSize > 8ULL * 1024 * 1024 * 1024) { // 8GB
- validation["warnings"].push_back("Very large file may cause performance issues");
- }
- // Build file info
- validation["file_info"] = {
- {"path", modelPath},
- {"size", fileSize},
- {"size_mb", fileSize / (1024.0 * 1024.0)},
- {"extension", extension},
- {"last_modified", std::chrono::duration_cast<std::chrono::seconds>(
- std::filesystem::last_write_time(modelPath).time_since_epoch()).count()}
- };
- // Check compatibility
- validation["compatibility"] = {
- {"extension_valid", validExtension},
- {"size_appropriate", fileSize <= 4ULL * 1024 * 1024 * 1024}, // 4GB
- {"recommended_format", "safetensors"}
- };
- // Add recommendations
- if (!validExtension) {
- validation["recommendations"].push_back("Convert to SafeTensors format for better security and performance");
- }
- if (fileSize > 2ULL * 1024 * 1024 * 1024) { // 2GB
- validation["recommendations"].push_back("Consider using a smaller model for better performance");
- }
- // If no errors found, mark as valid
- if (validation["errors"].empty()) {
- validation["is_valid"] = true;
- }
- } catch (const std::exception& e) {
- validation["errors"].push_back("Validation failed: " + std::string(e.what()));
- }
- return validation;
- }
- json Server::checkModelCompatibility(const ModelManager::ModelInfo& modelInfo, const std::string& systemInfo) {
- json compatibility = {
- {"is_compatible", true},
- {"compatibility_score", 100},
- {"issues", json::array()},
- {"warnings", json::array()},
- {"requirements", json::object()},
- {"recommendations", json::array()},
- {"system_info", json::object()}
- };
- // Check system compatibility
- if (systemInfo == "auto") {
- compatibility["system_info"] = {
- {"cpu_cores", std::thread::hardware_concurrency()}
- };
- }
- // Check model-specific compatibility issues
- if (modelInfo.type == ModelType::CHECKPOINT) {
- if (modelInfo.fileSize > 4ULL * 1024 * 1024 * 1024) { // 4GB
- compatibility["warnings"].push_back("Large checkpoint model may require significant memory");
- compatibility["compatibility_score"] = 80;
- }
- if (modelInfo.fileSize < 500 * 1024 * 1024) { // 500MB
- compatibility["warnings"].push_back("Small checkpoint model may have limited capabilities");
- compatibility["compatibility_score"] = 85;
- }
- } else if (modelInfo.type == ModelType::LORA) {
- if (modelInfo.fileSize > 500 * 1024 * 1024) { // 500MB
- compatibility["warnings"].push_back("Large LoRA may impact performance");
- compatibility["compatibility_score"] = 75;
- }
- }
- return compatibility;
- }
- json Server::calculateSpecificRequirements(const std::string& modelType, const std::string& resolution, const std::string& batchSize) {
- json specific = {
- {"memory_requirements", json::object()},
- {"performance_impact", json::object()},
- {"quality_expectations", json::object()}
- };
- // Parse resolution
- int width = 512, height = 512;
- try {
- size_t xPos = resolution.find('x');
- if (xPos != std::string::npos) {
- width = std::stoi(resolution.substr(0, xPos));
- height = std::stoi(resolution.substr(xPos + 1));
- }
- } catch (...) {
- // Use defaults if parsing fails
- }
- // Parse batch size
- int batch = 1;
- try {
- batch = std::stoi(batchSize);
- } catch (...) {
- // Use default if parsing fails
- }
- // Calculate memory requirements based on resolution and batch
- size_t pixels = width * height;
- size_t baseMemory = 1024 * 1024 * 1024; // 1GB base
- size_t resolutionMemory = (pixels * 4) / (512 * 512); // Scale based on 512x512
- size_t batchMemory = (batch - 1) * baseMemory * 0.5; // Additional memory for batch
- specific["memory_requirements"] = {
- {"base_memory_mb", baseMemory / (1024 * 1024)},
- {"resolution_memory_mb", resolutionMemory / (1024 * 1024)},
- {"batch_memory_mb", batchMemory / (1024 * 1024)},
- {"total_memory_mb", (baseMemory + resolutionMemory + batchMemory) / (1024 * 1024)}
- };
- // Calculate performance impact
- double performanceFactor = 1.0;
- if (pixels > 512 * 512) {
- performanceFactor = 1.5;
- }
- if (batch > 1) {
- performanceFactor *= 1.2;
- }
- specific["performance_impact"] = {
- {"resolution_factor", pixels > 512 * 512 ? 1.5 : 1.0},
- {"batch_factor", batch > 1 ? 1.2 : 1.0},
- {"overall_factor", performanceFactor}
- };
- return specific;
- }
- // Enhanced model management endpoint implementations
- void Server::handleModelInfo(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Extract model ID from URL path
- std::string modelId = req.matches[1].str();
- if (modelId.empty()) {
- sendErrorResponse(res, "Missing model ID", 400, "MISSING_MODEL_ID", requestId);
- return;
- }
- // Get model information
- auto modelInfo = m_modelManager->getModelInfo(modelId);
- if (modelInfo.name.empty()) {
- sendErrorResponse(res, "Model not found", 404, "MODEL_NOT_FOUND", requestId);
- return;
- }
- // Build comprehensive model information
- json response = {
- {"model", {
- {"name", modelInfo.name},
- {"path", modelInfo.path},
- {"type", ModelManager::modelTypeToString(modelInfo.type)},
- {"is_loaded", modelInfo.isLoaded},
- {"file_size", modelInfo.fileSize},
- {"file_size_mb", modelInfo.fileSize / (1024.0 * 1024.0)},
- {"description", modelInfo.description},
- {"metadata", modelInfo.metadata},
- {"capabilities", getModelCapabilities(modelInfo.type)},
- {"compatibility", getModelCompatibility(modelInfo)},
- {"requirements", getModelRequirements(modelInfo.type)},
- {"recommended_usage", getRecommendedUsage(modelInfo.type)},
- {"last_modified", std::chrono::duration_cast<std::chrono::seconds>(
- modelInfo.modifiedAt.time_since_epoch()).count()}
- }},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to get model info: ") + e.what(), 500, "MODEL_INFO_ERROR", requestId);
- }
- }
- void Server::handleLoadModelById(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Extract model ID from URL path (could be hash or name)
- std::string modelIdentifier = req.matches[1].str();
- if (modelIdentifier.empty()) {
- sendErrorResponse(res, "Missing model identifier", 400, "MISSING_MODEL_ID", requestId);
- return;
- }
- // Try to find by hash first (if it looks like a hash - 10+ hex chars)
- std::string modelId = modelIdentifier;
- if (modelIdentifier.length() >= 10 &&
- std::all_of(modelIdentifier.begin(), modelIdentifier.end(),
- [](char c) { return std::isxdigit(c); })) {
- std::string foundName = m_modelManager->findModelByHash(modelIdentifier);
- if (!foundName.empty()) {
- modelId = foundName;
- std::cout << "Resolved hash " << modelIdentifier << " to model: " << modelId << std::endl;
- }
- }
- // Parse optional parameters from request body
- json requestJson;
- if (!req.body.empty()) {
- try {
- requestJson = json::parse(req.body);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- return;
- }
- }
- // Unload previous model if one is loaded
- std::string previousModel;
- {
- std::lock_guard<std::mutex> lock(m_currentModelMutex);
- previousModel = m_currentlyLoadedModel;
- }
- if (!previousModel.empty() && previousModel != modelId) {
- std::cout << "Unloading previous model: " << previousModel << std::endl;
- m_modelManager->unloadModel(previousModel);
- }
- // Load model
- bool success = m_modelManager->loadModel(modelId);
- if (success) {
- // Update currently loaded model
- {
- std::lock_guard<std::mutex> lock(m_currentModelMutex);
- m_currentlyLoadedModel = modelId;
- }
- auto modelInfo = m_modelManager->getModelInfo(modelId);
- json response = {
- {"status", "success"},
- {"model", {
- {"name", modelInfo.name},
- {"path", modelInfo.path},
- {"type", ModelManager::modelTypeToString(modelInfo.type)},
- {"is_loaded", modelInfo.isLoaded}
- }},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } else {
- sendErrorResponse(res, "Failed to load model", 400, "MODEL_LOAD_FAILED", requestId);
- }
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Model load failed: ") + e.what(), 500, "MODEL_LOAD_ERROR", requestId);
- }
- }
- void Server::handleUnloadModelById(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Extract model ID from URL path
- std::string modelId = req.matches[1].str();
- if (modelId.empty()) {
- sendErrorResponse(res, "Missing model ID", 400, "MISSING_MODEL_ID", requestId);
- return;
- }
- // Unload model
- bool success = m_modelManager->unloadModel(modelId);
- if (success) {
- // Clear currently loaded model if it matches
- {
- std::lock_guard<std::mutex> lock(m_currentModelMutex);
- if (m_currentlyLoadedModel == modelId) {
- m_currentlyLoadedModel = "";
- }
- }
- json response = {
- {"status", "success"},
- {"model", {
- {"name", modelId},
- {"is_loaded", false}
- }},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } else {
- sendErrorResponse(res, "Failed to unload model or model not found", 404, "MODEL_UNLOAD_FAILED", requestId);
- }
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Model unload failed: ") + e.what(), 500, "MODEL_UNLOAD_ERROR", requestId);
- }
- }
- void Server::handleModelTypes(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- json types = {
- {"model_types", {
- {
- {"type", "checkpoint"},
- {"description", "Main stable diffusion model files for text-to-image, image-to-image, and inpainting"},
- {"extensions", {"safetensors", "ckpt", "gguf"}},
- {"capabilities", {"text2img", "img2img", "inpainting", "controlnet", "lora", "vae"}},
- {"recommended_for", "General purpose image generation"}
- },
- {
- {"type", "lora"},
- {"description", "LoRA adapter models for style transfer and character customization"},
- {"extensions", {"safetensors", "ckpt"}},
- {"capabilities", {"style_transfer", "character_customization"}},
- {"requires", {"checkpoint"}},
- {"recommended_for", "Style modification and character-specific generation"}
- },
- {
- {"type", "controlnet"},
- {"description", "ControlNet models for precise control over output composition"},
- {"extensions", {"safetensors", "pth"}},
- {"capabilities", {"precise_control", "composition_control"}},
- {"requires", {"checkpoint"}},
- {"recommended_for", "Precise control over image generation"}
- },
- {
- {"type", "vae"},
- {"description", "VAE models for improved encoding and decoding quality"},
- {"extensions", {"safetensors", "pt", "ckpt", "gguf"}},
- {"capabilities", {"encoding", "decoding", "quality_improvement"}},
- {"requires", {"checkpoint"}},
- {"recommended_for", "Improved image quality and encoding"}
- },
- {
- {"type", "embedding"},
- {"description", "Text embeddings for concept control and style words"},
- {"extensions", {"safetensors", "pt"}},
- {"capabilities", {"concept_control", "style_words"}},
- {"requires", {"checkpoint"}},
- {"recommended_for", "Concept control and specific styles"}
- },
- {
- {"type", "taesd"},
- {"description", "TAESD models for real-time decoding"},
- {"extensions", {"safetensors", "pth", "gguf"}},
- {"capabilities", {"real_time_decoding", "fast_preview"}},
- {"requires", {"checkpoint"}},
- {"recommended_for", "Real-time applications and fast previews"}
- },
- {
- {"type", "esrgan"},
- {"description", "ESRGAN models for image upscaling and enhancement"},
- {"extensions", {"pth", "pt"}},
- {"capabilities", {"upscaling", "enhancement", "quality_improvement"}},
- {"recommended_for", "Image upscaling and quality enhancement"}
- }
- }},
- {"request_id", requestId}
- };
- sendJsonResponse(res, types);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to get model types: ") + e.what(), 500, "MODEL_TYPES_ERROR", requestId);
- }
- }
- void Server::handleModelDirectories(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- std::string modelsDir = m_modelManager->getModelsDirectory();
- json directories = json::array();
- // Define expected model directories
- std::vector<std::string> modelDirs = {
- "stable-diffusion", "checkpoints", "lora", "controlnet",
- "vae", "taesd", "esrgan", "embeddings"
- };
- for (const auto& dirName : modelDirs) {
- std::string dirPath = modelsDir + "/" + dirName;
- std::string type = getModelTypeFromDirectoryName(dirName);
- std::string description = getDirectoryDescription(dirName);
- json dirInfo = {
- {"name", dirName},
- {"path", dirPath},
- {"type", type},
- {"description", description},
- {"exists", std::filesystem::exists(dirPath) && std::filesystem::is_directory(dirPath)},
- {"contents", getDirectoryContents(dirPath)}
- };
- directories.push_back(dirInfo);
- }
- json response = {
- {"models_directory", modelsDir},
- {"directories", directories},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to get model directories: ") + e.what(), 500, "MODEL_DIRECTORIES_ERROR", requestId);
- }
- }
- void Server::handleRefreshModels(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Force refresh of model cache
- bool success = m_modelManager->scanModelsDirectory();
- if (success) {
- json response = {
- {"status", "success"},
- {"message", "Model cache refreshed successfully"},
- {"models_found", m_modelManager->getAvailableModelsCount()},
- {"models_loaded", m_modelManager->getLoadedModelsCount()},
- {"models_directory", m_modelManager->getModelsDirectory()},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } else {
- sendErrorResponse(res, "Failed to refresh model cache", 500, "MODEL_REFRESH_FAILED", requestId);
- }
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Model refresh failed: ") + e.what(), 500, "MODEL_REFRESH_ERROR", requestId);
- }
- }
- void Server::handleHashModels(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_generationQueue || !m_modelManager) {
- sendErrorResponse(res, "Services not available", 500, "SERVICE_UNAVAILABLE", requestId);
- return;
- }
- // Parse request body
- json requestJson;
- if (!req.body.empty()) {
- requestJson = json::parse(req.body);
- }
- HashRequest hashReq;
- hashReq.id = requestId;
- hashReq.forceRehash = requestJson.value("force_rehash", false);
- if (requestJson.contains("models") && requestJson["models"].is_array()) {
- for (const auto& model : requestJson["models"]) {
- hashReq.modelNames.push_back(model.get<std::string>());
- }
- }
- // Enqueue hash request
- auto future = m_generationQueue->enqueueHashRequest(hashReq);
- json response = {
- {"request_id", requestId},
- {"status", "queued"},
- {"message", "Hash job queued successfully"},
- {"models_to_hash", hashReq.modelNames.empty() ? "all_unhashed" : std::to_string(hashReq.modelNames.size())}
- };
- sendJsonResponse(res, response, 202);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Hash request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleConvertModel(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_generationQueue || !m_modelManager) {
- sendErrorResponse(res, "Services not available", 500, "SERVICE_UNAVAILABLE", requestId);
- return;
- }
- // Parse request body
- json requestJson;
- try {
- requestJson = json::parse(req.body);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- return;
- }
- // Validate required fields
- if (!requestJson.contains("model_name")) {
- sendErrorResponse(res, "Missing required field: model_name", 400, "MISSING_FIELD", requestId);
- return;
- }
- if (!requestJson.contains("quantization_type")) {
- sendErrorResponse(res, "Missing required field: quantization_type", 400, "MISSING_FIELD", requestId);
- return;
- }
- std::string modelName = requestJson["model_name"].get<std::string>();
- std::string quantizationType = requestJson["quantization_type"].get<std::string>();
- // Validate quantization type
- const std::vector<std::string> validTypes = {"f32", "f16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "q2_K", "q3_K", "q4_K"};
- if (std::find(validTypes.begin(), validTypes.end(), quantizationType) == validTypes.end()) {
- sendErrorResponse(res, "Invalid quantization_type. Valid types: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K",
- 400, "INVALID_QUANTIZATION_TYPE", requestId);
- return;
- }
- // Get model info to find the full path
- auto modelInfo = m_modelManager->getModelInfo(modelName);
- if (modelInfo.name.empty()) {
- sendErrorResponse(res, "Model not found: " + modelName, 404, "MODEL_NOT_FOUND", requestId);
- return;
- }
- // Check if model is already GGUF
- if (modelInfo.fullPath.find(".gguf") != std::string::npos) {
- sendErrorResponse(res, "Model is already in GGUF format. Cannot convert GGUF to GGUF.",
- 400, "ALREADY_GGUF", requestId);
- return;
- }
- // Build output path
- std::string outputPath = requestJson.value("output_path", "");
- if (outputPath.empty()) {
- // Generate default output path: model_name_quantization.gguf
- namespace fs = std::filesystem;
- fs::path inputPath(modelInfo.fullPath);
- std::string baseName = inputPath.stem().string();
- std::string outputDir = inputPath.parent_path().string();
- outputPath = outputDir + "/" + baseName + "_" + quantizationType + ".gguf";
- }
- // Create conversion request
- ConversionRequest convReq;
- convReq.id = requestId;
- convReq.modelName = modelName;
- convReq.modelPath = modelInfo.fullPath;
- convReq.outputPath = outputPath;
- convReq.quantizationType = quantizationType;
- // Enqueue conversion request
- auto future = m_generationQueue->enqueueConversionRequest(convReq);
- json response = {
- {"request_id", requestId},
- {"status", "queued"},
- {"message", "Model conversion queued successfully"},
- {"model_name", modelName},
- {"input_path", modelInfo.fullPath},
- {"output_path", outputPath},
- {"quantization_type", quantizationType}
- };
- sendJsonResponse(res, response, 202);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Conversion request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleModelStats(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- auto allModels = m_modelManager->getAllModels();
- json response = {
- {"statistics", {
- {"total_models", allModels.size()},
- {"loaded_models", m_modelManager->getLoadedModelsCount()},
- {"available_models", m_modelManager->getAvailableModelsCount()},
- {"model_types", getModelTypeStatistics()},
- {"largest_model", getLargestModel(allModels)},
- {"smallest_model", getSmallestModel(allModels)}
- }},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to get model stats: ") + e.what(), 500, "MODEL_STATS_ERROR", requestId);
- }
- }
- void Server::handleBatchModels(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Parse JSON request body
- json requestJson = json::parse(req.body);
- if (!requestJson.contains("operation") || !requestJson["operation"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'operation' field", 400, "INVALID_OPERATION", requestId);
- return;
- }
- if (!requestJson.contains("models") || !requestJson["models"].is_array()) {
- sendErrorResponse(res, "Missing or invalid 'models' field", 400, "INVALID_MODELS", requestId);
- return;
- }
- std::string operation = requestJson["operation"];
- json models = requestJson["models"];
- json results = json::array();
- for (const auto& model : models) {
- if (!model.is_string()) {
- results.push_back({
- {"model", model},
- {"success", false},
- {"error", "Invalid model name"}
- });
- continue;
- }
- std::string modelName = model;
- bool success = false;
- std::string error = "";
- if (operation == "load") {
- success = m_modelManager->loadModel(modelName);
- if (!success) error = "Failed to load model";
- } else if (operation == "unload") {
- success = m_modelManager->unloadModel(modelName);
- if (!success) error = "Failed to unload model";
- } else {
- error = "Unsupported operation";
- }
- results.push_back({
- {"model", modelName},
- {"success", success},
- {"error", error.empty() ? json(nullptr) : json(error)}
- });
- }
- json response = {
- {"operation", operation},
- {"results", results},
- {"successful_count", std::count_if(results.begin(), results.end(),
- [](const json& result) { return result["success"].get<bool>(); })},
- {"failed_count", std::count_if(results.begin(), results.end(),
- [](const json& result) { return !result["success"].get<bool>(); })},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Batch operation failed: ") + e.what(), 500, "BATCH_OPERATION_ERROR", requestId);
- }
- }
- void Server::handleValidateModel(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- // Parse JSON request body
- json requestJson = json::parse(req.body);
- if (!requestJson.contains("model_path") || !requestJson["model_path"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'model_path' field", 400, "INVALID_MODEL_PATH", requestId);
- return;
- }
- std::string modelPath = requestJson["model_path"];
- std::string modelType = requestJson.value("model_type", "checkpoint");
- // Validate model file
- json validation = validateModelFile(modelPath, modelType);
- json response = {
- {"validation", validation},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Model validation failed: ") + e.what(), 500, "MODEL_VALIDATION_ERROR", requestId);
- }
- }
- void Server::handleCheckCompatibility(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Parse JSON request body
- json requestJson = json::parse(req.body);
- if (!requestJson.contains("model_name") || !requestJson["model_name"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'model_name' field", 400, "INVALID_MODEL_NAME", requestId);
- return;
- }
- std::string modelName = requestJson["model_name"];
- std::string systemInfo = requestJson.value("system_info", "auto");
- // Get model information
- auto modelInfo = m_modelManager->getModelInfo(modelName);
- if (modelInfo.name.empty()) {
- sendErrorResponse(res, "Model not found", 404, "MODEL_NOT_FOUND", requestId);
- return;
- }
- // Check compatibility
- json compatibility = checkModelCompatibility(modelInfo, systemInfo);
- json response = {
- {"model", modelName},
- {"compatibility", compatibility},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Compatibility check failed: ") + e.what(), 500, "COMPATIBILITY_CHECK_ERROR", requestId);
- }
- }
- void Server::handleModelRequirements(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- // Parse JSON request body
- json requestJson = json::parse(req.body);
- std::string modelType = requestJson.value("model_type", "checkpoint");
- std::string resolution = requestJson.value("resolution", "512x512");
- std::string batchSize = requestJson.value("batch_size", "1");
- // Calculate specific requirements
- json requirements = calculateSpecificRequirements(modelType, resolution, batchSize);
- // Get general requirements for model type
- ModelType type = ModelManager::stringToModelType(modelType);
- json generalRequirements = getModelRequirements(type);
- json response = {
- {"model_type", modelType},
- {"configuration", {
- {"resolution", resolution},
- {"batch_size", batchSize}
- }},
- {"specific_requirements", requirements},
- {"general_requirements", generalRequirements},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Requirements calculation failed: ") + e.what(), 500, "REQUIREMENTS_ERROR", requestId);
- }
- }
- void Server::serverThreadFunction(const std::string& host, int port) {
- try {
- std::cout << "Server thread starting, attempting to bind to " << host << ":" << port << std::endl;
- // Check if port is available before attempting to bind
- std::cout << "Checking if port " << port << " is available..." << std::endl;
- // Try to create a test socket to check if port is in use
- int test_socket = socket(AF_INET, SOCK_STREAM, 0);
- if (test_socket >= 0) {
- // Set SO_REUSEADDR to avoid TIME_WAIT issues
- int opt = 1;
- setsockopt(test_socket, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
- struct sockaddr_in addr;
- addr.sin_family = AF_INET;
- addr.sin_port = htons(port);
- addr.sin_addr.s_addr = INADDR_ANY;
- // Try to bind to the port
- if (bind(test_socket, (struct sockaddr*)&addr, sizeof(addr)) < 0) {
- close(test_socket);
- std::cerr << "ERROR: Port " << port << " is already in use! Cannot start server." << std::endl;
- std::cerr << "Please stop the existing instance or use a different port." << std::endl;
- m_isRunning.store(false);
- m_startupFailed.store(true);
- return;
- }
- close(test_socket);
- }
- std::cout << "Port " << port << " is available, proceeding with server startup..." << std::endl;
- std::cout << "Calling listen()..." << std::endl;
- // The listen() call will block until server is stopped
- // listen() returns true if it successfully binds and starts
- // Once it binds successfully, we set m_isRunning to true via a callback
- // Set up a flag to track if listen started successfully
- std::atomic<bool> listenStarted{false};
- // We need to set m_isRunning after successful bind but before blocking
- // cpp-httplib doesn't provide a callback, so we set it optimistically
- // and clear it if listen() returns false
- m_isRunning.store(true);
- bool listenResult = m_httpServer->listen(host.c_str(), port);
- std::cout << "listen() returned: " << (listenResult ? "true" : "false") << std::endl;
- // If we reach here, server has stopped (either normally or due to error)
- m_isRunning.store(false);
- if (!listenResult) {
- std::cerr << "Server listen failed! This usually means port is in use or permission denied." << std::endl;
- }
- } catch (const std::exception& e) {
- std::cerr << "Exception in server thread: " << e.what() << std::endl;
- m_isRunning.store(false);
- }
- }
|