| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536 |
- #include "server.h"
- #include "model_manager.h"
- #include "generation_queue.h"
- #include "utils.h"
- #include "auth_middleware.h"
- #include "user_manager.h"
- #include <httplib.h>
- #include <nlohmann/json.hpp>
- #include <iostream>
- #include <sstream>
- #include <fstream>
- #include <chrono>
- #include <random>
- #include <iomanip>
- #include <algorithm>
- #include <thread>
- #include <filesystem>
- // Include stb_image for loading images (implementation is in generation_queue.cpp)
- #include "../stable-diffusion.cpp-src/thirdparty/stb_image.h"
- #include <sys/socket.h>
- #include <netinet/in.h>
- #include <unistd.h>
- #include <arpa/inet.h>
- using json = nlohmann::json;
- Server::Server(ModelManager* modelManager, GenerationQueue* generationQueue, const std::string& outputDir, const std::string& uiDir)
- : m_modelManager(modelManager)
- , m_generationQueue(generationQueue)
- , m_isRunning(false)
- , m_startupFailed(false)
- , m_port(8080)
- , m_outputDir(outputDir)
- , m_uiDir(uiDir)
- , m_userManager(nullptr)
- , m_authMiddleware(nullptr)
- {
- m_httpServer = std::make_unique<httplib::Server>();
- }
- Server::~Server() {
- stop();
- }
- bool Server::start(const std::string& host, int port) {
- if (m_isRunning.load()) {
- return false;
- }
- m_host = host;
- m_port = port;
- // Validate host and port
- if (host.empty() || (port < 1 || port > 65535)) {
- return false;
- }
- // Set up CORS headers
- setupCORS();
- // Register API endpoints
- registerEndpoints();
- // Reset startup flags
- m_startupFailed.store(false);
- // Start server in a separate thread
- m_serverThread = std::thread(&Server::serverThreadFunction, this, host, port);
- // Wait for server to actually start and bind to the port
- // Give more time for server to actually start and bind
- for (int i = 0; i < 100; i++) { // Wait up to 10 seconds
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
- // Check if startup failed early
- if (m_startupFailed.load()) {
- if (m_serverThread.joinable()) {
- m_serverThread.join();
- }
- return false;
- }
- if (m_isRunning.load()) {
- // Give it a moment more to ensure server is fully started
- std::this_thread::sleep_for(std::chrono::milliseconds(500));
- if (m_isRunning.load()) {
- return true;
- }
- }
- }
- if (m_isRunning.load()) {
- return true;
- } else {
- if (m_serverThread.joinable()) {
- m_serverThread.join();
- }
- return false;
- }
- }
- void Server::stop() {
- // Use atomic check to ensure thread safety
- bool wasRunning = m_isRunning.exchange(false);
- if (!wasRunning) {
- return; // Already stopped
- }
- if (m_httpServer) {
- m_httpServer->stop();
- // Give the server a moment to stop the blocking listen call
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
- // If server thread is still running, try to force unblock the listen call
- // by making a quick connection to the server port
- if (m_serverThread.joinable()) {
- try {
- // Create a quick connection to interrupt the blocking listen
- httplib::Client client("127.0.0.1", m_port);
- client.set_connection_timeout(0, 500000); // 0.5 seconds
- client.set_read_timeout(0, 500000); // 0.5 seconds
- client.set_write_timeout(0, 500000); // 0.5 seconds
- auto res = client.Get("/api/health");
- // We don't care about the response, just trying to unblock
- } catch (...) {
- // Ignore any connection errors - we're just trying to unblock
- }
- }
- }
- if (m_serverThread.joinable()) {
- m_serverThread.join();
- }
- }
- bool Server::isRunning() const {
- return m_isRunning.load();
- }
- void Server::waitForStop() {
- if (m_serverThread.joinable()) {
- m_serverThread.join();
- }
- }
- void Server::registerEndpoints() {
- // Register authentication endpoints first (before applying middleware)
- registerAuthEndpoints();
- // Health check endpoint (public)
- m_httpServer->Get("/api/health", [this](const httplib::Request& req, httplib::Response& res) {
- handleHealthCheck(req, res);
- });
- // API status endpoint (public)
- m_httpServer->Get("/api/status", [this](const httplib::Request& req, httplib::Response& res) {
- handleApiStatus(req, res);
- });
- // Apply authentication middleware to protected endpoints
- auto withAuth = [this](std::function<void(const httplib::Request&, httplib::Response&)> handler) {
- return [this, handler](const httplib::Request& req, httplib::Response& res) {
- if (m_authMiddleware) {
- AuthContext authContext = m_authMiddleware->authenticate(req, res);
- if (!authContext.authenticated) {
- m_authMiddleware->sendAuthError(res, authContext.errorMessage, authContext.errorCode);
- return;
- }
- }
- handler(req, res);
- };
- };
- // Specialized generation endpoints (protected)
- m_httpServer->Post("/api/generate/text2img", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleText2Img(req, res);
- }));
- m_httpServer->Post("/api/generate/img2img", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleImg2Img(req, res);
- }));
- m_httpServer->Post("/api/generate/controlnet", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleControlNet(req, res);
- }));
- m_httpServer->Post("/api/generate/upscale", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleUpscale(req, res);
- }));
- m_httpServer->Post("/api/generate/inpainting", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleInpainting(req, res);
- }));
- // Utility endpoints (now protected - require authentication)
- m_httpServer->Get("/api/samplers", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleSamplers(req, res);
- }));
- m_httpServer->Get("/api/schedulers", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleSchedulers(req, res);
- }));
- m_httpServer->Get("/api/parameters", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleParameters(req, res);
- }));
- m_httpServer->Post("/api/validate", [this](const httplib::Request& req, httplib::Response& res) {
- handleValidate(req, res);
- });
- m_httpServer->Post("/api/estimate", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleEstimate(req, res);
- }));
- m_httpServer->Get("/api/config", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleConfig(req, res);
- }));
- m_httpServer->Get("/api/system", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleSystem(req, res);
- }));
- m_httpServer->Post("/api/system/restart", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleSystemRestart(req, res);
- }));
- // Models list endpoint (now protected - require authentication)
- m_httpServer->Get("/api/models", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleModelsList(req, res);
- }));
- // Model-specific endpoints
- m_httpServer->Get("/api/models/(.*)", [this](const httplib::Request& req, httplib::Response& res) {
- handleModelInfo(req, res);
- });
- m_httpServer->Post("/api/models/(.*)/load", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleLoadModelById(req, res);
- }));
- m_httpServer->Post("/api/models/(.*)/unload", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleUnloadModelById(req, res);
- }));
- // Model management endpoints (now protected - require authentication)
- m_httpServer->Get("/api/models/types", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleModelTypes(req, res);
- }));
- m_httpServer->Get("/api/models/directories", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleModelDirectories(req, res);
- }));
- m_httpServer->Post("/api/models/refresh", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleRefreshModels(req, res);
- }));
- m_httpServer->Post("/api/models/hash", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleHashModels(req, res);
- }));
- m_httpServer->Post("/api/models/convert", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleConvertModel(req, res);
- }));
- m_httpServer->Get("/api/models/stats", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleModelStats(req, res);
- }));
- m_httpServer->Post("/api/models/batch", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleBatchModels(req, res);
- }));
- // Model validation endpoints (already protected with withAuth)
- m_httpServer->Post("/api/models/validate", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleValidateModel(req, res);
- }));
- m_httpServer->Post("/api/models/compatible", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleCheckCompatibility(req, res);
- }));
- m_httpServer->Post("/api/models/requirements", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleModelRequirements(req, res);
- }));
- // Queue status endpoint (now protected - require authentication)
- m_httpServer->Get("/api/queue/status", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleQueueStatus(req, res);
- }));
- // Download job output file endpoint (must be before job status endpoint to match more specific pattern first)
- // Note: This endpoint is public to allow frontend to display generated images without authentication
- m_httpServer->Get("/api/queue/job/(.*)/output/(.*)", [this](const httplib::Request& req, httplib::Response& res) {
- handleDownloadOutput(req, res);
- });
- // Job status endpoint (now protected - require authentication)
- m_httpServer->Get("/api/queue/job/(.*)", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleJobStatus(req, res);
- }));
- // Cancel job endpoint (protected)
- m_httpServer->Post("/api/queue/cancel", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleCancelJob(req, res);
- }));
- // Clear queue endpoint (protected)
- m_httpServer->Post("/api/queue/clear", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleClearQueue(req, res);
- }));
- // Serve static web UI files if uiDir is configured
- if (!m_uiDir.empty() && std::filesystem::exists(m_uiDir)) {
- std::cout << "Serving static UI files from: " << m_uiDir << " at /ui" << std::endl;
- // Read UI version from version.json if available
- std::string uiVersion = "unknown";
- std::string versionFilePath = m_uiDir + "/version.json";
- if (std::filesystem::exists(versionFilePath)) {
- try {
- std::ifstream versionFile(versionFilePath);
- if (versionFile.is_open()) {
- nlohmann::json versionData = nlohmann::json::parse(versionFile);
- if (versionData.contains("version")) {
- uiVersion = versionData["version"].get<std::string>();
- }
- versionFile.close();
- }
- } catch (const std::exception& e) {
- std::cerr << "Failed to read UI version: " << e.what() << std::endl;
- }
- }
- std::cout << "UI version: " << uiVersion << std::endl;
- // Serve dynamic config.js that provides runtime configuration to the web UI
- m_httpServer->Get("/ui/config.js", [this, uiVersion](const httplib::Request& req, httplib::Response& res) {
- // Generate JavaScript configuration with current server settings
- std::ostringstream configJs;
- configJs << "// Auto-generated configuration\n"
- << "window.__SERVER_CONFIG__ = {\n"
- << " apiUrl: 'http://" << m_host << ":" << m_port << "',\n"
- << " apiBasePath: '/api',\n"
- << " host: '" << m_host << "',\n"
- << " port: " << m_port << ",\n"
- << " uiVersion: '" << uiVersion << "',\n";
- // Add authentication method information
- if (m_authMiddleware) {
- auto authConfig = m_authMiddleware->getConfig();
- std::string authMethod = "none";
- switch (authConfig.authMethod) {
- case AuthMethod::UNIX:
- authMethod = "unix";
- break;
- case AuthMethod::JWT:
- authMethod = "jwt";
- break;
- default:
- authMethod = "none";
- break;
- }
- configJs << " authMethod: '" << authMethod << "',\n"
- << " authEnabled: " << (authConfig.authMethod != AuthMethod::NONE ? "true" : "false") << "\n";
- } else {
- configJs << " authMethod: 'none',\n"
- << " authEnabled: false\n";
- }
- configJs << "};\n";
- // No cache for config.js - always fetch fresh
- res.set_header("Cache-Control", "no-cache, no-store, must-revalidate");
- res.set_header("Pragma", "no-cache");
- res.set_header("Expires", "0");
- res.set_content(configJs.str(), "application/javascript");
- });
- // Set up file request handler for caching static assets
- m_httpServer->set_file_request_handler([uiVersion](const httplib::Request& req, httplib::Response& res) {
- // Add cache headers based on file type and version
- std::string path = req.path;
- // For versioned static assets (.js, .css, images), use long cache
- if (path.find("/_next/") != std::string::npos ||
- path.find(".js") != std::string::npos ||
- path.find(".css") != std::string::npos ||
- path.find(".png") != std::string::npos ||
- path.find(".jpg") != std::string::npos ||
- path.find(".svg") != std::string::npos ||
- path.find(".ico") != std::string::npos ||
- path.find(".woff") != std::string::npos ||
- path.find(".woff2") != std::string::npos ||
- path.find(".ttf") != std::string::npos) {
- // Long cache (1 year) for static assets
- res.set_header("Cache-Control", "public, max-age=31536000, immutable");
- // Add ETag based on UI version for cache validation
- res.set_header("ETag", "\"" + uiVersion + "\"");
- // Check If-None-Match for conditional requests
- if (req.has_header("If-None-Match")) {
- std::string clientETag = req.get_header_value("If-None-Match");
- if (clientETag == "\"" + uiVersion + "\"") {
- res.status = 304; // Not Modified
- return;
- }
- }
- } else if (path.find(".html") != std::string::npos || path == "/ui/" || path == "/ui") {
- // HTML files should revalidate but can be cached briefly
- res.set_header("Cache-Control", "public, max-age=0, must-revalidate");
- res.set_header("ETag", "\"" + uiVersion + "\"");
- }
- });
- // Create a handler for UI routes with authentication check
- auto uiHandler = [this](const httplib::Request& req, httplib::Response& res) {
- // Check if authentication is enabled
- if (m_authMiddleware) {
- auto authConfig = m_authMiddleware->getConfig();
- if (authConfig.authMethod != AuthMethod::NONE) {
- // Authentication is enabled, check if user is authenticated
- AuthContext authContext = m_authMiddleware->authenticate(req, res);
- // For Unix auth, we need to check if the user is authenticated
- // The authenticateUnix function will return a guest context for UI requests
- // when no Authorization header is present, but we still need to show the login page
- if (!authContext.authenticated) {
- // Check if this is a request for a static asset (JS, CSS, images)
- // These should be served even without authentication to allow the login page to work
- bool isStaticAsset = false;
- std::string path = req.path;
- if (path.find(".js") != std::string::npos ||
- path.find(".css") != std::string::npos ||
- path.find(".png") != std::string::npos ||
- path.find(".jpg") != std::string::npos ||
- path.find(".jpeg") != std::string::npos ||
- path.find(".svg") != std::string::npos ||
- path.find(".ico") != std::string::npos ||
- path.find("/_next/") != std::string::npos) {
- isStaticAsset = true;
- }
- // For static assets, allow them to be served without authentication
- if (isStaticAsset) {
- // Continue to serve the file
- } else {
- // For HTML requests, redirect to login page
- if (req.path.find(".html") != std::string::npos ||
- req.path == "/ui/" || req.path == "/ui") {
- // Serve the login page instead of the requested page
- std::string loginPagePath = m_uiDir + "/login.html";
- if (std::filesystem::exists(loginPagePath)) {
- std::ifstream loginFile(loginPagePath);
- if (loginFile.is_open()) {
- std::string content((std::istreambuf_iterator<char>(loginFile)),
- std::istreambuf_iterator<char>());
- res.set_content(content, "text/html");
- return;
- }
- }
- // If login.html doesn't exist, serve a simple login page
- std::string simpleLoginPage = R"(
- <!DOCTYPE html>
- <html>
- <head>
- <title>Login Required</title>
- <style>
- body { font-family: Arial, sans-serif; max-width: 500px; margin: 100px auto; padding: 20px; }
- .form-group { margin-bottom: 15px; }
- label { display: block; margin-bottom: 5px; }
- input { width: 100%; padding: 8px; box-sizing: border-box; }
- button { background-color: #007bff; color: white; padding: 10px 15px; border: none; cursor: pointer; }
- .error { color: red; margin-top: 10px; }
- </style>
- </head>
- <body>
- <h1>Login Required</h1>
- <p>Please enter your username to continue.</p>
- <form id="loginForm">
- <div class="form-group">
- <label for="username">Username:</label>
- <input type="text" id="username" name="username" required>
- </div>
- <button type="submit">Login</button>
- </form>
- <div id="error" class="error"></div>
- <script>
- document.getElementById('loginForm').addEventListener('submit', async (e) => {
- e.preventDefault();
- const username = document.getElementById('username').value;
- const errorDiv = document.getElementById('error');
- try {
- const response = await fetch('/api/auth/login', {
- method: 'POST',
- headers: { 'Content-Type': 'application/json' },
- body: JSON.stringify({ username })
- });
- if (response.ok) {
- const data = await response.json();
- localStorage.setItem('auth_token', data.token);
- localStorage.setItem('unix_user', username);
- window.location.reload();
- } else {
- const error = await response.json();
- errorDiv.textContent = error.message || 'Login failed';
- }
- } catch (err) {
- errorDiv.textContent = 'Login failed: ' + err.message;
- }
- });
- </script>
- </body>
- </html>
- )";
- res.set_content(simpleLoginPage, "text/html");
- return;
- } else {
- // For non-HTML files, return unauthorized
- m_authMiddleware->sendAuthError(res, "Authentication required", "AUTH_REQUIRED");
- return;
- }
- }
- }
- }
- }
- // If we get here, either auth is disabled or user is authenticated
- // Serve the requested file
- std::string filePath = req.path.substr(3); // Remove "/ui" prefix
- if (filePath.empty() || filePath == "/") {
- filePath = "/index.html";
- }
- std::string fullPath = m_uiDir + filePath;
- if (std::filesystem::exists(fullPath) && std::filesystem::is_regular_file(fullPath)) {
- std::ifstream file(fullPath, std::ios::binary);
- if (file.is_open()) {
- std::string content((std::istreambuf_iterator<char>(file)),
- std::istreambuf_iterator<char>());
- // Determine content type based on file extension
- std::string contentType = "text/plain";
- if (filePath.find(".html") != std::string::npos) {
- contentType = "text/html";
- } else if (filePath.find(".js") != std::string::npos) {
- contentType = "application/javascript";
- } else if (filePath.find(".css") != std::string::npos) {
- contentType = "text/css";
- } else if (filePath.find(".png") != std::string::npos) {
- contentType = "image/png";
- } else if (filePath.find(".jpg") != std::string::npos || filePath.find(".jpeg") != std::string::npos) {
- contentType = "image/jpeg";
- } else if (filePath.find(".svg") != std::string::npos) {
- contentType = "image/svg+xml";
- }
- res.set_content(content, contentType);
- } else {
- res.status = 404;
- res.set_content("File not found", "text/plain");
- }
- } else {
- // For SPA routing, if the file doesn't exist, serve index.html
- // This allows Next.js to handle client-side routing
- std::string indexPath = m_uiDir + "/index.html";
- if (std::filesystem::exists(indexPath)) {
- std::ifstream indexFile(indexPath, std::ios::binary);
- if (indexFile.is_open()) {
- std::string content((std::istreambuf_iterator<char>(indexFile)),
- std::istreambuf_iterator<char>());
- res.set_content(content, "text/html");
- } else {
- res.status = 404;
- res.set_content("File not found", "text/plain");
- }
- } else {
- res.status = 404;
- res.set_content("File not found", "text/plain");
- }
- }
- };
- // Set up UI routes with authentication
- m_httpServer->Get("/ui/.*", uiHandler);
- // Redirect /ui to /ui/ to ensure proper routing
- m_httpServer->Get("/ui", [](const httplib::Request& req, httplib::Response& res) {
- res.set_redirect("/ui/");
- });
- }
- }
- void Server::setAuthComponents(std::shared_ptr<UserManager> userManager, std::shared_ptr<AuthMiddleware> authMiddleware) {
- m_userManager = userManager;
- m_authMiddleware = authMiddleware;
- }
- void Server::registerAuthEndpoints() {
- // Login endpoint
- m_httpServer->Post("/api/auth/login", [this](const httplib::Request& req, httplib::Response& res) {
- handleLogin(req, res);
- });
- // Logout endpoint
- m_httpServer->Post("/api/auth/logout", [this](const httplib::Request& req, httplib::Response& res) {
- handleLogout(req, res);
- });
- // Token validation endpoint
- m_httpServer->Get("/api/auth/validate", [this](const httplib::Request& req, httplib::Response& res) {
- handleValidateToken(req, res);
- });
- // Refresh token endpoint
- m_httpServer->Post("/api/auth/refresh", [this](const httplib::Request& req, httplib::Response& res) {
- handleRefreshToken(req, res);
- });
- // Get current user endpoint
- m_httpServer->Get("/api/auth/me", [this](const httplib::Request& req, httplib::Response& res) {
- handleGetCurrentUser(req, res);
- });
- }
- void Server::handleLogin(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_userManager || !m_authMiddleware) {
- sendErrorResponse(res, "Authentication system not available", 500, "AUTH_UNAVAILABLE", requestId);
- return;
- }
- // Parse request body
- json requestJson;
- try {
- requestJson = json::parse(req.body);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- return;
- }
- // Check if using Unix authentication
- if (m_authMiddleware->getConfig().authMethod == AuthMethod::UNIX) {
- // For Unix auth, get username and password from request body
- std::string username = requestJson.value("username", "");
- std::string password = requestJson.value("password", "");
- if (username.empty()) {
- sendErrorResponse(res, "Missing username", 400, "MISSING_USERNAME", requestId);
- return;
- }
- // Check if PAM is enabled - if so, password is required
- if (m_userManager->isPamAuthEnabled() && password.empty()) {
- sendErrorResponse(res, "Password is required for Unix authentication", 400, "MISSING_PASSWORD", requestId);
- return;
- }
- // Authenticate Unix user (with or without password depending on PAM)
- auto result = m_userManager->authenticateUnix(username, password);
- if (!result.success) {
- sendErrorResponse(res, result.errorMessage, 401, "UNIX_AUTH_FAILED", requestId);
- return;
- }
- // Generate simple token for Unix auth
- std::string token = "unix_token_" + std::to_string(std::chrono::duration_cast<std::chrono::seconds>(
- std::chrono::system_clock::now().time_since_epoch()).count()) + "_" + username;
- json response = {
- {"token", token},
- {"user", {
- {"id", result.userId},
- {"username", result.username},
- {"role", result.role},
- {"permissions", result.permissions}
- }},
- {"message", "Unix authentication successful"}
- };
- sendJsonResponse(res, response);
- return;
- }
- // For non-Unix auth, validate required fields
- if (!requestJson.contains("username") || !requestJson.contains("password")) {
- sendErrorResponse(res, "Missing username or password", 400, "MISSING_CREDENTIALS", requestId);
- return;
- }
- std::string username = requestJson["username"];
- std::string password = requestJson["password"];
- // Authenticate user
- auto result = m_userManager->authenticateUser(username, password);
- if (!result.success) {
- sendErrorResponse(res, result.errorMessage, 401, "INVALID_CREDENTIALS", requestId);
- return;
- }
- // Generate JWT token if using JWT auth
- std::string token;
- if (m_authMiddleware->getConfig().authMethod == AuthMethod::JWT) {
- // For now, create a simple token (in a real implementation, use JWT)
- token = "token_" + std::to_string(std::chrono::duration_cast<std::chrono::seconds>(
- std::chrono::system_clock::now().time_since_epoch()).count()) + "_" + username;
- }
- json response = {
- {"token", token},
- {"user", {
- {"id", result.userId},
- {"username", result.username},
- {"role", result.role},
- {"permissions", result.permissions}
- }},
- {"message", "Login successful"}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Login failed: ") + e.what(), 500, "LOGIN_ERROR", requestId);
- }
- }
- void Server::handleLogout(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- // For now, just return success (in a real implementation, invalidate the token)
- json response = {
- {"message", "Logout successful"}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Logout failed: ") + e.what(), 500, "LOGOUT_ERROR", requestId);
- }
- }
- void Server::handleValidateToken(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_userManager || !m_authMiddleware) {
- sendErrorResponse(res, "Authentication system not available", 500, "AUTH_UNAVAILABLE", requestId);
- return;
- }
- // Extract token from header
- std::string authHeader = req.get_header_value("Authorization");
- if (authHeader.empty()) {
- sendErrorResponse(res, "Missing authorization token", 401, "MISSING_TOKEN", requestId);
- return;
- }
- // Simple token validation (in a real implementation, validate JWT)
- // For now, just check if it starts with "token_"
- if (authHeader.find("Bearer ") != 0) {
- sendErrorResponse(res, "Invalid authorization header format", 401, "INVALID_HEADER", requestId);
- return;
- }
- std::string token = authHeader.substr(7); // Remove "Bearer "
- if (token.find("token_") != 0) {
- sendErrorResponse(res, "Invalid token", 401, "INVALID_TOKEN", requestId);
- return;
- }
- // Extract username from token (simple format: token_timestamp_username)
- size_t last_underscore = token.find_last_of('_');
- if (last_underscore == std::string::npos) {
- sendErrorResponse(res, "Invalid token format", 401, "INVALID_TOKEN", requestId);
- return;
- }
- std::string username = token.substr(last_underscore + 1);
- // Get user info
- auto userInfo = m_userManager->getUserInfoByUsername(username);
- if (userInfo.id.empty()) {
- sendErrorResponse(res, "User not found", 401, "USER_NOT_FOUND", requestId);
- return;
- }
- json response = {
- {"user", {
- {"id", userInfo.id},
- {"username", userInfo.username},
- {"role", userInfo.role},
- {"permissions", userInfo.permissions}
- }},
- {"valid", true}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Token validation failed: ") + e.what(), 500, "VALIDATION_ERROR", requestId);
- }
- }
- void Server::handleRefreshToken(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- // For now, just return a new token (in a real implementation, refresh JWT)
- json response = {
- {"token", "new_token_" + std::to_string(std::chrono::duration_cast<std::chrono::seconds>(
- std::chrono::system_clock::now().time_since_epoch()).count())},
- {"message", "Token refreshed successfully"}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Token refresh failed: ") + e.what(), 500, "REFRESH_ERROR", requestId);
- }
- }
- void Server::handleGetCurrentUser(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_userManager || !m_authMiddleware) {
- sendErrorResponse(res, "Authentication system not available", 500, "AUTH_UNAVAILABLE", requestId);
- return;
- }
- // Authenticate the request
- AuthContext authContext = m_authMiddleware->authenticate(req, res);
- if (!authContext.authenticated) {
- sendErrorResponse(res, "Authentication required", 401, "AUTH_REQUIRED", requestId);
- return;
- }
- json response = {
- {"user", {
- {"id", authContext.userId},
- {"username", authContext.username},
- {"role", authContext.role},
- {"permissions", authContext.permissions}
- }}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Get current user failed: ") + e.what(), 500, "USER_ERROR", requestId);
- }
- }
- void Server::setupCORS() {
- // Use post-routing handler to set CORS headers after the response is generated
- // This ensures we don't duplicate headers that may be set by other handlers
- m_httpServer->set_post_routing_handler([](const httplib::Request& req, httplib::Response& res) {
- // Only add CORS headers if they haven't been set already
- if (!res.has_header("Access-Control-Allow-Origin")) {
- res.set_header("Access-Control-Allow-Origin", "*");
- }
- if (!res.has_header("Access-Control-Allow-Methods")) {
- res.set_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS");
- }
- if (!res.has_header("Access-Control-Allow-Headers")) {
- res.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
- }
- });
- // Handle OPTIONS requests for CORS preflight (API endpoints only)
- m_httpServer->Options("/api/.*", [](const httplib::Request&, httplib::Response& res) {
- res.set_header("Access-Control-Allow-Origin", "*");
- res.set_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS");
- res.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
- res.status = 200;
- });
- }
- void Server::handleHealthCheck(const httplib::Request& req, httplib::Response& res) {
- try {
- json response = {
- {"status", "healthy"},
- {"timestamp", std::chrono::duration_cast<std::chrono::seconds>(
- std::chrono::system_clock::now().time_since_epoch()).count()},
- {"version", "1.0.0"}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Health check failed: ") + e.what(), 500);
- }
- }
- void Server::handleApiStatus(const httplib::Request& req, httplib::Response& res) {
- try {
- json response = {
- {"server", {
- {"running", m_isRunning.load()},
- {"host", m_host},
- {"port", m_port}
- }},
- {"generation_queue", {
- {"running", m_generationQueue ? m_generationQueue->isRunning() : false},
- {"queue_size", m_generationQueue ? m_generationQueue->getQueueSize() : 0},
- {"active_generations", m_generationQueue ? m_generationQueue->getActiveGenerations() : 0}
- }},
- {"models", {
- {"loaded_count", m_modelManager ? m_modelManager->getLoadedModelsCount() : 0},
- {"available_count", m_modelManager ? m_modelManager->getAvailableModelsCount() : 0}
- }}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Status check failed: ") + e.what(), 500);
- }
- }
- void Server::handleModelsList(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Parse query parameters for enhanced filtering
- std::string typeFilter = req.get_param_value("type");
- std::string searchQuery = req.get_param_value("search");
- std::string sortBy = req.get_param_value("sort_by");
- std::string sortOrder = req.get_param_value("sort_order");
- std::string dateFilter = req.get_param_value("date");
- std::string sizeFilter = req.get_param_value("size");
- // Pagination parameters
- int page = 1;
- int limit = 50;
- try {
- if (!req.get_param_value("page").empty()) {
- page = std::stoi(req.get_param_value("page"));
- if (page < 1) page = 1;
- }
- if (!req.get_param_value("limit").empty()) {
- limit = std::stoi(req.get_param_value("limit"));
- if (limit < 1) limit = 1;
- if (limit > 200) limit = 200; // Max limit to prevent performance issues
- }
- } catch (const std::exception& e) {
- sendErrorResponse(res, "Invalid pagination parameters", 400, "INVALID_PAGINATION", requestId);
- return;
- }
- // Filter parameters
- bool includeLoaded = req.get_param_value("loaded") == "true";
- bool includeUnloaded = req.get_param_value("unloaded") == "true";
- bool includeMetadata = req.get_param_value("include_metadata") == "true";
- bool includeThumbnails = req.get_param_value("include_thumbnails") == "true";
- // Get all models
- auto allModels = m_modelManager->getAllModels();
- json models = json::array();
- // Apply filters and build response
- for (const auto& pair : allModels) {
- const auto& modelInfo = pair.second;
- // Apply type filter
- if (!typeFilter.empty()) {
- ModelType filterType = ModelManager::stringToModelType(typeFilter);
- if (modelInfo.type != filterType) continue;
- }
- // Apply loaded/unloaded filters
- if (includeLoaded && !modelInfo.isLoaded) continue;
- if (includeUnloaded && modelInfo.isLoaded) continue;
- // Apply search filter (case-insensitive search in name and description)
- if (!searchQuery.empty()) {
- std::string searchLower = searchQuery;
- std::transform(searchLower.begin(), searchLower.end(), searchLower.begin(), ::tolower);
- std::string nameLower = modelInfo.name;
- std::transform(nameLower.begin(), nameLower.end(), nameLower.begin(), ::tolower);
- std::string descLower = modelInfo.description;
- std::transform(descLower.begin(), descLower.end(), descLower.begin(), ::tolower);
- if (nameLower.find(searchLower) == std::string::npos &&
- descLower.find(searchLower) == std::string::npos) {
- continue;
- }
- }
- // Apply date filter (simplified - expects "recent", "old", or YYYY-MM-DD)
- if (!dateFilter.empty()) {
- auto now = std::filesystem::file_time_type::clock::now();
- auto modelTime = modelInfo.modifiedAt;
- auto duration = std::chrono::duration_cast<std::chrono::hours>(now - modelTime).count();
- if (dateFilter == "recent" && duration > 24 * 7) continue; // Older than 1 week
- if (dateFilter == "old" && duration < 24 * 30) continue; // Newer than 1 month
- }
- // Apply size filter (expects "small", "medium", "large", or size in MB)
- if (!sizeFilter.empty()) {
- double sizeMB = modelInfo.fileSize / (1024.0 * 1024.0);
- if (sizeFilter == "small" && sizeMB > 1024) continue; // > 1GB
- if (sizeFilter == "medium" && (sizeMB < 1024 || sizeMB > 4096)) continue; // < 1GB or > 4GB
- if (sizeFilter == "large" && sizeMB < 4096) continue; // < 4GB
- // Try to parse as specific size in MB
- try {
- double maxSizeMB = std::stod(sizeFilter);
- if (sizeMB > maxSizeMB) continue;
- } catch (...) {
- // Ignore if parsing fails
- }
- }
- // Build model JSON with only essential information
- json modelJson = {
- {"name", modelInfo.name},
- {"type", ModelManager::modelTypeToString(modelInfo.type)},
- {"file_size", modelInfo.fileSize},
- {"file_size_mb", modelInfo.fileSize / (1024.0 * 1024.0)},
- {"sha256", modelInfo.sha256.empty() ? nullptr : json(modelInfo.sha256)},
- {"sha256_short", (modelInfo.sha256.empty() || modelInfo.sha256.length() < 10) ? nullptr : json(modelInfo.sha256.substr(0, 10))}
- };
- // Add architecture information if available (checkpoints only)
- if (!modelInfo.architecture.empty()) {
- modelJson["architecture"] = modelInfo.architecture;
- modelJson["recommended_vae"] = modelInfo.recommendedVAE.empty() ? nullptr : json(modelInfo.recommendedVAE);
- if (modelInfo.recommendedWidth > 0) {
- modelJson["recommended_width"] = modelInfo.recommendedWidth;
- }
- if (modelInfo.recommendedHeight > 0) {
- modelJson["recommended_height"] = modelInfo.recommendedHeight;
- }
- if (modelInfo.recommendedSteps > 0) {
- modelJson["recommended_steps"] = modelInfo.recommendedSteps;
- }
- if (!modelInfo.recommendedSampler.empty()) {
- modelJson["recommended_sampler"] = modelInfo.recommendedSampler;
- }
- if (!modelInfo.requiredModels.empty()) {
- modelJson["required_models"] = modelInfo.requiredModels;
- }
- if (!modelInfo.missingModels.empty()) {
- modelJson["missing_models"] = modelInfo.missingModels;
- modelJson["has_missing_dependencies"] = true;
- } else {
- modelJson["has_missing_dependencies"] = false;
- }
- }
- models.push_back(modelJson);
- }
- // Apply sorting
- if (!sortBy.empty()) {
- std::sort(models.begin(), models.end(), [&sortBy, &sortOrder](const json& a, const json& b) {
- bool ascending = sortOrder != "desc";
- if (sortBy == "name") {
- return ascending ? a["name"] < b["name"] : a["name"] > b["name"];
- } else if (sortBy == "size") {
- return ascending ? a["file_size"] < b["file_size"] : a["file_size"] > b["file_size"];
- } else if (sortBy == "date") {
- return ascending ? a["last_modified"] < b["last_modified"] : a["last_modified"] > b["last_modified"];
- } else if (sortBy == "type") {
- return ascending ? a["type"] < b["type"] : a["type"] > b["type"];
- } else if (sortBy == "loaded") {
- return ascending ? a["is_loaded"] < b["is_loaded"] : a["is_loaded"] > b["is_loaded"];
- }
- return false;
- });
- }
- // Apply pagination
- int totalCount = models.size();
- int totalPages = (totalCount + limit - 1) / limit;
- int startIndex = (page - 1) * limit;
- int endIndex = std::min(startIndex + limit, totalCount);
- json paginatedModels = json::array();
- for (int i = startIndex; i < endIndex; ++i) {
- paginatedModels.push_back(models[i]);
- }
- // Build comprehensive response
- json response = {
- {"models", paginatedModels},
- {"pagination", {
- {"page", page},
- {"limit", limit},
- {"total_count", totalCount},
- {"total_pages", totalPages},
- {"has_next", page < totalPages},
- {"has_prev", page > 1}
- }},
- {"filters_applied", {
- {"type", typeFilter.empty() ? json(nullptr) : json(typeFilter)},
- {"search", searchQuery.empty() ? json(nullptr) : json(searchQuery)},
- {"date", dateFilter.empty() ? json(nullptr) : json(dateFilter)},
- {"size", sizeFilter.empty() ? json(nullptr) : json(sizeFilter)},
- {"loaded", includeLoaded ? json(true) : json(nullptr)},
- {"unloaded", includeUnloaded ? json(true) : json(nullptr)}
- }},
- {"sorting", {
- {"sort_by", sortBy.empty() ? "name" : json(sortBy)},
- {"sort_order", sortOrder.empty() ? "asc" : json(sortOrder)}
- }},
- {"statistics", {
- {"loaded_count", m_modelManager->getLoadedModelsCount()},
- {"available_count", m_modelManager->getAvailableModelsCount()}
- }},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to list models: ") + e.what(), 500, "MODEL_LIST_ERROR", requestId);
- }
- }
- void Server::handleQueueStatus(const httplib::Request& req, httplib::Response& res) {
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500);
- return;
- }
- // Get detailed queue status
- auto jobs = m_generationQueue->getQueueStatus();
- // Convert jobs to JSON
- json jobsJson = json::array();
- for (const auto& job : jobs) {
- std::string statusStr;
- switch (job.status) {
- case GenerationStatus::QUEUED: statusStr = "queued"; break;
- case GenerationStatus::PROCESSING: statusStr = "processing"; break;
- case GenerationStatus::COMPLETED: statusStr = "completed"; break;
- case GenerationStatus::FAILED: statusStr = "failed"; break;
- }
- // Convert time points to timestamps
- auto queuedTime = std::chrono::duration_cast<std::chrono::milliseconds>(
- job.queuedTime.time_since_epoch()).count();
- auto startTime = std::chrono::duration_cast<std::chrono::milliseconds>(
- job.startTime.time_since_epoch()).count();
- auto endTime = std::chrono::duration_cast<std::chrono::milliseconds>(
- job.endTime.time_since_epoch()).count();
- jobsJson.push_back({
- {"id", job.id},
- {"status", statusStr},
- {"prompt", job.prompt},
- {"queued_time", queuedTime},
- {"start_time", startTime > 0 ? json(startTime) : json(nullptr)},
- {"end_time", endTime > 0 ? json(endTime) : json(nullptr)},
- {"position", job.position},
- {"progress", job.progress}
- });
- }
- json response = {
- {"queue", {
- {"size", m_generationQueue->getQueueSize()},
- {"active_generations", m_generationQueue->getActiveGenerations()},
- {"running", m_generationQueue->isRunning()},
- {"jobs", jobsJson}
- }}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Queue status check failed: ") + e.what(), 500);
- }
- }
- void Server::handleJobStatus(const httplib::Request& req, httplib::Response& res) {
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500);
- return;
- }
- // Extract job ID from URL path
- std::string jobId = req.matches[1].str();
- if (jobId.empty()) {
- sendErrorResponse(res, "Missing job ID", 400);
- return;
- }
- // Get job information
- auto jobInfo = m_generationQueue->getJobInfo(jobId);
- if (jobInfo.id.empty()) {
- sendErrorResponse(res, "Job not found", 404);
- return;
- }
- // Convert status to string
- std::string statusStr;
- switch (jobInfo.status) {
- case GenerationStatus::QUEUED: statusStr = "queued"; break;
- case GenerationStatus::PROCESSING: statusStr = "processing"; break;
- case GenerationStatus::COMPLETED: statusStr = "completed"; break;
- case GenerationStatus::FAILED: statusStr = "failed"; break;
- }
- // Convert time points to timestamps
- auto queuedTime = std::chrono::duration_cast<std::chrono::milliseconds>(
- jobInfo.queuedTime.time_since_epoch()).count();
- auto startTime = std::chrono::duration_cast<std::chrono::milliseconds>(
- jobInfo.startTime.time_since_epoch()).count();
- auto endTime = std::chrono::duration_cast<std::chrono::milliseconds>(
- jobInfo.endTime.time_since_epoch()).count();
- // Create download URLs for output files
- json outputUrls = json::array();
- for (const auto& filePath : jobInfo.outputFiles) {
- // Extract filename from full path
- std::filesystem::path p(filePath);
- std::string filename = p.filename().string();
- // Create download URL
- std::string url = "/api/queue/job/" + jobInfo.id + "/output/" + filename;
- json fileInfo = {
- {"filename", filename},
- {"url", url},
- {"path", filePath}
- };
- outputUrls.push_back(fileInfo);
- }
- json response = {
- {"job", {
- {"id", jobInfo.id},
- {"status", statusStr},
- {"prompt", jobInfo.prompt},
- {"queued_time", queuedTime},
- {"start_time", startTime > 0 ? json(startTime) : json(nullptr)},
- {"end_time", endTime > 0 ? json(endTime) : json(nullptr)},
- {"position", jobInfo.position},
- {"outputs", outputUrls},
- {"error_message", jobInfo.errorMessage},
- {"progress", jobInfo.progress}
- }}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Job status check failed: ") + e.what(), 500);
- }
- }
- void Server::handleCancelJob(const httplib::Request& req, httplib::Response& res) {
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500);
- return;
- }
- // Parse JSON request body
- json requestJson = json::parse(req.body);
- // Validate required fields
- if (!requestJson.contains("job_id") || !requestJson["job_id"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'job_id' field", 400);
- return;
- }
- std::string jobId = requestJson["job_id"];
- // Try to cancel the job
- bool cancelled = m_generationQueue->cancelJob(jobId);
- if (cancelled) {
- json response = {
- {"status", "success"},
- {"message", "Job cancelled successfully"},
- {"job_id", jobId}
- };
- sendJsonResponse(res, response);
- } else {
- json response = {
- {"status", "error"},
- {"message", "Job not found or already processing"},
- {"job_id", jobId}
- };
- sendJsonResponse(res, response, 404);
- }
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Job cancellation failed: ") + e.what(), 500);
- }
- }
- void Server::handleClearQueue(const httplib::Request& req, httplib::Response& res) {
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500);
- return;
- }
- // Clear the queue
- m_generationQueue->clearQueue();
- json response = {
- {"status", "success"},
- {"message", "Queue cleared successfully"}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Queue clear failed: ") + e.what(), 500);
- }
- }
- void Server::handleDownloadOutput(const httplib::Request& req, httplib::Response& res) {
- try {
- // Extract job ID and filename from URL path
- if (req.matches.size() < 3) {
- sendErrorResponse(res, "Invalid request: job ID and filename required", 400, "INVALID_REQUEST", "");
- return;
- }
- std::string jobId = req.matches[1];
- std::string filename = req.matches[2];
- // Validate inputs
- if (jobId.empty() || filename.empty()) {
- sendErrorResponse(res, "Job ID and filename cannot be empty", 400, "INVALID_PARAMETERS", "");
- return;
- }
- // Construct absolute file path using the same logic as when saving:
- // {outputDir}/{jobId}/{filename}
- std::string fullPath = std::filesystem::absolute(m_outputDir + "/" + jobId + "/" + filename).string();
- // Log the request for debugging
- std::cout << "Image download request: jobId=" << jobId << ", filename=" << filename
- << ", fullPath=" << fullPath << std::endl;
- // Check if file exists
- if (!std::filesystem::exists(fullPath)) {
- std::cerr << "Output file not found: " << fullPath << std::endl;
- sendErrorResponse(res, "Output file not found: " + filename, 404, "FILE_NOT_FOUND", "");
- return;
- }
- // Check file size to detect zero-byte files
- auto fileSize = std::filesystem::file_size(fullPath);
- if (fileSize == 0) {
- std::cerr << "Output file is zero bytes: " << fullPath << std::endl;
- sendErrorResponse(res, "Output file is empty (corrupted generation)", 500, "EMPTY_FILE", "");
- return;
- }
- // Check if file is accessible
- std::ifstream file(fullPath, std::ios::binary);
- if (!file.is_open()) {
- std::cerr << "Failed to open output file: " << fullPath << std::endl;
- sendErrorResponse(res, "Output file not accessible", 500, "FILE_ACCESS_ERROR", "");
- return;
- }
- // Read file contents
- std::string fileContent;
- try {
- fileContent = std::string(
- std::istreambuf_iterator<char>(file),
- std::istreambuf_iterator<char>()
- );
- file.close();
- } catch (const std::exception& e) {
- std::cerr << "Failed to read file content: " << e.what() << std::endl;
- sendErrorResponse(res, "Failed to read file content", 500, "FILE_READ_ERROR", "");
- return;
- }
- // Verify we actually read data
- if (fileContent.empty()) {
- std::cerr << "File content is empty after read: " << fullPath << std::endl;
- sendErrorResponse(res, "File content is empty after read", 500, "EMPTY_CONTENT", "");
- return;
- }
- // Determine content type based on file extension
- std::string contentType = "application/octet-stream";
- if (Utils::endsWith(filename, ".png")) {
- contentType = "image/png";
- } else if (Utils::endsWith(filename, ".jpg") || Utils::endsWith(filename, ".jpeg")) {
- contentType = "image/jpeg";
- } else if (Utils::endsWith(filename, ".mp4")) {
- contentType = "video/mp4";
- } else if (Utils::endsWith(filename, ".gif")) {
- contentType = "image/gif";
- } else if (Utils::endsWith(filename, ".webp")) {
- contentType = "image/webp";
- }
- // Set response headers for proper browser handling
- res.set_header("Content-Type", contentType);
- res.set_header("Content-Length", std::to_string(fileContent.length()));
- res.set_header("Cache-Control", "public, max-age=3600"); // Cache for 1 hour
- res.set_header("Access-Control-Allow-Origin", "*"); // CORS for image access
- // Uncomment if you want to force download instead of inline display:
- // res.set_header("Content-Disposition", "attachment; filename=\"" + filename + "\"");
- // Set the content
- res.set_content(fileContent, contentType);
- res.status = 200;
- std::cout << "Successfully served image: " << filename << " (" << fileContent.length() << " bytes)" << std::endl;
- } catch (const std::exception& e) {
- std::cerr << "Exception in handleDownloadOutput: " << e.what() << std::endl;
- sendErrorResponse(res, std::string("Failed to download file: ") + e.what(), 500, "DOWNLOAD_ERROR", "");
- }
- }
- void Server::sendJsonResponse(httplib::Response& res, const nlohmann::json& json, int status_code) {
- res.set_header("Content-Type", "application/json");
- res.status = status_code;
- res.body = json.dump();
- }
- void Server::sendErrorResponse(httplib::Response& res, const std::string& message, int status_code,
- const std::string& error_code, const std::string& request_id) {
- json errorResponse = {
- {"error", {
- {"message", message},
- {"status_code", status_code},
- {"error_code", error_code},
- {"request_id", request_id},
- {"timestamp", std::chrono::duration_cast<std::chrono::seconds>(
- std::chrono::system_clock::now().time_since_epoch()).count()}
- }}
- };
- sendJsonResponse(res, errorResponse, status_code);
- }
- std::pair<bool, std::string> Server::validateGenerationParameters(const nlohmann::json& params) {
- // Validate required fields
- if (!params.contains("prompt") || !params["prompt"].is_string()) {
- return {false, "Missing or invalid 'prompt' field"};
- }
- const std::string& prompt = params["prompt"];
- if (prompt.empty()) {
- return {false, "Prompt cannot be empty"};
- }
- if (prompt.length() > 10000) {
- return {false, "Prompt too long (max 10000 characters)"};
- }
- // Validate negative prompt if present
- if (params.contains("negative_prompt")) {
- if (!params["negative_prompt"].is_string()) {
- return {false, "Invalid 'negative_prompt' field, must be string"};
- }
- if (params["negative_prompt"].get<std::string>().length() > 10000) {
- return {false, "Negative prompt too long (max 10000 characters)"};
- }
- }
- // Validate width
- if (params.contains("width")) {
- if (!params["width"].is_number_integer()) {
- return {false, "Invalid 'width' field, must be integer"};
- }
- int width = params["width"];
- if (width < 64 || width > 2048 || width % 64 != 0) {
- return {false, "Width must be between 64 and 2048 and divisible by 64"};
- }
- }
- // Validate height
- if (params.contains("height")) {
- if (!params["height"].is_number_integer()) {
- return {false, "Invalid 'height' field, must be integer"};
- }
- int height = params["height"];
- if (height < 64 || height > 2048 || height % 64 != 0) {
- return {false, "Height must be between 64 and 2048 and divisible by 64"};
- }
- }
- // Validate batch count
- if (params.contains("batch_count")) {
- if (!params["batch_count"].is_number_integer()) {
- return {false, "Invalid 'batch_count' field, must be integer"};
- }
- int batchCount = params["batch_count"];
- if (batchCount < 1 || batchCount > 100) {
- return {false, "Batch count must be between 1 and 100"};
- }
- }
- // Validate steps
- if (params.contains("steps")) {
- if (!params["steps"].is_number_integer()) {
- return {false, "Invalid 'steps' field, must be integer"};
- }
- int steps = params["steps"];
- if (steps < 1 || steps > 150) {
- return {false, "Steps must be between 1 and 150"};
- }
- }
- // Validate CFG scale
- if (params.contains("cfg_scale")) {
- if (!params["cfg_scale"].is_number()) {
- return {false, "Invalid 'cfg_scale' field, must be number"};
- }
- float cfgScale = params["cfg_scale"];
- if (cfgScale < 1.0f || cfgScale > 30.0f) {
- return {false, "CFG scale must be between 1.0 and 30.0"};
- }
- }
- // Validate seed
- if (params.contains("seed")) {
- if (!params["seed"].is_string() && !params["seed"].is_number_integer()) {
- return {false, "Invalid 'seed' field, must be string or integer"};
- }
- }
- // Validate sampling method
- if (params.contains("sampling_method")) {
- if (!params["sampling_method"].is_string()) {
- return {false, "Invalid 'sampling_method' field, must be string"};
- }
- std::string method = params["sampling_method"];
- std::vector<std::string> validMethods = {
- "euler", "euler_a", "heun", "dpm2", "dpm++2s_a", "dpm++2m",
- "dpm++2mv2", "ipndm", "ipndm_v", "lcm", "ddim_trailing", "tcd", "default"
- };
- if (std::find(validMethods.begin(), validMethods.end(), method) == validMethods.end()) {
- return {false, "Invalid sampling method"};
- }
- }
- // Validate scheduler
- if (params.contains("scheduler")) {
- if (!params["scheduler"].is_string()) {
- return {false, "Invalid 'scheduler' field, must be string"};
- }
- std::string scheduler = params["scheduler"];
- std::vector<std::string> validSchedulers = {
- "discrete", "karras", "exponential", "ays", "gits",
- "smoothstep", "sgm_uniform", "simple", "default"
- };
- if (std::find(validSchedulers.begin(), validSchedulers.end(), scheduler) == validSchedulers.end()) {
- return {false, "Invalid scheduler"};
- }
- }
- // Validate strength
- if (params.contains("strength")) {
- if (!params["strength"].is_number()) {
- return {false, "Invalid 'strength' field, must be number"};
- }
- float strength = params["strength"];
- if (strength < 0.0f || strength > 1.0f) {
- return {false, "Strength must be between 0.0 and 1.0"};
- }
- }
- // Validate control strength
- if (params.contains("control_strength")) {
- if (!params["control_strength"].is_number()) {
- return {false, "Invalid 'control_strength' field, must be number"};
- }
- float controlStrength = params["control_strength"];
- if (controlStrength < 0.0f || controlStrength > 1.0f) {
- return {false, "Control strength must be between 0.0 and 1.0"};
- }
- }
- // Validate clip skip
- if (params.contains("clip_skip")) {
- if (!params["clip_skip"].is_number_integer()) {
- return {false, "Invalid 'clip_skip' field, must be integer"};
- }
- int clipSkip = params["clip_skip"];
- if (clipSkip < -1 || clipSkip > 12) {
- return {false, "Clip skip must be between -1 and 12"};
- }
- }
- // Validate threads
- if (params.contains("threads")) {
- if (!params["threads"].is_number_integer()) {
- return {false, "Invalid 'threads' field, must be integer"};
- }
- int threads = params["threads"];
- if (threads < -1 || threads > 32) {
- return {false, "Threads must be between -1 (auto) and 32"};
- }
- }
- return {true, ""};
- }
- SamplingMethod Server::parseSamplingMethod(const std::string& method) {
- if (method == "euler") return SamplingMethod::EULER;
- else if (method == "euler_a") return SamplingMethod::EULER_A;
- else if (method == "heun") return SamplingMethod::HEUN;
- else if (method == "dpm2") return SamplingMethod::DPM2;
- else if (method == "dpm++2s_a") return SamplingMethod::DPMPP2S_A;
- else if (method == "dpm++2m") return SamplingMethod::DPMPP2M;
- else if (method == "dpm++2mv2") return SamplingMethod::DPMPP2MV2;
- else if (method == "ipndm") return SamplingMethod::IPNDM;
- else if (method == "ipndm_v") return SamplingMethod::IPNDM_V;
- else if (method == "lcm") return SamplingMethod::LCM;
- else if (method == "ddim_trailing") return SamplingMethod::DDIM_TRAILING;
- else if (method == "tcd") return SamplingMethod::TCD;
- else return SamplingMethod::DEFAULT;
- }
- Scheduler Server::parseScheduler(const std::string& scheduler) {
- if (scheduler == "discrete") return Scheduler::DISCRETE;
- else if (scheduler == "karras") return Scheduler::KARRAS;
- else if (scheduler == "exponential") return Scheduler::EXPONENTIAL;
- else if (scheduler == "ays") return Scheduler::AYS;
- else if (scheduler == "gits") return Scheduler::GITS;
- else if (scheduler == "smoothstep") return Scheduler::SMOOTHSTEP;
- else if (scheduler == "sgm_uniform") return Scheduler::SGM_UNIFORM;
- else if (scheduler == "simple") return Scheduler::SIMPLE;
- else return Scheduler::DEFAULT;
- }
- std::string Server::generateRequestId() {
- std::random_device rd;
- std::mt19937 gen(rd());
- std::uniform_int_distribution<> dis(100000, 999999);
- return "req_" + std::to_string(dis(gen));
- }
- std::tuple<std::vector<uint8_t>, int, int, int, bool, std::string>
- Server::loadImageFromInput(const std::string& input) {
- std::vector<uint8_t> imageData;
- int width = 0, height = 0, channels = 0;
- // Auto-detect input source type
- // 1. Check if input is a URL (starts with http:// or https://)
- if (Utils::startsWith(input, "http://") || Utils::startsWith(input, "https://")) {
- // Parse URL to extract host and path
- std::string url = input;
- std::string scheme, host, path;
- int port = 80;
- // Determine scheme and port
- if (Utils::startsWith(url, "https://")) {
- scheme = "https";
- port = 443;
- url = url.substr(8); // Remove "https://"
- } else {
- scheme = "http";
- port = 80;
- url = url.substr(7); // Remove "http://"
- }
- // Extract host and path
- size_t slashPos = url.find('/');
- if (slashPos != std::string::npos) {
- host = url.substr(0, slashPos);
- path = url.substr(slashPos);
- } else {
- host = url;
- path = "/";
- }
- // Check for custom port
- size_t colonPos = host.find(':');
- if (colonPos != std::string::npos) {
- try {
- port = std::stoi(host.substr(colonPos + 1));
- host = host.substr(0, colonPos);
- } catch (...) {
- return {imageData, 0, 0, 0, false, "Invalid port in URL"};
- }
- }
- // Download image using httplib
- try {
- httplib::Result res;
- if (scheme == "https") {
- #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
- httplib::SSLClient client(host, port);
- client.set_follow_location(true);
- client.set_connection_timeout(30, 0); // 30 seconds
- client.set_read_timeout(60, 0); // 60 seconds
- res = client.Get(path.c_str());
- #else
- return {imageData, 0, 0, 0, false, "HTTPS not supported (OpenSSL not available)"};
- #endif
- } else {
- httplib::Client client(host, port);
- client.set_follow_location(true);
- client.set_connection_timeout(30, 0); // 30 seconds
- client.set_read_timeout(60, 0); // 60 seconds
- res = client.Get(path.c_str());
- }
- if (!res) {
- return {imageData, 0, 0, 0, false, "Failed to download image from URL: Connection error"};
- }
- if (res->status != 200) {
- return {imageData, 0, 0, 0, false, "Failed to download image from URL: HTTP " + std::to_string(res->status)};
- }
- // Convert response body to vector
- std::vector<uint8_t> downloadedData(res->body.begin(), res->body.end());
- // Load image from memory
- int w, h, c;
- unsigned char* pixels = stbi_load_from_memory(
- downloadedData.data(),
- downloadedData.size(),
- &w, &h, &c,
- 3 // Force RGB
- );
- if (!pixels) {
- return {imageData, 0, 0, 0, false, "Failed to decode image from URL"};
- }
- width = w;
- height = h;
- channels = 3;
- size_t dataSize = width * height * channels;
- imageData.resize(dataSize);
- std::memcpy(imageData.data(), pixels, dataSize);
- stbi_image_free(pixels);
- } catch (const std::exception& e) {
- return {imageData, 0, 0, 0, false, "Failed to download image from URL: " + std::string(e.what())};
- }
- }
- // 2. Check if input is base64 encoded data URI (starts with "data:image")
- else if (Utils::startsWith(input, "data:image")) {
- // Extract base64 data after the comma
- size_t commaPos = input.find(',');
- if (commaPos == std::string::npos) {
- return {imageData, 0, 0, 0, false, "Invalid data URI format"};
- }
- std::string base64Data = input.substr(commaPos + 1);
- std::vector<uint8_t> decodedData = Utils::base64Decode(base64Data);
- // Load image from memory using stb_image
- int w, h, c;
- unsigned char* pixels = stbi_load_from_memory(
- decodedData.data(),
- decodedData.size(),
- &w, &h, &c,
- 3 // Force RGB
- );
- if (!pixels) {
- return {imageData, 0, 0, 0, false, "Failed to decode image from base64 data URI"};
- }
- width = w;
- height = h;
- channels = 3; // We forced RGB
- // Copy pixel data
- size_t dataSize = width * height * channels;
- imageData.resize(dataSize);
- std::memcpy(imageData.data(), pixels, dataSize);
- stbi_image_free(pixels);
- }
- // 3. Check if input is raw base64 (long string without slashes, likely base64)
- else if (input.length() > 100 && input.find('/') == std::string::npos && input.find('.') == std::string::npos) {
- // Likely raw base64 without data URI prefix
- std::vector<uint8_t> decodedData = Utils::base64Decode(input);
- int w, h, c;
- unsigned char* pixels = stbi_load_from_memory(
- decodedData.data(),
- decodedData.size(),
- &w, &h, &c,
- 3 // Force RGB
- );
- if (!pixels) {
- return {imageData, 0, 0, 0, false, "Failed to decode image from base64"};
- }
- width = w;
- height = h;
- channels = 3;
- size_t dataSize = width * height * channels;
- imageData.resize(dataSize);
- std::memcpy(imageData.data(), pixels, dataSize);
- stbi_image_free(pixels);
- }
- // 4. Treat as local file path
- else {
- int w, h, c;
- unsigned char* pixels = stbi_load(input.c_str(), &w, &h, &c, 3);
- if (!pixels) {
- return {imageData, 0, 0, 0, false, "Failed to load image from file: " + input};
- }
- width = w;
- height = h;
- channels = 3;
- size_t dataSize = width * height * channels;
- imageData.resize(dataSize);
- std::memcpy(imageData.data(), pixels, dataSize);
- stbi_image_free(pixels);
- }
- return {imageData, width, height, channels, true, ""};
- }
- std::string Server::samplingMethodToString(SamplingMethod method) {
- switch (method) {
- case SamplingMethod::EULER: return "euler";
- case SamplingMethod::EULER_A: return "euler_a";
- case SamplingMethod::HEUN: return "heun";
- case SamplingMethod::DPM2: return "dpm2";
- case SamplingMethod::DPMPP2S_A: return "dpm++2s_a";
- case SamplingMethod::DPMPP2M: return "dpm++2m";
- case SamplingMethod::DPMPP2MV2: return "dpm++2mv2";
- case SamplingMethod::IPNDM: return "ipndm";
- case SamplingMethod::IPNDM_V: return "ipndm_v";
- case SamplingMethod::LCM: return "lcm";
- case SamplingMethod::DDIM_TRAILING: return "ddim_trailing";
- case SamplingMethod::TCD: return "tcd";
- default: return "default";
- }
- }
- std::string Server::schedulerToString(Scheduler scheduler) {
- switch (scheduler) {
- case Scheduler::DISCRETE: return "discrete";
- case Scheduler::KARRAS: return "karras";
- case Scheduler::EXPONENTIAL: return "exponential";
- case Scheduler::AYS: return "ays";
- case Scheduler::GITS: return "gits";
- case Scheduler::SMOOTHSTEP: return "smoothstep";
- case Scheduler::SGM_UNIFORM: return "sgm_uniform";
- case Scheduler::SIMPLE: return "simple";
- default: return "default";
- }
- }
- uint64_t Server::estimateGenerationTime(const GenerationRequest& request) {
- // Basic estimation based on parameters
- uint64_t baseTime = 1000; // 1 second base time
- // Factor in steps
- baseTime *= request.steps;
- // Factor in resolution
- double resolutionFactor = (request.width * request.height) / (512.0 * 512.0);
- baseTime = static_cast<uint64_t>(baseTime * resolutionFactor);
- // Factor in batch count
- baseTime *= request.batchCount;
- // Adjust for sampling method (some are faster than others)
- switch (request.samplingMethod) {
- case SamplingMethod::LCM:
- baseTime /= 4; // LCM is much faster
- break;
- case SamplingMethod::EULER:
- case SamplingMethod::EULER_A:
- baseTime *= 0.8; // Euler methods are faster
- break;
- case SamplingMethod::DPM2:
- case SamplingMethod::DPMPP2S_A:
- baseTime *= 1.2; // DPM methods are slower
- break;
- default:
- break;
- }
- return baseTime;
- }
- size_t Server::estimateMemoryUsage(const GenerationRequest& request) {
- // Basic memory estimation in bytes
- size_t baseMemory = 1024 * 1024 * 1024; // 1GB base
- // Factor in resolution
- double resolutionFactor = (request.width * request.height) / (512.0 * 512.0);
- baseMemory = static_cast<size_t>(baseMemory * resolutionFactor);
- // Factor in batch count
- baseMemory *= request.batchCount;
- // Additional memory for certain features
- if (request.diffusionFlashAttn) {
- baseMemory += 512 * 1024 * 1024; // Extra 512MB for flash attention
- }
- if (!request.controlNetPath.empty()) {
- baseMemory += 1024 * 1024 * 1024; // Extra 1GB for ControlNet
- }
- return baseMemory;
- }
- // Specialized generation endpoints
- void Server::handleText2Img(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
- return;
- }
- json requestJson = json::parse(req.body);
- // Validate required fields for text2img
- if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Validate all parameters
- auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
- if (!isValid) {
- sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Check if any model is loaded
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Get currently loaded checkpoint model
- auto allModels = m_modelManager->getAllModels();
- std::string loadedModelName;
- for (const auto& [modelName, modelInfo] : allModels) {
- if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
- loadedModelName = modelName;
- break;
- }
- }
- if (loadedModelName.empty()) {
- sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
- return;
- }
- // Create generation request specifically for text2img
- GenerationRequest genRequest;
- genRequest.id = requestId;
- genRequest.modelName = loadedModelName; // Use the currently loaded model
- genRequest.prompt = requestJson["prompt"];
- genRequest.negativePrompt = requestJson.value("negative_prompt", "");
- genRequest.width = requestJson.value("width", 512);
- genRequest.height = requestJson.value("height", 512);
- genRequest.batchCount = requestJson.value("batch_count", 1);
- genRequest.steps = requestJson.value("steps", 20);
- genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
- genRequest.seed = requestJson.value("seed", "random");
- // Parse optional parameters
- if (requestJson.contains("sampling_method")) {
- genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
- }
- if (requestJson.contains("scheduler")) {
- genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
- }
- // Set text2img specific defaults
- genRequest.strength = 1.0f; // Full strength for text2img
- // Optional VAE model
- if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
- std::string vaeModelId = requestJson["vae_model"];
- if (!vaeModelId.empty()) {
- auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
- if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
- genRequest.vaePath = vaeInfo.path;
- } else {
- sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
- return;
- }
- }
- }
- // Optional TAESD model
- if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
- std::string taesdModelId = requestJson["taesd_model"];
- if (!taesdModelId.empty()) {
- auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
- if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
- genRequest.taesdPath = taesdInfo.path;
- } else {
- sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
- return;
- }
- }
- }
- // Enqueue request
- auto future = m_generationQueue->enqueueRequest(genRequest);
- json params = {
- {"prompt", genRequest.prompt},
- {"negative_prompt", genRequest.negativePrompt},
- {"model", genRequest.modelName},
- {"width", genRequest.width},
- {"height", genRequest.height},
- {"batch_count", genRequest.batchCount},
- {"steps", genRequest.steps},
- {"cfg_scale", genRequest.cfgScale},
- {"seed", genRequest.seed},
- {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
- {"scheduler", schedulerToString(genRequest.scheduler)}
- };
- // Add VAE/TAESD if specified
- if (!genRequest.vaePath.empty()) {
- params["vae_model"] = requestJson.value("vae_model", "");
- }
- if (!genRequest.taesdPath.empty()) {
- params["taesd_model"] = requestJson.value("taesd_model", "");
- }
- json response = {
- {"request_id", requestId},
- {"status", "queued"},
- {"message", "Text-to-image generation request queued successfully"},
- {"queue_position", m_generationQueue->getQueueSize()},
- {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
- {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
- {"type", "text2img"},
- {"parameters", params}
- };
- sendJsonResponse(res, response, 202);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Text-to-image request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleImg2Img(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
- return;
- }
- json requestJson = json::parse(req.body);
- // Validate required fields for img2img
- if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- if (!requestJson.contains("init_image") || !requestJson["init_image"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'init_image' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Validate all parameters
- auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
- if (!isValid) {
- sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Check if any model is loaded
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Get currently loaded checkpoint model
- auto allModels = m_modelManager->getAllModels();
- std::string loadedModelName;
- for (const auto& [modelName, modelInfo] : allModels) {
- if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
- loadedModelName = modelName;
- break;
- }
- }
- if (loadedModelName.empty()) {
- sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
- return;
- }
- // Load the init image
- std::string initImageInput = requestJson["init_image"];
- auto [imageData, imgWidth, imgHeight, imgChannels, success, loadError] = loadImageFromInput(initImageInput);
- if (!success) {
- sendErrorResponse(res, "Failed to load init image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId);
- return;
- }
- // Create generation request specifically for img2img
- GenerationRequest genRequest;
- genRequest.id = requestId;
- genRequest.requestType = GenerationRequest::RequestType::IMG2IMG;
- genRequest.modelName = loadedModelName; // Use the currently loaded model
- genRequest.prompt = requestJson["prompt"];
- genRequest.negativePrompt = requestJson.value("negative_prompt", "");
- genRequest.width = requestJson.value("width", imgWidth); // Default to input image dimensions
- genRequest.height = requestJson.value("height", imgHeight);
- genRequest.batchCount = requestJson.value("batch_count", 1);
- genRequest.steps = requestJson.value("steps", 20);
- genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
- genRequest.seed = requestJson.value("seed", "random");
- genRequest.strength = requestJson.value("strength", 0.75f);
- // Set init image data
- genRequest.initImageData = imageData;
- genRequest.initImageWidth = imgWidth;
- genRequest.initImageHeight = imgHeight;
- genRequest.initImageChannels = imgChannels;
- // Parse optional parameters
- if (requestJson.contains("sampling_method")) {
- genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
- }
- if (requestJson.contains("scheduler")) {
- genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
- }
- // Optional VAE model
- if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
- std::string vaeModelId = requestJson["vae_model"];
- if (!vaeModelId.empty()) {
- auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
- if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
- genRequest.vaePath = vaeInfo.path;
- } else {
- sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
- return;
- }
- }
- }
- // Optional TAESD model
- if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
- std::string taesdModelId = requestJson["taesd_model"];
- if (!taesdModelId.empty()) {
- auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
- if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
- genRequest.taesdPath = taesdInfo.path;
- } else {
- sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
- return;
- }
- }
- }
- // Enqueue request
- auto future = m_generationQueue->enqueueRequest(genRequest);
- json params = {
- {"prompt", genRequest.prompt},
- {"negative_prompt", genRequest.negativePrompt},
- {"init_image", requestJson["init_image"]},
- {"model", genRequest.modelName},
- {"width", genRequest.width},
- {"height", genRequest.height},
- {"batch_count", genRequest.batchCount},
- {"steps", genRequest.steps},
- {"cfg_scale", genRequest.cfgScale},
- {"seed", genRequest.seed},
- {"strength", genRequest.strength},
- {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
- {"scheduler", schedulerToString(genRequest.scheduler)}
- };
- // Add VAE/TAESD if specified
- if (!genRequest.vaePath.empty()) {
- params["vae_model"] = requestJson.value("vae_model", "");
- }
- if (!genRequest.taesdPath.empty()) {
- params["taesd_model"] = requestJson.value("taesd_model", "");
- }
- json response = {
- {"request_id", requestId},
- {"status", "queued"},
- {"message", "Image-to-image generation request queued successfully"},
- {"queue_position", m_generationQueue->getQueueSize()},
- {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
- {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
- {"type", "img2img"},
- {"parameters", params}
- };
- sendJsonResponse(res, response, 202);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Image-to-image request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleControlNet(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
- return;
- }
- json requestJson = json::parse(req.body);
- // Validate required fields for ControlNet
- if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- if (!requestJson.contains("control_image") || !requestJson["control_image"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'control_image' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Validate all parameters
- auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
- if (!isValid) {
- sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Check if any model is loaded
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Get currently loaded checkpoint model
- auto allModels = m_modelManager->getAllModels();
- std::string loadedModelName;
- for (const auto& [modelName, modelInfo] : allModels) {
- if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
- loadedModelName = modelName;
- break;
- }
- }
- if (loadedModelName.empty()) {
- sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
- return;
- }
- // Create generation request specifically for ControlNet
- GenerationRequest genRequest;
- genRequest.id = requestId;
- genRequest.modelName = loadedModelName; // Use the currently loaded model
- genRequest.prompt = requestJson["prompt"];
- genRequest.negativePrompt = requestJson.value("negative_prompt", "");
- genRequest.width = requestJson.value("width", 512);
- genRequest.height = requestJson.value("height", 512);
- genRequest.batchCount = requestJson.value("batch_count", 1);
- genRequest.steps = requestJson.value("steps", 20);
- genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
- genRequest.seed = requestJson.value("seed", "random");
- genRequest.controlStrength = requestJson.value("control_strength", 0.9f);
- genRequest.controlNetPath = requestJson.value("control_net_model", "");
- // Parse optional parameters
- if (requestJson.contains("sampling_method")) {
- genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
- }
- if (requestJson.contains("scheduler")) {
- genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
- }
- // Optional VAE model
- if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
- std::string vaeModelId = requestJson["vae_model"];
- if (!vaeModelId.empty()) {
- auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
- if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
- genRequest.vaePath = vaeInfo.path;
- } else {
- sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
- return;
- }
- }
- }
- // Optional TAESD model
- if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
- std::string taesdModelId = requestJson["taesd_model"];
- if (!taesdModelId.empty()) {
- auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
- if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
- genRequest.taesdPath = taesdInfo.path;
- } else {
- sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
- return;
- }
- }
- }
- // Store control image path (would be handled in actual implementation)
- genRequest.outputPath = requestJson.value("control_image", "");
- // Enqueue request
- auto future = m_generationQueue->enqueueRequest(genRequest);
- json params = {
- {"prompt", genRequest.prompt},
- {"negative_prompt", genRequest.negativePrompt},
- {"control_image", requestJson["control_image"]},
- {"control_net_model", genRequest.controlNetPath},
- {"model", genRequest.modelName},
- {"width", genRequest.width},
- {"height", genRequest.height},
- {"batch_count", genRequest.batchCount},
- {"steps", genRequest.steps},
- {"cfg_scale", genRequest.cfgScale},
- {"seed", genRequest.seed},
- {"control_strength", genRequest.controlStrength},
- {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
- {"scheduler", schedulerToString(genRequest.scheduler)}
- };
- // Add VAE/TAESD if specified
- if (!genRequest.vaePath.empty()) {
- params["vae_model"] = requestJson.value("vae_model", "");
- }
- if (!genRequest.taesdPath.empty()) {
- params["taesd_model"] = requestJson.value("taesd_model", "");
- }
- json response = {
- {"request_id", requestId},
- {"status", "queued"},
- {"message", "ControlNet generation request queued successfully"},
- {"queue_position", m_generationQueue->getQueueSize()},
- {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
- {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
- {"type", "controlnet"},
- {"parameters", params}
- };
- sendJsonResponse(res, response, 202);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("ControlNet request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleUpscale(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
- return;
- }
- json requestJson = json::parse(req.body);
- // Validate required fields for upscaler
- if (!requestJson.contains("image") || !requestJson["image"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'image' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- if (!requestJson.contains("esrgan_model") || !requestJson["esrgan_model"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'esrgan_model' field (model hash or name)", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Check if model manager is available
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Get the ESRGAN/upscaler model
- std::string esrganModelId = requestJson["esrgan_model"];
- auto modelInfo = m_modelManager->getModelInfo(esrganModelId);
- if (modelInfo.name.empty()) {
- sendErrorResponse(res, "ESRGAN model not found: " + esrganModelId, 404, "MODEL_NOT_FOUND", requestId);
- return;
- }
- if (modelInfo.type != ModelType::ESRGAN && modelInfo.type != ModelType::UPSCALER) {
- sendErrorResponse(res, "Model is not an ESRGAN/upscaler model", 400, "INVALID_MODEL_TYPE", requestId);
- return;
- }
- // Load the input image
- std::string imageInput = requestJson["image"];
- auto [imageData, imgWidth, imgHeight, imgChannels, success, loadError] = loadImageFromInput(imageInput);
- if (!success) {
- sendErrorResponse(res, "Failed to load image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId);
- return;
- }
- // Create upscaler request
- GenerationRequest genRequest;
- genRequest.id = requestId;
- genRequest.requestType = GenerationRequest::RequestType::UPSCALER;
- genRequest.esrganPath = modelInfo.path;
- genRequest.upscaleFactor = requestJson.value("upscale_factor", 4);
- genRequest.nThreads = requestJson.value("threads", -1);
- genRequest.offloadParamsToCpu = requestJson.value("offload_to_cpu", false);
- genRequest.diffusionConvDirect = requestJson.value("direct", false);
- // Set input image data
- genRequest.initImageData = imageData;
- genRequest.initImageWidth = imgWidth;
- genRequest.initImageHeight = imgHeight;
- genRequest.initImageChannels = imgChannels;
- // Enqueue request
- auto future = m_generationQueue->enqueueRequest(genRequest);
- json response = {
- {"request_id", requestId},
- {"status", "queued"},
- {"message", "Upscale request queued successfully"},
- {"queue_position", m_generationQueue->getQueueSize()},
- {"type", "upscale"},
- {"parameters", {
- {"esrgan_model", esrganModelId},
- {"upscale_factor", genRequest.upscaleFactor},
- {"input_width", imgWidth},
- {"input_height", imgHeight},
- {"output_width", imgWidth * genRequest.upscaleFactor},
- {"output_height", imgHeight * genRequest.upscaleFactor}
- }}
- };
- sendJsonResponse(res, response, 202);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Upscale request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleInpainting(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
- return;
- }
- json requestJson = json::parse(req.body);
- // Validate required fields for inpainting
- if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- if (!requestJson.contains("source_image") || !requestJson["source_image"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'source_image' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- if (!requestJson.contains("mask_image") || !requestJson["mask_image"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'mask_image' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Validate all parameters
- auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
- if (!isValid) {
- sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Check if any model is loaded
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Get currently loaded checkpoint model
- auto allModels = m_modelManager->getAllModels();
- std::string loadedModelName;
- for (const auto& [modelName, modelInfo] : allModels) {
- if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
- loadedModelName = modelName;
- break;
- }
- }
- if (loadedModelName.empty()) {
- sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
- return;
- }
- // Load the source image
- std::string sourceImageInput = requestJson["source_image"];
- auto [sourceImageData, sourceImgWidth, sourceImgHeight, sourceImgChannels, sourceSuccess, sourceLoadError] = loadImageFromInput(sourceImageInput);
- if (!sourceSuccess) {
- sendErrorResponse(res, "Failed to load source image: " + sourceLoadError, 400, "IMAGE_LOAD_ERROR", requestId);
- return;
- }
- // Load the mask image
- std::string maskImageInput = requestJson["mask_image"];
- auto [maskImageData, maskImgWidth, maskImgHeight, maskImgChannels, maskSuccess, maskLoadError] = loadImageFromInput(maskImageInput);
- if (!maskSuccess) {
- sendErrorResponse(res, "Failed to load mask image: " + maskLoadError, 400, "MASK_LOAD_ERROR", requestId);
- return;
- }
- // Validate that source and mask images have compatible dimensions
- if (sourceImgWidth != maskImgWidth || sourceImgHeight != maskImgHeight) {
- sendErrorResponse(res, "Source and mask images must have the same dimensions", 400, "DIMENSION_MISMATCH", requestId);
- return;
- }
- // Create generation request specifically for inpainting
- GenerationRequest genRequest;
- genRequest.id = requestId;
- genRequest.requestType = GenerationRequest::RequestType::INPAINTING;
- genRequest.modelName = loadedModelName; // Use the currently loaded model
- genRequest.prompt = requestJson["prompt"];
- genRequest.negativePrompt = requestJson.value("negative_prompt", "");
- genRequest.width = requestJson.value("width", sourceImgWidth); // Default to input image dimensions
- genRequest.height = requestJson.value("height", sourceImgHeight);
- genRequest.batchCount = requestJson.value("batch_count", 1);
- genRequest.steps = requestJson.value("steps", 20);
- genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
- genRequest.seed = requestJson.value("seed", "random");
- genRequest.strength = requestJson.value("strength", 0.75f);
- // Set source image data
- genRequest.initImageData = sourceImageData;
- genRequest.initImageWidth = sourceImgWidth;
- genRequest.initImageHeight = sourceImgHeight;
- genRequest.initImageChannels = sourceImgChannels;
- // Set mask image data
- genRequest.maskImageData = maskImageData;
- genRequest.maskImageWidth = maskImgWidth;
- genRequest.maskImageHeight = maskImgHeight;
- genRequest.maskImageChannels = maskImgChannels;
- // Parse optional parameters
- if (requestJson.contains("sampling_method")) {
- genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
- }
- if (requestJson.contains("scheduler")) {
- genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
- }
- // Optional VAE model
- if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
- std::string vaeModelId = requestJson["vae_model"];
- if (!vaeModelId.empty()) {
- auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
- if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
- genRequest.vaePath = vaeInfo.path;
- } else {
- sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
- return;
- }
- }
- }
- // Optional TAESD model
- if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
- std::string taesdModelId = requestJson["taesd_model"];
- if (!taesdModelId.empty()) {
- auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
- if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
- genRequest.taesdPath = taesdInfo.path;
- } else {
- sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
- return;
- }
- }
- }
- // Enqueue request
- auto future = m_generationQueue->enqueueRequest(genRequest);
- json params = {
- {"prompt", genRequest.prompt},
- {"negative_prompt", genRequest.negativePrompt},
- {"source_image", requestJson["source_image"]},
- {"mask_image", requestJson["mask_image"]},
- {"model", genRequest.modelName},
- {"width", genRequest.width},
- {"height", genRequest.height},
- {"batch_count", genRequest.batchCount},
- {"steps", genRequest.steps},
- {"cfg_scale", genRequest.cfgScale},
- {"seed", genRequest.seed},
- {"strength", genRequest.strength},
- {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
- {"scheduler", schedulerToString(genRequest.scheduler)}
- };
- // Add VAE/TAESD if specified
- if (!genRequest.vaePath.empty()) {
- params["vae_model"] = requestJson.value("vae_model", "");
- }
- if (!genRequest.taesdPath.empty()) {
- params["taesd_model"] = requestJson.value("taesd_model", "");
- }
- json response = {
- {"request_id", requestId},
- {"status", "queued"},
- {"message", "Inpainting generation request queued successfully"},
- {"queue_position", m_generationQueue->getQueueSize()},
- {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
- {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
- {"type", "inpainting"},
- {"parameters", params}
- };
- sendJsonResponse(res, response, 202);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Inpainting request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- // Utility endpoints
- void Server::handleSamplers(const httplib::Request& req, httplib::Response& res) {
- try {
- json samplers = {
- {"samplers", {
- {
- {"name", "euler"},
- {"description", "Euler sampler - fast and simple"},
- {"recommended_steps", 20}
- },
- {
- {"name", "euler_a"},
- {"description", "Euler ancestral sampler - adds randomness"},
- {"recommended_steps", 20}
- },
- {
- {"name", "heun"},
- {"description", "Heun sampler - more accurate but slower"},
- {"recommended_steps", 20}
- },
- {
- {"name", "dpm2"},
- {"description", "DPM2 sampler - second-order DPM"},
- {"recommended_steps", 20}
- },
- {
- {"name", "dpm++2s_a"},
- {"description", "DPM++ 2s ancestral sampler"},
- {"recommended_steps", 20}
- },
- {
- {"name", "dpm++2m"},
- {"description", "DPM++ 2m sampler - multistep"},
- {"recommended_steps", 20}
- },
- {
- {"name", "dpm++2mv2"},
- {"description", "DPM++ 2m v2 sampler - improved multistep"},
- {"recommended_steps", 20}
- },
- {
- {"name", "ipndm"},
- {"description", "IPNDM sampler - improved noise prediction"},
- {"recommended_steps", 20}
- },
- {
- {"name", "ipndm_v"},
- {"description", "IPNDM v sampler - variant of IPNDM"},
- {"recommended_steps", 20}
- },
- {
- {"name", "lcm"},
- {"description", "LCM sampler - Latent Consistency Model, very fast"},
- {"recommended_steps", 4}
- },
- {
- {"name", "ddim_trailing"},
- {"description", "DDIM trailing sampler - deterministic"},
- {"recommended_steps", 20}
- },
- {
- {"name", "tcd"},
- {"description", "TCD sampler - Trajectory Consistency Distillation"},
- {"recommended_steps", 8}
- },
- {
- {"name", "default"},
- {"description", "Use model's default sampler"},
- {"recommended_steps", 20}
- }
- }}
- };
- sendJsonResponse(res, samplers);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to get samplers: ") + e.what(), 500);
- }
- }
- void Server::handleSchedulers(const httplib::Request& req, httplib::Response& res) {
- try {
- json schedulers = {
- {"schedulers", {
- {
- {"name", "discrete"},
- {"description", "Discrete scheduler - standard noise schedule"}
- },
- {
- {"name", "karras"},
- {"description", "Karras scheduler - improved noise schedule"}
- },
- {
- {"name", "exponential"},
- {"description", "Exponential scheduler - exponential noise decay"}
- },
- {
- {"name", "ays"},
- {"description", "AYS scheduler - Adaptive Your Scheduler"}
- },
- {
- {"name", "gits"},
- {"description", "GITS scheduler - Generalized Iterative Time Steps"}
- },
- {
- {"name", "smoothstep"},
- {"description", "Smoothstep scheduler - smooth transition function"}
- },
- {
- {"name", "sgm_uniform"},
- {"description", "SGM uniform scheduler - uniform noise schedule"}
- },
- {
- {"name", "simple"},
- {"description", "Simple scheduler - basic linear schedule"}
- },
- {
- {"name", "default"},
- {"description", "Use model's default scheduler"}
- }
- }}
- };
- sendJsonResponse(res, schedulers);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to get schedulers: ") + e.what(), 500);
- }
- }
- void Server::handleParameters(const httplib::Request& req, httplib::Response& res) {
- try {
- json parameters = {
- {"parameters", {
- {
- {"name", "prompt"},
- {"type", "string"},
- {"required", true},
- {"description", "Text prompt for image generation"},
- {"min_length", 1},
- {"max_length", 10000},
- {"example", "a beautiful landscape with mountains"}
- },
- {
- {"name", "negative_prompt"},
- {"type", "string"},
- {"required", false},
- {"description", "Negative prompt to guide generation away from"},
- {"min_length", 0},
- {"max_length", 10000},
- {"example", "blurry, low quality, distorted"}
- },
- {
- {"name", "width"},
- {"type", "integer"},
- {"required", false},
- {"description", "Image width in pixels"},
- {"min", 64},
- {"max", 2048},
- {"multiple_of", 64},
- {"default", 512}
- },
- {
- {"name", "height"},
- {"type", "integer"},
- {"required", false},
- {"description", "Image height in pixels"},
- {"min", 64},
- {"max", 2048},
- {"multiple_of", 64},
- {"default", 512}
- },
- {
- {"name", "steps"},
- {"type", "integer"},
- {"required", false},
- {"description", "Number of diffusion steps"},
- {"min", 1},
- {"max", 150},
- {"default", 20}
- },
- {
- {"name", "cfg_scale"},
- {"type", "number"},
- {"required", false},
- {"description", "Classifier-Free Guidance scale"},
- {"min", 1.0},
- {"max", 30.0},
- {"default", 7.5}
- },
- {
- {"name", "seed"},
- {"type", "string|integer"},
- {"required", false},
- {"description", "Seed for generation (use 'random' for random seed)"},
- {"example", "42"}
- },
- {
- {"name", "sampling_method"},
- {"type", "string"},
- {"required", false},
- {"description", "Sampling method to use"},
- {"enum", {"euler", "euler_a", "heun", "dpm2", "dpm++2s_a", "dpm++2m", "dpm++2mv2", "ipndm", "ipndm_v", "lcm", "ddim_trailing", "tcd", "default"}},
- {"default", "default"}
- },
- {
- {"name", "scheduler"},
- {"type", "string"},
- {"required", false},
- {"description", "Scheduler to use"},
- {"enum", {"discrete", "karras", "exponential", "ays", "gits", "smoothstep", "sgm_uniform", "simple", "default"}},
- {"default", "default"}
- },
- {
- {"name", "batch_count"},
- {"type", "integer"},
- {"required", false},
- {"description", "Number of images to generate"},
- {"min", 1},
- {"max", 100},
- {"default", 1}
- },
- {
- {"name", "strength"},
- {"type", "number"},
- {"required", false},
- {"description", "Strength for img2img (0.0-1.0)"},
- {"min", 0.0},
- {"max", 1.0},
- {"default", 0.75}
- },
- {
- {"name", "control_strength"},
- {"type", "number"},
- {"required", false},
- {"description", "ControlNet strength (0.0-1.0)"},
- {"min", 0.0},
- {"max", 1.0},
- {"default", 0.9}
- }
- }},
- {"openapi", {
- {"version", "3.0.0"},
- {"info", {
- {"title", "Stable Diffusion REST API"},
- {"version", "1.0.0"},
- {"description", "Comprehensive REST API for stable-diffusion.cpp functionality"}
- }},
- {"components", {
- {"schemas", {
- {"GenerationRequest", {
- {"type", "object"},
- {"required", {"prompt"}},
- {"properties", {
- {"prompt", {{"type", "string"}, {"description", "Text prompt for generation"}}},
- {"negative_prompt", {{"type", "string"}, {"description", "Negative prompt"}}},
- {"width", {{"type", "integer"}, {"minimum", 64}, {"maximum", 2048}, {"default", 512}}},
- {"height", {{"type", "integer"}, {"minimum", 64}, {"maximum", 2048}, {"default", 512}}},
- {"steps", {{"type", "integer"}, {"minimum", 1}, {"maximum", 150}, {"default", 20}}},
- {"cfg_scale", {{"type", "number"}, {"minimum", 1.0}, {"maximum", 30.0}, {"default", 7.5}}}
- }}
- }}
- }}
- }}
- }}
- };
- sendJsonResponse(res, parameters);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to get parameters: ") + e.what(), 500);
- }
- }
- void Server::handleValidate(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- json requestJson = json::parse(req.body);
- // Validate parameters
- auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
- json response = {
- {"request_id", requestId},
- {"valid", isValid},
- {"message", isValid ? "Parameters are valid" : errorMessage},
- {"errors", isValid ? json::array() : json::array({errorMessage})}
- };
- sendJsonResponse(res, response, isValid ? 200 : 400);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Validation failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleEstimate(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- json requestJson = json::parse(req.body);
- // Validate parameters first
- auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
- if (!isValid) {
- sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Create a temporary request to estimate
- GenerationRequest genRequest;
- genRequest.prompt = requestJson["prompt"];
- genRequest.width = requestJson.value("width", 512);
- genRequest.height = requestJson.value("height", 512);
- genRequest.batchCount = requestJson.value("batch_count", 1);
- genRequest.steps = requestJson.value("steps", 20);
- genRequest.diffusionFlashAttn = requestJson.value("diffusion_flash_attn", false);
- genRequest.controlNetPath = requestJson.value("control_net_path", "");
- if (requestJson.contains("sampling_method")) {
- genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
- }
- // Calculate estimates
- uint64_t estimatedTime = estimateGenerationTime(genRequest);
- size_t estimatedMemory = estimateMemoryUsage(genRequest);
- json response = {
- {"request_id", requestId},
- {"estimated_time_seconds", estimatedTime / 1000},
- {"estimated_memory_mb", estimatedMemory / (1024 * 1024)},
- {"parameters", {
- {"resolution", std::to_string(genRequest.width) + "x" + std::to_string(genRequest.height)},
- {"steps", genRequest.steps},
- {"batch_count", genRequest.batchCount},
- {"sampling_method", samplingMethodToString(genRequest.samplingMethod)}
- }}
- };
- sendJsonResponse(res, response);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Estimation failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleConfig(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- // Get current configuration
- json config = {
- {"request_id", requestId},
- {"config", {
- {"server", {
- {"host", m_host},
- {"port", m_port},
- {"max_concurrent_generations", 1}
- }},
- {"generation", {
- {"default_width", 512},
- {"default_height", 512},
- {"default_steps", 20},
- {"default_cfg_scale", 7.5},
- {"max_batch_count", 100},
- {"max_steps", 150},
- {"max_resolution", 2048}
- }},
- {"rate_limiting", {
- {"requests_per_minute", 60},
- {"enabled", true}
- }}
- }}
- };
- sendJsonResponse(res, config);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Config operation failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleSystem(const httplib::Request& req, httplib::Response& res) {
- try {
- json system = {
- {"system", {
- {"version", "1.0.0"},
- {"build", "stable-diffusion.cpp-rest"},
- {"uptime", std::chrono::duration_cast<std::chrono::seconds>(
- std::chrono::steady_clock::now().time_since_epoch()).count()},
- {"capabilities", {
- {"text2img", true},
- {"img2img", true},
- {"controlnet", true},
- {"batch_generation", true},
- {"parameter_validation", true},
- {"estimation", true}
- }},
- {"supported_formats", {
- {"input", {"png", "jpg", "jpeg", "webp"}},
- {"output", {"png", "jpg", "jpeg", "webp"}}
- }},
- {"limits", {
- {"max_resolution", 2048},
- {"max_steps", 150},
- {"max_batch_count", 100},
- {"max_prompt_length", 10000}
- }}
- }},
- {"hardware", {
- {"cpu_threads", std::thread::hardware_concurrency()}
- }}
- };
- sendJsonResponse(res, system);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("System info failed: ") + e.what(), 500);
- }
- }
- void Server::handleSystemRestart(const httplib::Request& req, httplib::Response& res) {
- try {
- json response = {
- {"message", "Server restart initiated. The server will shut down gracefully and exit. Please use a process manager to automatically restart it."},
- {"status", "restarting"}
- };
- sendJsonResponse(res, response);
- // Schedule server stop after response is sent
- // Using a separate thread to allow the response to be sent first
- std::thread([this]() {
- std::this_thread::sleep_for(std::chrono::seconds(1));
- this->stop();
- // Exit with code 42 to signal restart intent to process manager
- std::exit(42);
- }).detach();
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Restart failed: ") + e.what(), 500);
- }
- }
- // Helper methods for model management
- json Server::getModelCapabilities(ModelType type) {
- json capabilities = json::object();
- switch (type) {
- case ModelType::CHECKPOINT:
- capabilities = {
- {"text2img", true},
- {"img2img", true},
- {"inpainting", true},
- {"outpainting", true},
- {"controlnet", true},
- {"lora", true},
- {"vae", true},
- {"sampling_methods", {"euler", "euler_a", "heun", "dpm2", "dpm++2s_a", "dpm++2m", "dpm++2mv2", "ipndm", "ipndm_v", "lcm", "ddim_trailing", "tcd"}},
- {"schedulers", {"discrete", "karras", "exponential", "ays", "gits", "smoothstep", "sgm_uniform", "simple"}},
- {"recommended_resolution", "512x512"},
- {"max_resolution", "2048x2048"},
- {"supports_batch", true}
- };
- break;
- case ModelType::LORA:
- capabilities = {
- {"text2img", true},
- {"img2img", true},
- {"inpainting", true},
- {"controlnet", false},
- {"lora", true},
- {"vae", false},
- {"requires_checkpoint", true},
- {"strength_range", {0.0, 2.0}},
- {"recommended_strength", 1.0}
- };
- break;
- case ModelType::CONTROLNET:
- capabilities = {
- {"text2img", false},
- {"img2img", true},
- {"inpainting", true},
- {"controlnet", true},
- {"requires_checkpoint", true},
- {"control_modes", {"canny", "depth", "pose", "scribble", "hed", "mlsd", "normal", "seg"}},
- {"strength_range", {0.0, 1.0}},
- {"recommended_strength", 0.9}
- };
- break;
- case ModelType::VAE:
- capabilities = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"vae", true},
- {"requires_checkpoint", true},
- {"encoding", true},
- {"decoding", true},
- {"precision", {"fp16", "fp32"}}
- };
- break;
- case ModelType::EMBEDDING:
- capabilities = {
- {"text2img", true},
- {"img2img", true},
- {"inpainting", true},
- {"embedding", true},
- {"requires_checkpoint", true},
- {"token_count", 1},
- {"compatible_with", {"checkpoint", "lora"}}
- };
- break;
- case ModelType::TAESD:
- capabilities = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"vae", true},
- {"requires_checkpoint", true},
- {"fast_decoding", true},
- {"real_time", true},
- {"precision", {"fp16", "fp32"}}
- };
- break;
- case ModelType::ESRGAN:
- capabilities = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"upscaling", true},
- {"scale_factors", {2, 4}},
- {"models", {"ESRGAN", "RealESRGAN", "SwinIR"}},
- {"supports_alpha", false}
- };
- break;
- default:
- capabilities = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"capabilities", {}}
- };
- break;
- }
- return capabilities;
- }
- json Server::getModelTypeStatistics() {
- if (!m_modelManager) return json::object();
- json stats = json::object();
- auto allModels = m_modelManager->getAllModels();
- // Initialize counters for each type
- std::map<ModelType, int> typeCounts;
- std::map<ModelType, int> loadedCounts;
- std::map<ModelType, size_t> sizeByType;
- for (const auto& pair : allModels) {
- ModelType type = pair.second.type;
- typeCounts[type]++;
- if (pair.second.isLoaded) {
- loadedCounts[type]++;
- }
- sizeByType[type] += pair.second.fileSize;
- }
- // Build statistics JSON
- for (const auto& count : typeCounts) {
- std::string typeName = ModelManager::modelTypeToString(count.first);
- stats[typeName] = {
- {"total_count", count.second},
- {"loaded_count", loadedCounts[count.first]},
- {"total_size_bytes", sizeByType[count.first]},
- {"total_size_mb", sizeByType[count.first] / (1024.0 * 1024.0)},
- {"average_size_mb", count.second > 0 ? (sizeByType[count.first] / (1024.0 * 1024.0)) / count.second : 0.0}
- };
- }
- return stats;
- }
- // Additional helper methods for model management
- json Server::getModelCompatibility(const ModelManager::ModelInfo& modelInfo) {
- json compatibility = {
- {"is_compatible", true},
- {"compatibility_score", 100},
- {"issues", json::array()},
- {"warnings", json::array()},
- {"requirements", {
- {"min_memory_mb", 1024},
- {"recommended_memory_mb", 2048},
- {"supported_formats", {"safetensors", "ckpt", "gguf"}},
- {"required_dependencies", {}}
- }}
- };
- // Check for specific compatibility issues based on model type
- if (modelInfo.type == ModelType::LORA) {
- compatibility["requirements"]["required_dependencies"] = {"checkpoint"};
- } else if (modelInfo.type == ModelType::CONTROLNET) {
- compatibility["requirements"]["required_dependencies"] = {"checkpoint"};
- } else if (modelInfo.type == ModelType::VAE) {
- compatibility["requirements"]["required_dependencies"] = {"checkpoint"};
- }
- return compatibility;
- }
- json Server::getModelRequirements(ModelType type) {
- json requirements = {
- {"min_memory_mb", 1024},
- {"recommended_memory_mb", 2048},
- {"min_disk_space_mb", 1024},
- {"supported_formats", {"safetensors", "ckpt", "gguf"}},
- {"required_dependencies", json::array()},
- {"optional_dependencies", json::array()},
- {"system_requirements", {
- {"cpu_cores", 4},
- {"cpu_architecture", "x86_64"},
- {"os", "Linux/Windows/macOS"},
- {"gpu_memory_mb", 2048},
- {"gpu_compute_capability", "3.5+"}
- }}
- };
- switch (type) {
- case ModelType::CHECKPOINT:
- requirements["min_memory_mb"] = 2048;
- requirements["recommended_memory_mb"] = 4096;
- requirements["min_disk_space_mb"] = 2048;
- requirements["supported_formats"] = {"safetensors", "ckpt", "gguf"};
- break;
- case ModelType::LORA:
- requirements["min_memory_mb"] = 512;
- requirements["recommended_memory_mb"] = 1024;
- requirements["min_disk_space_mb"] = 100;
- requirements["supported_formats"] = {"safetensors", "ckpt"};
- requirements["required_dependencies"] = {"checkpoint"};
- break;
- case ModelType::CONTROLNET:
- requirements["min_memory_mb"] = 1024;
- requirements["recommended_memory_mb"] = 2048;
- requirements["min_disk_space_mb"] = 500;
- requirements["supported_formats"] = {"safetensors", "pth"};
- requirements["required_dependencies"] = {"checkpoint"};
- break;
- case ModelType::VAE:
- requirements["min_memory_mb"] = 512;
- requirements["recommended_memory_mb"] = 1024;
- requirements["min_disk_space_mb"] = 200;
- requirements["supported_formats"] = {"safetensors", "pt", "ckpt", "gguf"};
- requirements["required_dependencies"] = {"checkpoint"};
- break;
- case ModelType::EMBEDDING:
- requirements["min_memory_mb"] = 64;
- requirements["recommended_memory_mb"] = 256;
- requirements["min_disk_space_mb"] = 10;
- requirements["supported_formats"] = {"safetensors", "pt"};
- requirements["required_dependencies"] = {"checkpoint"};
- break;
- case ModelType::TAESD:
- requirements["min_memory_mb"] = 256;
- requirements["recommended_memory_mb"] = 512;
- requirements["min_disk_space_mb"] = 100;
- requirements["supported_formats"] = {"safetensors", "pth", "gguf"};
- requirements["required_dependencies"] = {"checkpoint"};
- break;
- case ModelType::ESRGAN:
- requirements["min_memory_mb"] = 1024;
- requirements["recommended_memory_mb"] = 2048;
- requirements["min_disk_space_mb"] = 500;
- requirements["supported_formats"] = {"pth", "pt"};
- requirements["optional_dependencies"] = {"checkpoint"};
- break;
- default:
- break;
- }
- return requirements;
- }
- json Server::getRecommendedUsage(ModelType type) {
- json usage = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"controlnet", false},
- {"lora", false},
- {"vae", false},
- {"recommended_resolution", "512x512"},
- {"recommended_steps", 20},
- {"recommended_cfg_scale", 7.5},
- {"recommended_batch_size", 1}
- };
- switch (type) {
- case ModelType::CHECKPOINT:
- usage = {
- {"text2img", true},
- {"img2img", true},
- {"inpainting", true},
- {"controlnet", true},
- {"lora", true},
- {"vae", true},
- {"recommended_resolution", "512x512"},
- {"recommended_steps", 20},
- {"recommended_cfg_scale", 7.5},
- {"recommended_batch_size", 1}
- };
- break;
- case ModelType::LORA:
- usage = {
- {"text2img", true},
- {"img2img", true},
- {"inpainting", true},
- {"controlnet", false},
- {"lora", true},
- {"vae", false},
- {"recommended_strength", 1.0},
- {"recommended_usage", "Style transfer, character customization"}
- };
- break;
- case ModelType::CONTROLNET:
- usage = {
- {"text2img", false},
- {"img2img", true},
- {"inpainting", true},
- {"controlnet", true},
- {"lora", false},
- {"vae", false},
- {"recommended_strength", 0.9},
- {"recommended_usage", "Precise control over output"}
- };
- break;
- case ModelType::VAE:
- usage = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"controlnet", false},
- {"lora", false},
- {"vae", true},
- {"recommended_usage", "Improved encoding/decoding quality"}
- };
- break;
- case ModelType::EMBEDDING:
- usage = {
- {"text2img", true},
- {"img2img", true},
- {"inpainting", true},
- {"controlnet", false},
- {"lora", false},
- {"vae", false},
- {"embedding", true},
- {"recommended_usage", "Concept control, style words"}
- };
- break;
- case ModelType::TAESD:
- usage = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"controlnet", false},
- {"lora", false},
- {"vae", true},
- {"recommended_usage", "Real-time decoding"}
- };
- break;
- case ModelType::ESRGAN:
- usage = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"controlnet", false},
- {"lora", false},
- {"vae", false},
- {"upscaling", true},
- {"recommended_usage", "Image upscaling and quality enhancement"}
- };
- break;
- default:
- break;
- }
- return usage;
- }
- std::string Server::getModelTypeFromDirectoryName(const std::string& dirName) {
- if (dirName == "stable-diffusion" || dirName == "checkpoints") {
- return "checkpoint";
- } else if (dirName == "lora") {
- return "lora";
- } else if (dirName == "controlnet") {
- return "controlnet";
- } else if (dirName == "vae") {
- return "vae";
- } else if (dirName == "taesd") {
- return "taesd";
- } else if (dirName == "esrgan" || dirName == "upscaler") {
- return "esrgan";
- } else if (dirName == "embeddings" || dirName == "textual-inversion") {
- return "embedding";
- } else {
- return "unknown";
- }
- }
- std::string Server::getDirectoryDescription(const std::string& dirName) {
- if (dirName == "stable-diffusion" || dirName == "checkpoints") {
- return "Main stable diffusion model files";
- } else if (dirName == "lora") {
- return "LoRA adapter models for style transfer";
- } else if (dirName == "controlnet") {
- return "ControlNet models for precise control";
- } else if (dirName == "vae") {
- return "VAE models for improved encoding/decoding";
- } else if (dirName == "taesd") {
- return "TAESD models for real-time decoding";
- } else if (dirName == "esrgan" || dirName == "upscaler") {
- return "ESRGAN models for image upscaling";
- } else if (dirName == "embeddings" || dirName == "textual-inversion") {
- return "Text embeddings for concept control";
- } else {
- return "Unknown model directory";
- }
- }
- json Server::getDirectoryContents(const std::string& dirPath) {
- json contents = json::array();
- try {
- if (std::filesystem::exists(dirPath) && std::filesystem::is_directory(dirPath)) {
- for (const auto& entry : std::filesystem::directory_iterator(dirPath)) {
- if (entry.is_regular_file()) {
- json file = {
- {"name", entry.path().filename().string()},
- {"path", entry.path().string()},
- {"size", std::filesystem::file_size(entry.path())},
- {"size_mb", std::filesystem::file_size(entry.path()) / (1024.0 * 1024.0)},
- {"last_modified", std::chrono::duration_cast<std::chrono::seconds>(
- std::filesystem::last_write_time(entry.path()).time_since_epoch()).count()}
- };
- contents.push_back(file);
- }
- }
- }
- } catch (const std::exception& e) {
- // Return empty array if directory access fails
- }
- return contents;
- }
- json Server::getLargestModel(const std::map<std::string, ModelManager::ModelInfo>& allModels) {
- json largest = json::object();
- size_t maxSize = 0;
- std::string largestName;
- for (const auto& pair : allModels) {
- if (pair.second.fileSize > maxSize) {
- maxSize = pair.second.fileSize;
- largestName = pair.second.name;
- }
- }
- if (!largestName.empty()) {
- largest = {
- {"name", largestName},
- {"size", maxSize},
- {"size_mb", maxSize / (1024.0 * 1024.0)},
- {"type", ModelManager::modelTypeToString(allModels.at(largestName).type)}
- };
- }
- return largest;
- }
- json Server::getSmallestModel(const std::map<std::string, ModelManager::ModelInfo>& allModels) {
- json smallest = json::object();
- size_t minSize = SIZE_MAX;
- std::string smallestName;
- for (const auto& pair : allModels) {
- if (pair.second.fileSize < minSize) {
- minSize = pair.second.fileSize;
- smallestName = pair.second.name;
- }
- }
- if (!smallestName.empty()) {
- smallest = {
- {"name", smallestName},
- {"size", minSize},
- {"size_mb", minSize / (1024.0 * 1024.0)},
- {"type", ModelManager::modelTypeToString(allModels.at(smallestName).type)}
- };
- }
- return smallest;
- }
- json Server::validateModelFile(const std::string& modelPath, const std::string& modelType) {
- json validation = {
- {"is_valid", false},
- {"errors", json::array()},
- {"warnings", json::array()},
- {"file_info", json::object()},
- {"compatibility", json::object()},
- {"recommendations", json::array()}
- };
- try {
- if (!std::filesystem::exists(modelPath)) {
- validation["errors"].push_back("File does not exist");
- return validation;
- }
- if (!std::filesystem::is_regular_file(modelPath)) {
- validation["errors"].push_back("Path is not a regular file");
- return validation;
- }
- // Check file extension
- std::string extension = std::filesystem::path(modelPath).extension().string();
- if (extension.empty()) {
- validation["errors"].push_back("Missing file extension");
- return validation;
- }
- // Remove dot and convert to lowercase
- if (extension[0] == '.') {
- extension = extension.substr(1);
- }
- std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower);
- // Validate extension based on model type
- ModelType type = ModelManager::stringToModelType(modelType);
- bool validExtension = false;
- switch (type) {
- case ModelType::CHECKPOINT:
- validExtension = (extension == "safetensors" || extension == "ckpt" || extension == "gguf");
- break;
- case ModelType::LORA:
- validExtension = (extension == "safetensors" || extension == "ckpt");
- break;
- case ModelType::CONTROLNET:
- validExtension = (extension == "safetensors" || extension == "pth");
- break;
- case ModelType::VAE:
- validExtension = (extension == "safetensors" || extension == "pt" || extension == "ckpt" || extension == "gguf");
- break;
- case ModelType::EMBEDDING:
- validExtension = (extension == "safetensors" || extension == "pt");
- break;
- case ModelType::TAESD:
- validExtension = (extension == "safetensors" || extension == "pth" || extension == "gguf");
- break;
- case ModelType::ESRGAN:
- validExtension = (extension == "pth" || extension == "pt");
- break;
- default:
- break;
- }
- if (!validExtension) {
- validation["errors"].push_back("Invalid file extension for model type: " + extension);
- }
- // Check file size
- size_t fileSize = std::filesystem::file_size(modelPath);
- if (fileSize == 0) {
- validation["errors"].push_back("File is empty");
- } else if (fileSize > 8ULL * 1024 * 1024 * 1024) { // 8GB
- validation["warnings"].push_back("Very large file may cause performance issues");
- }
- // Build file info
- validation["file_info"] = {
- {"path", modelPath},
- {"size", fileSize},
- {"size_mb", fileSize / (1024.0 * 1024.0)},
- {"extension", extension},
- {"last_modified", std::chrono::duration_cast<std::chrono::seconds>(
- std::filesystem::last_write_time(modelPath).time_since_epoch()).count()}
- };
- // Check compatibility
- validation["compatibility"] = {
- {"extension_valid", validExtension},
- {"size_appropriate", fileSize <= 4ULL * 1024 * 1024 * 1024}, // 4GB
- {"recommended_format", "safetensors"}
- };
- // Add recommendations
- if (!validExtension) {
- validation["recommendations"].push_back("Convert to SafeTensors format for better security and performance");
- }
- if (fileSize > 2ULL * 1024 * 1024 * 1024) { // 2GB
- validation["recommendations"].push_back("Consider using a smaller model for better performance");
- }
- // If no errors found, mark as valid
- if (validation["errors"].empty()) {
- validation["is_valid"] = true;
- }
- } catch (const std::exception& e) {
- validation["errors"].push_back("Validation failed: " + std::string(e.what()));
- }
- return validation;
- }
- json Server::checkModelCompatibility(const ModelManager::ModelInfo& modelInfo, const std::string& systemInfo) {
- json compatibility = {
- {"is_compatible", true},
- {"compatibility_score", 100},
- {"issues", json::array()},
- {"warnings", json::array()},
- {"requirements", json::object()},
- {"recommendations", json::array()},
- {"system_info", json::object()}
- };
- // Check system compatibility
- if (systemInfo == "auto") {
- compatibility["system_info"] = {
- {"cpu_cores", std::thread::hardware_concurrency()}
- };
- }
- // Check model-specific compatibility issues
- if (modelInfo.type == ModelType::CHECKPOINT) {
- if (modelInfo.fileSize > 4ULL * 1024 * 1024 * 1024) { // 4GB
- compatibility["warnings"].push_back("Large checkpoint model may require significant memory");
- compatibility["compatibility_score"] = 80;
- }
- if (modelInfo.fileSize < 500 * 1024 * 1024) { // 500MB
- compatibility["warnings"].push_back("Small checkpoint model may have limited capabilities");
- compatibility["compatibility_score"] = 85;
- }
- } else if (modelInfo.type == ModelType::LORA) {
- if (modelInfo.fileSize > 500 * 1024 * 1024) { // 500MB
- compatibility["warnings"].push_back("Large LoRA may impact performance");
- compatibility["compatibility_score"] = 75;
- }
- }
- return compatibility;
- }
- json Server::calculateSpecificRequirements(const std::string& modelType, const std::string& resolution, const std::string& batchSize) {
- json specific = {
- {"memory_requirements", json::object()},
- {"performance_impact", json::object()},
- {"quality_expectations", json::object()}
- };
- // Parse resolution
- int width = 512, height = 512;
- try {
- size_t xPos = resolution.find('x');
- if (xPos != std::string::npos) {
- width = std::stoi(resolution.substr(0, xPos));
- height = std::stoi(resolution.substr(xPos + 1));
- }
- } catch (...) {
- // Use defaults if parsing fails
- }
- // Parse batch size
- int batch = 1;
- try {
- batch = std::stoi(batchSize);
- } catch (...) {
- // Use default if parsing fails
- }
- // Calculate memory requirements based on resolution and batch
- size_t pixels = width * height;
- size_t baseMemory = 1024 * 1024 * 1024; // 1GB base
- size_t resolutionMemory = (pixels * 4) / (512 * 512); // Scale based on 512x512
- size_t batchMemory = (batch - 1) * baseMemory * 0.5; // Additional memory for batch
- specific["memory_requirements"] = {
- {"base_memory_mb", baseMemory / (1024 * 1024)},
- {"resolution_memory_mb", resolutionMemory / (1024 * 1024)},
- {"batch_memory_mb", batchMemory / (1024 * 1024)},
- {"total_memory_mb", (baseMemory + resolutionMemory + batchMemory) / (1024 * 1024)}
- };
- // Calculate performance impact
- double performanceFactor = 1.0;
- if (pixels > 512 * 512) {
- performanceFactor = 1.5;
- }
- if (batch > 1) {
- performanceFactor *= 1.2;
- }
- specific["performance_impact"] = {
- {"resolution_factor", pixels > 512 * 512 ? 1.5 : 1.0},
- {"batch_factor", batch > 1 ? 1.2 : 1.0},
- {"overall_factor", performanceFactor}
- };
- return specific;
- }
- // Enhanced model management endpoint implementations
- void Server::handleModelInfo(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Extract model ID from URL path
- std::string modelId = req.matches[1].str();
- if (modelId.empty()) {
- sendErrorResponse(res, "Missing model ID", 400, "MISSING_MODEL_ID", requestId);
- return;
- }
- // Get model information
- auto modelInfo = m_modelManager->getModelInfo(modelId);
- if (modelInfo.name.empty()) {
- sendErrorResponse(res, "Model not found", 404, "MODEL_NOT_FOUND", requestId);
- return;
- }
- // Build comprehensive model information
- json response = {
- {"model", {
- {"name", modelInfo.name},
- {"path", modelInfo.path},
- {"type", ModelManager::modelTypeToString(modelInfo.type)},
- {"is_loaded", modelInfo.isLoaded},
- {"file_size", modelInfo.fileSize},
- {"file_size_mb", modelInfo.fileSize / (1024.0 * 1024.0)},
- {"description", modelInfo.description},
- {"metadata", modelInfo.metadata},
- {"capabilities", getModelCapabilities(modelInfo.type)},
- {"compatibility", getModelCompatibility(modelInfo)},
- {"requirements", getModelRequirements(modelInfo.type)},
- {"recommended_usage", getRecommendedUsage(modelInfo.type)},
- {"last_modified", std::chrono::duration_cast<std::chrono::seconds>(
- modelInfo.modifiedAt.time_since_epoch()).count()}
- }},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to get model info: ") + e.what(), 500, "MODEL_INFO_ERROR", requestId);
- }
- }
- void Server::handleLoadModelById(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Extract model ID from URL path (could be hash or name)
- std::string modelIdentifier = req.matches[1].str();
- if (modelIdentifier.empty()) {
- sendErrorResponse(res, "Missing model identifier", 400, "MISSING_MODEL_ID", requestId);
- return;
- }
- // Try to find by hash first (if it looks like a hash - 10+ hex chars)
- std::string modelId = modelIdentifier;
- if (modelIdentifier.length() >= 10 &&
- std::all_of(modelIdentifier.begin(), modelIdentifier.end(),
- [](char c) { return std::isxdigit(c); })) {
- std::string foundName = m_modelManager->findModelByHash(modelIdentifier);
- if (!foundName.empty()) {
- modelId = foundName;
- std::cout << "Resolved hash " << modelIdentifier << " to model: " << modelId << std::endl;
- }
- }
- // Parse optional parameters from request body
- json requestJson;
- if (!req.body.empty()) {
- try {
- requestJson = json::parse(req.body);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- return;
- }
- }
- // Unload previous model if one is loaded
- std::string previousModel;
- {
- std::lock_guard<std::mutex> lock(m_currentModelMutex);
- previousModel = m_currentlyLoadedModel;
- }
- if (!previousModel.empty() && previousModel != modelId) {
- std::cout << "Unloading previous model: " << previousModel << std::endl;
- m_modelManager->unloadModel(previousModel);
- }
- // Load model
- bool success = m_modelManager->loadModel(modelId);
- if (success) {
- // Update currently loaded model
- {
- std::lock_guard<std::mutex> lock(m_currentModelMutex);
- m_currentlyLoadedModel = modelId;
- }
- auto modelInfo = m_modelManager->getModelInfo(modelId);
- json response = {
- {"status", "success"},
- {"model", {
- {"name", modelInfo.name},
- {"path", modelInfo.path},
- {"type", ModelManager::modelTypeToString(modelInfo.type)},
- {"is_loaded", modelInfo.isLoaded}
- }},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } else {
- sendErrorResponse(res, "Failed to load model", 400, "MODEL_LOAD_FAILED", requestId);
- }
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Model load failed: ") + e.what(), 500, "MODEL_LOAD_ERROR", requestId);
- }
- }
- void Server::handleUnloadModelById(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Extract model ID from URL path
- std::string modelId = req.matches[1].str();
- if (modelId.empty()) {
- sendErrorResponse(res, "Missing model ID", 400, "MISSING_MODEL_ID", requestId);
- return;
- }
- // Unload model
- bool success = m_modelManager->unloadModel(modelId);
- if (success) {
- // Clear currently loaded model if it matches
- {
- std::lock_guard<std::mutex> lock(m_currentModelMutex);
- if (m_currentlyLoadedModel == modelId) {
- m_currentlyLoadedModel = "";
- }
- }
- json response = {
- {"status", "success"},
- {"model", {
- {"name", modelId},
- {"is_loaded", false}
- }},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } else {
- sendErrorResponse(res, "Failed to unload model or model not found", 404, "MODEL_UNLOAD_FAILED", requestId);
- }
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Model unload failed: ") + e.what(), 500, "MODEL_UNLOAD_ERROR", requestId);
- }
- }
- void Server::handleModelTypes(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- json types = {
- {"model_types", {
- {
- {"type", "checkpoint"},
- {"description", "Main stable diffusion model files for text-to-image, image-to-image, and inpainting"},
- {"extensions", {"safetensors", "ckpt", "gguf"}},
- {"capabilities", {"text2img", "img2img", "inpainting", "controlnet", "lora", "vae"}},
- {"recommended_for", "General purpose image generation"}
- },
- {
- {"type", "lora"},
- {"description", "LoRA adapter models for style transfer and character customization"},
- {"extensions", {"safetensors", "ckpt"}},
- {"capabilities", {"style_transfer", "character_customization"}},
- {"requires", {"checkpoint"}},
- {"recommended_for", "Style modification and character-specific generation"}
- },
- {
- {"type", "controlnet"},
- {"description", "ControlNet models for precise control over output composition"},
- {"extensions", {"safetensors", "pth"}},
- {"capabilities", {"precise_control", "composition_control"}},
- {"requires", {"checkpoint"}},
- {"recommended_for", "Precise control over image generation"}
- },
- {
- {"type", "vae"},
- {"description", "VAE models for improved encoding and decoding quality"},
- {"extensions", {"safetensors", "pt", "ckpt", "gguf"}},
- {"capabilities", {"encoding", "decoding", "quality_improvement"}},
- {"requires", {"checkpoint"}},
- {"recommended_for", "Improved image quality and encoding"}
- },
- {
- {"type", "embedding"},
- {"description", "Text embeddings for concept control and style words"},
- {"extensions", {"safetensors", "pt"}},
- {"capabilities", {"concept_control", "style_words"}},
- {"requires", {"checkpoint"}},
- {"recommended_for", "Concept control and specific styles"}
- },
- {
- {"type", "taesd"},
- {"description", "TAESD models for real-time decoding"},
- {"extensions", {"safetensors", "pth", "gguf"}},
- {"capabilities", {"real_time_decoding", "fast_preview"}},
- {"requires", {"checkpoint"}},
- {"recommended_for", "Real-time applications and fast previews"}
- },
- {
- {"type", "esrgan"},
- {"description", "ESRGAN models for image upscaling and enhancement"},
- {"extensions", {"pth", "pt"}},
- {"capabilities", {"upscaling", "enhancement", "quality_improvement"}},
- {"recommended_for", "Image upscaling and quality enhancement"}
- }
- }},
- {"request_id", requestId}
- };
- sendJsonResponse(res, types);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to get model types: ") + e.what(), 500, "MODEL_TYPES_ERROR", requestId);
- }
- }
- void Server::handleModelDirectories(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- std::string modelsDir = m_modelManager->getModelsDirectory();
- json directories = json::array();
- // Define expected model directories
- std::vector<std::string> modelDirs = {
- "stable-diffusion", "checkpoints", "lora", "controlnet",
- "vae", "taesd", "esrgan", "embeddings"
- };
- for (const auto& dirName : modelDirs) {
- std::string dirPath = modelsDir + "/" + dirName;
- std::string type = getModelTypeFromDirectoryName(dirName);
- std::string description = getDirectoryDescription(dirName);
- json dirInfo = {
- {"name", dirName},
- {"path", dirPath},
- {"type", type},
- {"description", description},
- {"exists", std::filesystem::exists(dirPath) && std::filesystem::is_directory(dirPath)},
- {"contents", getDirectoryContents(dirPath)}
- };
- directories.push_back(dirInfo);
- }
- json response = {
- {"models_directory", modelsDir},
- {"directories", directories},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to get model directories: ") + e.what(), 500, "MODEL_DIRECTORIES_ERROR", requestId);
- }
- }
- void Server::handleRefreshModels(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Force refresh of model cache
- bool success = m_modelManager->scanModelsDirectory();
- if (success) {
- json response = {
- {"status", "success"},
- {"message", "Model cache refreshed successfully"},
- {"models_found", m_modelManager->getAvailableModelsCount()},
- {"models_loaded", m_modelManager->getLoadedModelsCount()},
- {"models_directory", m_modelManager->getModelsDirectory()},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } else {
- sendErrorResponse(res, "Failed to refresh model cache", 500, "MODEL_REFRESH_FAILED", requestId);
- }
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Model refresh failed: ") + e.what(), 500, "MODEL_REFRESH_ERROR", requestId);
- }
- }
- void Server::handleHashModels(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_generationQueue || !m_modelManager) {
- sendErrorResponse(res, "Services not available", 500, "SERVICE_UNAVAILABLE", requestId);
- return;
- }
- // Parse request body
- json requestJson;
- if (!req.body.empty()) {
- requestJson = json::parse(req.body);
- }
- HashRequest hashReq;
- hashReq.id = requestId;
- hashReq.forceRehash = requestJson.value("force_rehash", false);
- if (requestJson.contains("models") && requestJson["models"].is_array()) {
- for (const auto& model : requestJson["models"]) {
- hashReq.modelNames.push_back(model.get<std::string>());
- }
- }
- // Enqueue hash request
- auto future = m_generationQueue->enqueueHashRequest(hashReq);
- json response = {
- {"request_id", requestId},
- {"status", "queued"},
- {"message", "Hash job queued successfully"},
- {"models_to_hash", hashReq.modelNames.empty() ? "all_unhashed" : std::to_string(hashReq.modelNames.size())}
- };
- sendJsonResponse(res, response, 202);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Hash request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleConvertModel(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_generationQueue || !m_modelManager) {
- sendErrorResponse(res, "Services not available", 500, "SERVICE_UNAVAILABLE", requestId);
- return;
- }
- // Parse request body
- json requestJson;
- try {
- requestJson = json::parse(req.body);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- return;
- }
- // Validate required fields
- if (!requestJson.contains("model_name")) {
- sendErrorResponse(res, "Missing required field: model_name", 400, "MISSING_FIELD", requestId);
- return;
- }
- if (!requestJson.contains("quantization_type")) {
- sendErrorResponse(res, "Missing required field: quantization_type", 400, "MISSING_FIELD", requestId);
- return;
- }
- std::string modelName = requestJson["model_name"].get<std::string>();
- std::string quantizationType = requestJson["quantization_type"].get<std::string>();
- // Validate quantization type
- const std::vector<std::string> validTypes = {"f32", "f16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "q2_K", "q3_K", "q4_K"};
- if (std::find(validTypes.begin(), validTypes.end(), quantizationType) == validTypes.end()) {
- sendErrorResponse(res, "Invalid quantization_type. Valid types: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K",
- 400, "INVALID_QUANTIZATION_TYPE", requestId);
- return;
- }
- // Get model info to find the full path
- auto modelInfo = m_modelManager->getModelInfo(modelName);
- if (modelInfo.name.empty()) {
- sendErrorResponse(res, "Model not found: " + modelName, 404, "MODEL_NOT_FOUND", requestId);
- return;
- }
- // Check if model is already GGUF
- if (modelInfo.fullPath.find(".gguf") != std::string::npos) {
- sendErrorResponse(res, "Model is already in GGUF format. Cannot convert GGUF to GGUF.",
- 400, "ALREADY_GGUF", requestId);
- return;
- }
- // Build output path
- std::string outputPath = requestJson.value("output_path", "");
- if (outputPath.empty()) {
- // Generate default output path: model_name_quantization.gguf
- namespace fs = std::filesystem;
- fs::path inputPath(modelInfo.fullPath);
- std::string baseName = inputPath.stem().string();
- std::string outputDir = inputPath.parent_path().string();
- outputPath = outputDir + "/" + baseName + "_" + quantizationType + ".gguf";
- }
- // Create conversion request
- ConversionRequest convReq;
- convReq.id = requestId;
- convReq.modelName = modelName;
- convReq.modelPath = modelInfo.fullPath;
- convReq.outputPath = outputPath;
- convReq.quantizationType = quantizationType;
- // Enqueue conversion request
- auto future = m_generationQueue->enqueueConversionRequest(convReq);
- json response = {
- {"request_id", requestId},
- {"status", "queued"},
- {"message", "Model conversion queued successfully"},
- {"model_name", modelName},
- {"input_path", modelInfo.fullPath},
- {"output_path", outputPath},
- {"quantization_type", quantizationType}
- };
- sendJsonResponse(res, response, 202);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Conversion request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleModelStats(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- auto allModels = m_modelManager->getAllModels();
- json response = {
- {"statistics", {
- {"total_models", allModels.size()},
- {"loaded_models", m_modelManager->getLoadedModelsCount()},
- {"available_models", m_modelManager->getAvailableModelsCount()},
- {"model_types", getModelTypeStatistics()},
- {"largest_model", getLargestModel(allModels)},
- {"smallest_model", getSmallestModel(allModels)}
- }},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to get model stats: ") + e.what(), 500, "MODEL_STATS_ERROR", requestId);
- }
- }
- void Server::handleBatchModels(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Parse JSON request body
- json requestJson = json::parse(req.body);
- if (!requestJson.contains("operation") || !requestJson["operation"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'operation' field", 400, "INVALID_OPERATION", requestId);
- return;
- }
- if (!requestJson.contains("models") || !requestJson["models"].is_array()) {
- sendErrorResponse(res, "Missing or invalid 'models' field", 400, "INVALID_MODELS", requestId);
- return;
- }
- std::string operation = requestJson["operation"];
- json models = requestJson["models"];
- json results = json::array();
- for (const auto& model : models) {
- if (!model.is_string()) {
- results.push_back({
- {"model", model},
- {"success", false},
- {"error", "Invalid model name"}
- });
- continue;
- }
- std::string modelName = model;
- bool success = false;
- std::string error = "";
- if (operation == "load") {
- success = m_modelManager->loadModel(modelName);
- if (!success) error = "Failed to load model";
- } else if (operation == "unload") {
- success = m_modelManager->unloadModel(modelName);
- if (!success) error = "Failed to unload model";
- } else {
- error = "Unsupported operation";
- }
- results.push_back({
- {"model", modelName},
- {"success", success},
- {"error", error.empty() ? json(nullptr) : json(error)}
- });
- }
- json response = {
- {"operation", operation},
- {"results", results},
- {"successful_count", std::count_if(results.begin(), results.end(),
- [](const json& result) { return result["success"].get<bool>(); })},
- {"failed_count", std::count_if(results.begin(), results.end(),
- [](const json& result) { return !result["success"].get<bool>(); })},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Batch operation failed: ") + e.what(), 500, "BATCH_OPERATION_ERROR", requestId);
- }
- }
- void Server::handleValidateModel(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- // Parse JSON request body
- json requestJson = json::parse(req.body);
- if (!requestJson.contains("model_path") || !requestJson["model_path"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'model_path' field", 400, "INVALID_MODEL_PATH", requestId);
- return;
- }
- std::string modelPath = requestJson["model_path"];
- std::string modelType = requestJson.value("model_type", "checkpoint");
- // Validate model file
- json validation = validateModelFile(modelPath, modelType);
- json response = {
- {"validation", validation},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Model validation failed: ") + e.what(), 500, "MODEL_VALIDATION_ERROR", requestId);
- }
- }
- void Server::handleCheckCompatibility(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Parse JSON request body
- json requestJson = json::parse(req.body);
- if (!requestJson.contains("model_name") || !requestJson["model_name"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'model_name' field", 400, "INVALID_MODEL_NAME", requestId);
- return;
- }
- std::string modelName = requestJson["model_name"];
- std::string systemInfo = requestJson.value("system_info", "auto");
- // Get model information
- auto modelInfo = m_modelManager->getModelInfo(modelName);
- if (modelInfo.name.empty()) {
- sendErrorResponse(res, "Model not found", 404, "MODEL_NOT_FOUND", requestId);
- return;
- }
- // Check compatibility
- json compatibility = checkModelCompatibility(modelInfo, systemInfo);
- json response = {
- {"model", modelName},
- {"compatibility", compatibility},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Compatibility check failed: ") + e.what(), 500, "COMPATIBILITY_CHECK_ERROR", requestId);
- }
- }
- void Server::handleModelRequirements(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- // Parse JSON request body
- json requestJson = json::parse(req.body);
- std::string modelType = requestJson.value("model_type", "checkpoint");
- std::string resolution = requestJson.value("resolution", "512x512");
- std::string batchSize = requestJson.value("batch_size", "1");
- // Calculate specific requirements
- json requirements = calculateSpecificRequirements(modelType, resolution, batchSize);
- // Get general requirements for model type
- ModelType type = ModelManager::stringToModelType(modelType);
- json generalRequirements = getModelRequirements(type);
- json response = {
- {"model_type", modelType},
- {"configuration", {
- {"resolution", resolution},
- {"batch_size", batchSize}
- }},
- {"specific_requirements", requirements},
- {"general_requirements", generalRequirements},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Requirements calculation failed: ") + e.what(), 500, "REQUIREMENTS_ERROR", requestId);
- }
- }
- void Server::serverThreadFunction(const std::string& host, int port) {
- try {
- std::cout << "Server thread starting, attempting to bind to " << host << ":" << port << std::endl;
- // Check if port is available before attempting to bind
- std::cout << "Checking if port " << port << " is available..." << std::endl;
- // Try to create a test socket to check if port is in use
- int test_socket = socket(AF_INET, SOCK_STREAM, 0);
- if (test_socket >= 0) {
- // Set SO_REUSEADDR to avoid TIME_WAIT issues
- int opt = 1;
- setsockopt(test_socket, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
- struct sockaddr_in addr;
- addr.sin_family = AF_INET;
- addr.sin_port = htons(port);
- addr.sin_addr.s_addr = INADDR_ANY;
- // Try to bind to the port
- if (bind(test_socket, (struct sockaddr*)&addr, sizeof(addr)) < 0) {
- close(test_socket);
- std::cerr << "ERROR: Port " << port << " is already in use! Cannot start server." << std::endl;
- std::cerr << "Please stop the existing instance or use a different port." << std::endl;
- m_isRunning.store(false);
- m_startupFailed.store(true);
- return;
- }
- close(test_socket);
- }
- std::cout << "Port " << port << " is available, proceeding with server startup..." << std::endl;
- std::cout << "Calling listen()..." << std::endl;
- // Set up a flag to track if listen started successfully
- std::atomic<bool> listenStarted{false};
- // We need to set m_isRunning after successful bind but before blocking
- // cpp-httplib doesn't provide a callback, so we set it optimistically
- // and clear it if listen() returns false
- m_isRunning.store(true);
- bool listenResult = m_httpServer->listen(host.c_str(), port);
- std::cout << "listen() returned: " << (listenResult ? "true" : "false") << std::endl;
- // If we reach here, server has stopped (either normally or due to error)
- m_isRunning.store(false);
- if (!listenResult) {
- std::cerr << "Server listen failed! This usually means port is in use or permission denied." << std::endl;
- }
- } catch (const std::exception& e) {
- std::cerr << "Exception in server thread: " << e.what() << std::endl;
- m_isRunning.store(false);
- }
- }
|