| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894489548964897489848994900490149024903490449054906490749084909491049114912491349144915491649174918491949204921492249234924492549264927492849294930493149324933493449354936493749384939494049414942494349444945494649474948494949504951495249534954495549564957495849594960496149624963496449654966496749684969497049714972497349744975497649774978497949804981498249834984498549864987498849894990499149924993499449954996499749984999500050015002500350045005500650075008500950105011501250135014501550165017501850195020502150225023502450255026502750285029503050315032503350345035503650375038503950405041504250435044504550465047504850495050505150525053505450555056505750585059506050615062506350645065506650675068506950705071507250735074507550765077507850795080508150825083508450855086508750885089509050915092509350945095509650975098509951005101510251035104510551065107510851095110511151125113511451155116511751185119512051215122512351245125512651275128512951305131513251335134513551365137513851395140514151425143514451455146514751485149515051515152515351545155515651575158515951605161516251635164516551665167516851695170517151725173517451755176517751785179518051815182518351845185518651875188518951905191519251935194519551965197519851995200520152025203520452055206520752085209521052115212521352145215521652175218521952205221522252235224522552265227522852295230523152325233523452355236523752385239524052415242524352445245524652475248524952505251525252535254525552565257525852595260526152625263526452655266526752685269527052715272527352745275527652775278527952805281528252835284528552865287 |
- #include "server.h"
- #include "model_manager.h"
- #include "generation_queue.h"
- #include "utils.h"
- #include "auth_middleware.h"
- #include "user_manager.h"
- #include "version.h"
- #include <httplib.h>
- #include <nlohmann/json.hpp>
- #include <iostream>
- #include <sstream>
- #include <fstream>
- #include <chrono>
- #include <random>
- #include <algorithm>
- #include <thread>
- #include <filesystem>
- // Include stb_image for loading images (implementation is in generation_queue.cpp)
- #include "../stable-diffusion.cpp-src/thirdparty/stb_image.h"
- #include <sys/socket.h>
- #include <netinet/in.h>
- #include <unistd.h>
- #include <arpa/inet.h>
- Server::Server(ModelManager* modelManager, GenerationQueue* generationQueue, const std::string& outputDir, const std::string& uiDir, const ServerConfig& config)
- : m_modelManager(modelManager)
- , m_generationQueue(generationQueue)
- , m_isRunning(false)
- , m_startupFailed(false)
- , m_port(config.port)
- , m_outputDir(outputDir)
- , m_uiDir(uiDir)
- , m_userManager(nullptr)
- , m_authMiddleware(nullptr)
- , m_config(config)
- {
- m_httpServer = std::make_unique<httplib::Server>();
- }
- Server::~Server() {
- stop();
- }
- bool Server::start(const std::string& host, int port) {
- if (m_isRunning.load()) {
- return false;
- }
- m_host = host;
- m_port = port;
- // Validate host and port
- if (host.empty() || (port < 1 || port > 65535)) {
- return false;
- }
- // Set up CORS headers
- setupCORS();
- // Register API endpoints
- registerEndpoints();
- // Reset startup flags
- m_startupFailed.store(false);
- // Start server in a separate thread
- m_serverThread = std::thread(&Server::serverThreadFunction, this, host, port);
- // Wait for server to actually start and bind to the port
- // Give more time for server to actually start and bind
- for (int i = 0; i < 100; i++) { // Wait up to 10 seconds
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
- // Check if startup failed early
- if (m_startupFailed.load()) {
- if (m_serverThread.joinable()) {
- m_serverThread.join();
- }
- return false;
- }
- if (m_isRunning.load()) {
- // Give it a moment more to ensure server is fully started
- std::this_thread::sleep_for(std::chrono::milliseconds(500));
- if (m_isRunning.load()) {
- return true;
- }
- }
- }
- if (m_isRunning.load()) {
- return true;
- } else {
- if (m_serverThread.joinable()) {
- m_serverThread.join();
- }
- return false;
- }
- }
- void Server::stop() {
- // Use atomic check to ensure thread safety
- bool wasRunning = m_isRunning.exchange(false);
- if (!wasRunning) {
- return; // Already stopped
- }
- if (m_httpServer) {
- m_httpServer->stop();
- // Give the server a moment to stop the blocking listen call
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
- // If server thread is still running, try to force unblock the listen call
- // by making a quick connection to the server port
- if (m_serverThread.joinable()) {
- try {
- // Create a quick connection to interrupt the blocking listen
- httplib::Client client("127.0.0.1", m_port);
- client.set_connection_timeout(0, m_config.connectionTimeoutMs * 1000); // Convert ms to microseconds
- client.set_read_timeout(0, m_config.readTimeoutMs * 1000); // Convert ms to microseconds
- client.set_write_timeout(0, m_config.writeTimeoutMs * 1000); // Convert ms to microseconds
- auto res = client.Get("/api/health");
- // We don't care about the response, just trying to unblock
- } catch (...) {
- // Ignore any connection errors - we're just trying to unblock
- }
- }
- }
- if (m_serverThread.joinable()) {
- m_serverThread.join();
- }
- }
- bool Server::isRunning() const {
- return m_isRunning.load();
- }
- void Server::waitForStop() {
- if (m_serverThread.joinable()) {
- m_serverThread.join();
- }
- }
- void Server::registerEndpoints() {
- // Register authentication endpoints first (before applying middleware)
- registerAuthEndpoints();
- // Health check endpoint (public)
- m_httpServer->Get("/api/health", [this](const httplib::Request& req, httplib::Response& res) {
- handleHealthCheck(req, res);
- });
- // API status endpoint (public)
- m_httpServer->Get("/api/status", [this](const httplib::Request& req, httplib::Response& res) {
- handleApiStatus(req, res);
- });
- // Version information endpoint (public)
- m_httpServer->Get("/api/version", [this](const httplib::Request& req, httplib::Response& res) {
- handleVersion(req, res);
- });
- // Apply authentication middleware to protected endpoints
- auto withAuth = [this](std::function<void(const httplib::Request&, httplib::Response&)> handler) {
- return [this, handler](const httplib::Request& req, httplib::Response& res) {
- if (m_authMiddleware) {
- AuthContext authContext = m_authMiddleware->authenticate(req, res);
- if (!authContext.authenticated) {
- m_authMiddleware->sendAuthError(res, authContext.errorMessage, authContext.errorCode);
- return;
- }
- }
- handler(req, res);
- };
- };
- // Specialized generation endpoints (protected)
- m_httpServer->Post("/api/generate/text2img", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleText2Img(req, res);
- }));
- m_httpServer->Post("/api/generate/img2img", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleImg2Img(req, res);
- }));
- m_httpServer->Post("/api/generate/controlnet", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleControlNet(req, res);
- }));
- m_httpServer->Post("/api/generate/upscale", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleUpscale(req, res);
- }));
- m_httpServer->Post("/api/generate/inpainting", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleInpainting(req, res);
- }));
- // Utility endpoints (now protected - require authentication)
- m_httpServer->Get("/api/samplers", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleSamplers(req, res);
- }));
- m_httpServer->Get("/api/schedulers", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleSchedulers(req, res);
- }));
- m_httpServer->Get("/api/parameters", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleParameters(req, res);
- }));
- m_httpServer->Post("/api/validate", [this](const httplib::Request& req, httplib::Response& res) {
- handleValidate(req, res);
- });
- m_httpServer->Post("/api/estimate", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleEstimate(req, res);
- }));
- m_httpServer->Get("/api/config", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleConfig(req, res);
- }));
- m_httpServer->Get("/api/system", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleSystem(req, res);
- }));
- m_httpServer->Post("/api/system/restart", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleSystemRestart(req, res);
- }));
- // Models list endpoint (now protected - require authentication)
- m_httpServer->Get("/api/models", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleModelsList(req, res);
- }));
- // Model-specific endpoints
- m_httpServer->Get("/api/models/(.*)", [this](const httplib::Request& req, httplib::Response& res) {
- handleModelInfo(req, res);
- });
- m_httpServer->Post("/api/models/(.*)/load", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleLoadModelById(req, res);
- }));
- m_httpServer->Post("/api/models/(.*)/unload", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleUnloadModelById(req, res);
- }));
- // Model management endpoints (now protected - require authentication)
- m_httpServer->Get("/api/models/types", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleModelTypes(req, res);
- }));
- m_httpServer->Get("/api/models/directories", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleModelDirectories(req, res);
- }));
- m_httpServer->Post("/api/models/refresh", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleRefreshModels(req, res);
- }));
- m_httpServer->Post("/api/models/hash", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleHashModels(req, res);
- }));
- m_httpServer->Post("/api/models/convert", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleConvertModel(req, res);
- }));
- m_httpServer->Get("/api/models/stats", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleModelStats(req, res);
- }));
- m_httpServer->Post("/api/models/batch", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleBatchModels(req, res);
- }));
- // Model validation endpoints (already protected with withAuth)
- m_httpServer->Post("/api/models/validate", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleValidateModel(req, res);
- }));
- m_httpServer->Post("/api/models/compatible", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleCheckCompatibility(req, res);
- }));
- m_httpServer->Post("/api/models/requirements", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleModelRequirements(req, res);
- }));
- // Queue status endpoint (now protected - require authentication)
- m_httpServer->Get("/api/queue/status", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleQueueStatus(req, res);
- }));
- // Download job output file endpoint (must be before job status endpoint to match more specific pattern first)
- // Note: This endpoint is public to allow frontend to display generated images without authentication
- m_httpServer->Get("/api/queue/job/(.*)/output/(.*)", [this](const httplib::Request& req, httplib::Response& res) {
- handleDownloadOutput(req, res);
- });
- // Get job output by job ID endpoint (public to allow frontend to display generated images without authentication)
- m_httpServer->Get("/api/v1/jobs/(.*)/output", [this](const httplib::Request& req, httplib::Response& res) {
- handleJobOutput(req, res);
- });
- // Download image from URL endpoint (public for CORS-free image handling)
- m_httpServer->Get("/api/image/download", [this](const httplib::Request& req, httplib::Response& res) {
- handleDownloadImageFromUrl(req, res);
- });
- // Image resize endpoint (protected)
- m_httpServer->Post("/api/image/resize", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleImageResize(req, res);
- }));
- // Image crop endpoint (protected)
- m_httpServer->Post("/api/image/crop", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleImageCrop(req, res);
- }));
- // Job status endpoint (now protected - require authentication)
- m_httpServer->Get("/api/queue/job/(.*)", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleJobStatus(req, res);
- }));
- // Cancel job endpoint (protected)
- m_httpServer->Post("/api/queue/cancel", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleCancelJob(req, res);
- }));
- // Clear queue endpoint (protected)
- m_httpServer->Post("/api/queue/clear", withAuth([this](const httplib::Request& req, httplib::Response& res) {
- handleClearQueue(req, res);
- }));
- // Serve static web UI files if uiDir is configured
- if (!m_uiDir.empty() && std::filesystem::exists(m_uiDir)) {
- std::cout << "Serving static UI files from: " << m_uiDir << " at /ui" << std::endl;
- // Read UI version from version.nlohmann::json if available
- std::string uiVersion = "unknown";
- std::string versionFilePath = m_uiDir + "/version.nlohmann::json";
- if (std::filesystem::exists(versionFilePath)) {
- try {
- std::ifstream versionFile(versionFilePath);
- if (versionFile.is_open()) {
- nlohmann::json versionData = nlohmann::json::parse(versionFile);
- if (versionData.contains("version")) {
- uiVersion = versionData["version"].get<std::string>();
- }
- versionFile.close();
- }
- } catch (const std::exception& e) {
- std::cerr << "Failed to read UI version: " << e.what() << std::endl;
- }
- }
- std::cout << "UI version: " << uiVersion << std::endl;
- // Serve dynamic config.js that provides runtime configuration to the web UI
- m_httpServer->Get("/ui/config.js", [this, uiVersion](const httplib::Request& /*req*/, httplib::Response& res) {
- // Generate JavaScript configuration with current server settings
- std::ostringstream configJs;
- configJs << "// Auto-generated configuration\n"
- << "window.__SERVER_CONFIG__ = {\n"
- << " apiUrl: 'http://" << m_host << ":" << m_port << "',\n"
- << " apiBasePath: '/api',\n"
- << " host: '" << m_host << "',\n"
- << " port: " << m_port << ",\n"
- << " uiVersion: '" << uiVersion << "',\n";
- // Add authentication method information
- if (m_authMiddleware) {
- auto authConfig = m_authMiddleware->getConfig();
- std::string authMethod = "none";
- switch (authConfig.authMethod) {
- case AuthMethod::UNIX:
- authMethod = "unix";
- break;
- case AuthMethod::JWT:
- authMethod = "jwt";
- break;
- default:
- authMethod = "none";
- break;
- }
- configJs << " authMethod: '" << authMethod << "',\n"
- << " authEnabled: " << (authConfig.authMethod != AuthMethod::NONE ? "true" : "false") << "\n";
- } else {
- configJs << " authMethod: 'none',\n"
- << " authEnabled: false\n";
- }
- configJs << "};\n";
- // No cache for config.js - always fetch fresh
- res.set_header("Cache-Control", "no-cache, no-store, must-revalidate");
- res.set_header("Pragma", "no-cache");
- res.set_header("Expires", "0");
- res.set_content(configJs.str(), "application/javascript");
- });
- // Set up file request handler for caching static assets
- m_httpServer->set_file_request_handler([uiVersion](const httplib::Request& req, httplib::Response& res) {
- // Add cache headers based on file type and version
- std::string path = req.path;
- // For versioned static assets (.js, .css, images), use long cache
- if (path.find("/_next/") != std::string::npos ||
- path.find(".js") != std::string::npos ||
- path.find(".css") != std::string::npos ||
- path.find(".png") != std::string::npos ||
- path.find(".jpg") != std::string::npos ||
- path.find(".svg") != std::string::npos ||
- path.find(".ico") != std::string::npos ||
- path.find(".woff") != std::string::npos ||
- path.find(".woff2") != std::string::npos ||
- path.find(".ttf") != std::string::npos) {
- // Long cache (1 year) for static assets
- res.set_header("Cache-Control", "public, max-age=31536000, immutable");
- // Add ETag based on UI version for cache validation
- res.set_header("ETag", "\"" + uiVersion + "\"");
- // Check If-None-Match for conditional requests
- if (req.has_header("If-None-Match")) {
- std::string clientETag = req.get_header_value("If-None-Match");
- if (clientETag == "\"" + uiVersion + "\"") {
- res.status = 304; // Not Modified
- return;
- }
- }
- } else if (path.find(".html") != std::string::npos || path == "/ui/" || path == "/ui") {
- // HTML files should revalidate but can be cached briefly
- res.set_header("Cache-Control", "public, max-age=0, must-revalidate");
- res.set_header("ETag", "\"" + uiVersion + "\"");
- }
- });
- // Create a handler for UI routes with authentication check
- auto uiHandler = [this](const httplib::Request& req, httplib::Response& res) {
- // Check if authentication is enabled
- if (m_authMiddleware) {
- auto authConfig = m_authMiddleware->getConfig();
- if (authConfig.authMethod != AuthMethod::NONE) {
- // Authentication is enabled, check if user is authenticated
- AuthContext authContext = m_authMiddleware->authenticate(req, res);
- // For Unix auth, we need to check if the user is authenticated
- // The authenticateUnix function will return a guest context for UI requests
- // when no Authorization header is present, but we still need to show the login page
- if (!authContext.authenticated) {
- // Check if this is a request for a static asset (JS, CSS, images)
- // These should be served even without authentication to allow the login page to work
- bool isStaticAsset = false;
- std::string path = req.path;
- if (path.find(".js") != std::string::npos ||
- path.find(".css") != std::string::npos ||
- path.find(".png") != std::string::npos ||
- path.find(".jpg") != std::string::npos ||
- path.find(".jpeg") != std::string::npos ||
- path.find(".svg") != std::string::npos ||
- path.find(".ico") != std::string::npos ||
- path.find("/_next/") != std::string::npos) {
- isStaticAsset = true;
- }
- // For static assets, allow them to be served without authentication
- if (isStaticAsset) {
- // Continue to serve the file
- } else {
- // For HTML requests, redirect to login page
- if (req.path.find(".html") != std::string::npos ||
- req.path == "/ui/" || req.path == "/ui") {
- // Serve the login page instead of the requested page
- std::string loginPagePath = m_uiDir + "/login.html";
- if (std::filesystem::exists(loginPagePath)) {
- std::ifstream loginFile(loginPagePath);
- if (loginFile.is_open()) {
- std::string content((std::istreambuf_iterator<char>(loginFile)),
- std::istreambuf_iterator<char>());
- res.set_content(content, "text/html");
- return;
- }
- }
- // If login.html doesn't exist, serve a simple login page
- std::string simpleLoginPage = R"(
- <!DOCTYPE html>
- <html>
- <head>
- <title>Login Required</title>
- <style>
- body { font-family: Arial, sans-serif; max-width: 500px; margin: 100px auto; padding: 20px; }
- .form-group { margin-bottom: 15px; }
- label { display: block; margin-bottom: 5px; }
- input { width: 100%; padding: 8px; box-sizing: border-box; }
- button { background-color: #007bff; color: white; padding: 10px 15px; border: none; cursor: pointer; }
- .error { color: red; margin-top: 10px; }
- </style>
- </head>
- <body>
- <h1>Login Required</h1>
- <p>Please enter your username to continue.</p>
- <form id="loginForm">
- <div class="form-group">
- <label for="username">Username:</label>
- <input type="text" id="username" name="username" required>
- </div>
- <button type="submit">Login</button>
- </form>
- <div id="error" class="error"></div>
- <script>
- document.getElementById('loginForm').addEventListener('submit', async (e) => {
- e.preventDefault();
- const username = document.getElementById('username').value;
- const errorDiv = document.getElementById('error');
- try {
- const response = await fetch('/api/auth/login', {
- method: 'POST',
- headers: { 'Content-Type': 'application/nlohmann::json' },
- body: JSON.stringify({ username })
- });
- if (response.ok) {
- const data = await response.nlohmann::json();
- localStorage.setItem('auth_token', data.token);
- localStorage.setItem('unix_user', username);
- window.location.reload();
- } else {
- const error = await response.nlohmann::json();
- errorDiv.textContent = error.message || 'Login failed';
- }
- } catch (err) {
- errorDiv.textContent = 'Login failed: ' + err.message;
- }
- });
- </script>
- </body>
- </html>
- )";
- res.set_content(simpleLoginPage, "text/html");
- return;
- } else {
- // For non-HTML files, return unauthorized
- m_authMiddleware->sendAuthError(res, "Authentication required", "AUTH_REQUIRED");
- return;
- }
- }
- }
- }
- }
- // If we get here, either auth is disabled or user is authenticated
- // Serve the requested file
- std::string filePath = req.path.substr(3); // Remove "/ui" prefix
- if (filePath.empty() || filePath == "/") {
- filePath = "/index.html";
- }
- std::string fullPath = m_uiDir + filePath;
- if (std::filesystem::exists(fullPath) && std::filesystem::is_regular_file(fullPath)) {
- std::ifstream file(fullPath, std::ios::binary);
- if (file.is_open()) {
- std::string content((std::istreambuf_iterator<char>(file)),
- std::istreambuf_iterator<char>());
- // Determine content type based on file extension
- std::string contentType = "text/plain";
- if (filePath.find(".html") != std::string::npos) {
- contentType = "text/html";
- } else if (filePath.find(".js") != std::string::npos) {
- contentType = "application/javascript";
- } else if (filePath.find(".css") != std::string::npos) {
- contentType = "text/css";
- } else if (filePath.find(".png") != std::string::npos) {
- contentType = "image/png";
- } else if (filePath.find(".jpg") != std::string::npos || filePath.find(".jpeg") != std::string::npos) {
- contentType = "image/jpeg";
- } else if (filePath.find(".svg") != std::string::npos) {
- contentType = "image/svg+xml";
- }
- res.set_content(content, contentType);
- } else {
- res.status = 404;
- res.set_content("File not found", "text/plain");
- }
- } else {
- // For SPA routing, if the file doesn't exist, serve index.html
- // This allows Next.js to handle client-side routing
- std::string indexPath = m_uiDir + "/index.html";
- if (std::filesystem::exists(indexPath)) {
- std::ifstream indexFile(indexPath, std::ios::binary);
- if (indexFile.is_open()) {
- std::string content((std::istreambuf_iterator<char>(indexFile)),
- std::istreambuf_iterator<char>());
- res.set_content(content, "text/html");
- } else {
- res.status = 404;
- res.set_content("File not found", "text/plain");
- }
- } else {
- res.status = 404;
- res.set_content("File not found", "text/plain");
- }
- }
- };
- // Set up UI routes with authentication
- m_httpServer->Get("/ui/.*", uiHandler);
- // Redirect /ui to /ui/ to ensure proper routing
- m_httpServer->Get("/ui", [](const httplib::Request& /*req*/, httplib::Response& res) {
- res.set_redirect("/ui/");
- });
- }
- }
- void Server::setAuthComponents(std::shared_ptr<UserManager> userManager, std::shared_ptr<AuthMiddleware> authMiddleware) {
- m_userManager = userManager;
- m_authMiddleware = authMiddleware;
- }
- void Server::registerAuthEndpoints() {
- // Login endpoint
- m_httpServer->Post("/api/auth/login", [this](const httplib::Request& req, httplib::Response& res) {
- handleLogin(req, res);
- });
- // Logout endpoint
- m_httpServer->Post("/api/auth/logout", [this](const httplib::Request& req, httplib::Response& res) {
- handleLogout(req, res);
- });
- // Token validation endpoint
- m_httpServer->Get("/api/auth/validate", [this](const httplib::Request& req, httplib::Response& res) {
- handleValidateToken(req, res);
- });
- // Refresh token endpoint
- m_httpServer->Post("/api/auth/refresh", [this](const httplib::Request& req, httplib::Response& res) {
- handleRefreshToken(req, res);
- });
- // Get current user endpoint
- m_httpServer->Get("/api/auth/me", [this](const httplib::Request& req, httplib::Response& res) {
- handleGetCurrentUser(req, res);
- });
- }
- void Server::handleLogin(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_userManager || !m_authMiddleware) {
- sendErrorResponse(res, "Authentication system not available", 500, "AUTH_UNAVAILABLE", requestId);
- return;
- }
- // Parse request body
- nlohmann::json requestJson;
- try {
- requestJson = nlohmann::json::parse(req.body);
- } catch (const nlohmann::json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- return;
- }
- // Check if using Unix authentication
- if (m_authMiddleware->getConfig().authMethod == AuthMethod::UNIX) {
- // For Unix auth, get username and password from request body
- std::string username = requestJson.value("username", "");
- std::string password = requestJson.value("password", "");
- if (username.empty()) {
- sendErrorResponse(res, "Missing username", 400, "MISSING_USERNAME", requestId);
- return;
- }
- // Check if PAM is enabled - if so, password is required
- if (m_userManager->isPamAuthEnabled() && password.empty()) {
- sendErrorResponse(res, "Password is required for Unix authentication", 400, "MISSING_PASSWORD", requestId);
- return;
- }
- // Authenticate Unix user (with or without password depending on PAM)
- auto result = m_userManager->authenticateUnix(username, password);
- if (!result.success) {
- sendErrorResponse(res, result.errorMessage, 401, "UNIX_AUTH_FAILED", requestId);
- return;
- }
- // Generate simple token for Unix auth
- std::string token = "unix_token_" + std::to_string(std::chrono::duration_cast<std::chrono::seconds>(
- std::chrono::system_clock::now().time_since_epoch()).count()) + "_" + username;
- nlohmann::json response = {
- {"token", token},
- {"user", {
- {"id", result.userId},
- {"username", result.username},
- {"role", result.role},
- {"permissions", result.permissions}
- }},
- {"message", "Unix authentication successful"}
- };
- sendJsonResponse(res, response);
- return;
- }
- // For non-Unix auth, validate required fields
- if (!requestJson.contains("username") || !requestJson.contains("password")) {
- sendErrorResponse(res, "Missing username or password", 400, "MISSING_CREDENTIALS", requestId);
- return;
- }
- std::string username = requestJson["username"];
- std::string password = requestJson["password"];
- // Authenticate user
- auto result = m_userManager->authenticateUser(username, password);
- if (!result.success) {
- sendErrorResponse(res, result.errorMessage, 401, "INVALID_CREDENTIALS", requestId);
- return;
- }
- // Generate JWT token if using JWT auth
- std::string token;
- if (m_authMiddleware->getConfig().authMethod == AuthMethod::JWT) {
- // For now, create a simple token (in a real implementation, use JWT)
- token = "token_" + std::to_string(std::chrono::duration_cast<std::chrono::seconds>(
- std::chrono::system_clock::now().time_since_epoch()).count()) + "_" + username;
- }
- nlohmann::json response = {
- {"token", token},
- {"user", {
- {"id", result.userId},
- {"username", result.username},
- {"role", result.role},
- {"permissions", result.permissions}
- }},
- {"message", "Login successful"}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Login failed: ") + e.what(), 500, "LOGIN_ERROR", requestId);
- }
- }
- void Server::handleLogout(const httplib::Request& /*req*/, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- // For now, just return success (in a real implementation, invalidate the token)
- nlohmann::json response = {
- {"message", "Logout successful"}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Logout failed: ") + e.what(), 500, "LOGOUT_ERROR", requestId);
- }
- }
- void Server::handleValidateToken(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_userManager || !m_authMiddleware) {
- sendErrorResponse(res, "Authentication system not available", 500, "AUTH_UNAVAILABLE", requestId);
- return;
- }
- // Extract token from header
- std::string authHeader = req.get_header_value("Authorization");
- if (authHeader.empty()) {
- sendErrorResponse(res, "Missing authorization token", 401, "MISSING_TOKEN", requestId);
- return;
- }
- // Simple token validation (in a real implementation, validate JWT)
- // For now, just check if it starts with "token_"
- if (authHeader.find("Bearer ") != 0) {
- sendErrorResponse(res, "Invalid authorization header format", 401, "INVALID_HEADER", requestId);
- return;
- }
- std::string token = authHeader.substr(7); // Remove "Bearer "
- if (token.find("token_") != 0) {
- sendErrorResponse(res, "Invalid token", 401, "INVALID_TOKEN", requestId);
- return;
- }
- // Extract username from token (simple format: token_timestamp_username)
- size_t last_underscore = token.find_last_of('_');
- if (last_underscore == std::string::npos) {
- sendErrorResponse(res, "Invalid token format", 401, "INVALID_TOKEN", requestId);
- return;
- }
- std::string username = token.substr(last_underscore + 1);
- // Get user info
- auto userInfo = m_userManager->getUserInfoByUsername(username);
- if (userInfo.id.empty()) {
- sendErrorResponse(res, "User not found", 401, "USER_NOT_FOUND", requestId);
- return;
- }
- nlohmann::json response = {
- {"user", {
- {"id", userInfo.id},
- {"username", userInfo.username},
- {"role", userInfo.role},
- {"permissions", userInfo.permissions}
- }},
- {"valid", true}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Token validation failed: ") + e.what(), 500, "VALIDATION_ERROR", requestId);
- }
- }
- void Server::handleRefreshToken(const httplib::Request& /*req*/, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- // For now, just return a new token (in a real implementation, refresh JWT)
- nlohmann::json response = {
- {"token", "new_token_" + std::to_string(std::chrono::duration_cast<std::chrono::seconds>(
- std::chrono::system_clock::now().time_since_epoch()).count())},
- {"message", "Token refreshed successfully"}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Token refresh failed: ") + e.what(), 500, "REFRESH_ERROR", requestId);
- }
- }
- void Server::handleGetCurrentUser(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_userManager || !m_authMiddleware) {
- sendErrorResponse(res, "Authentication system not available", 500, "AUTH_UNAVAILABLE", requestId);
- return;
- }
- // Authenticate the request
- AuthContext authContext = m_authMiddleware->authenticate(req, res);
- if (!authContext.authenticated) {
- sendErrorResponse(res, "Authentication required", 401, "AUTH_REQUIRED", requestId);
- return;
- }
- nlohmann::json response = {
- {"user", {
- {"id", authContext.userId},
- {"username", authContext.username},
- {"role", authContext.role},
- {"permissions", authContext.permissions}
- }}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Get current user failed: ") + e.what(), 500, "USER_ERROR", requestId);
- }
- }
- void Server::setupCORS() {
- // Use post-routing handler to set CORS headers after the response is generated
- // This ensures we don't duplicate headers that may be set by other handlers
- m_httpServer->set_post_routing_handler([](const httplib::Request& /*req*/, httplib::Response& res) {
- // Only add CORS headers if they haven't been set already
- if (!res.has_header("Access-Control-Allow-Origin")) {
- res.set_header("Access-Control-Allow-Origin", "*");
- }
- if (!res.has_header("Access-Control-Allow-Methods")) {
- res.set_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS");
- }
- if (!res.has_header("Access-Control-Allow-Headers")) {
- res.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
- }
- });
- // Handle OPTIONS requests for CORS preflight (API endpoints only)
- m_httpServer->Options("/api/.*", [](const httplib::Request&, httplib::Response& res) {
- res.set_header("Access-Control-Allow-Origin", "*");
- res.set_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS");
- res.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
- res.status = 200;
- });
- }
- void Server::handleHealthCheck(const httplib::Request& /*req*/, httplib::Response& res) {
- try {
- nlohmann::json response = {
- {"status", "healthy"},
- {"timestamp", std::chrono::duration_cast<std::chrono::seconds>(
- std::chrono::system_clock::now().time_since_epoch()).count()},
- {"version", sd_rest::VERSION_INFO.version_full}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Health check failed: ") + e.what(), 500);
- }
- }
- void Server::handleApiStatus(const httplib::Request& /*req*/, httplib::Response& res) {
- try {
- nlohmann::json response = {
- {"server", {
- {"running", m_isRunning.load()},
- {"host", m_host},
- {"port", m_port}
- }},
- {"generation_queue", {
- {"running", m_generationQueue ? m_generationQueue->isRunning() : false},
- {"queue_size", m_generationQueue ? m_generationQueue->getQueueSize() : 0},
- {"active_generations", m_generationQueue ? m_generationQueue->getActiveGenerations() : 0}
- }},
- {"models", {
- {"loaded_count", m_modelManager ? m_modelManager->getLoadedModelsCount() : 0},
- {"available_count", m_modelManager ? m_modelManager->getAvailableModelsCount() : 0}
- }}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Status check failed: ") + e.what(), 500);
- }
- }
- void Server::handleVersion(const httplib::Request& /*req*/, httplib::Response& res) {
- try {
- nlohmann::json response = {
- {"version", sd_rest::VERSION_INFO.version_full},
- {"type", sd_rest::VERSION_INFO.version_type},
- {"commit", {
- {"short", sd_rest::VERSION_INFO.commit_short},
- {"full", sd_rest::VERSION_INFO.commit_full}
- }},
- {"branch", sd_rest::VERSION_INFO.branch},
- {"clean", sd_rest::VERSION_INFO.is_clean},
- {"build_time", sd_rest::VERSION_INFO.build_time}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Version check failed: ") + e.what(), 500);
- }
- }
- // Helper function to convert ModelDetails vector to JSON array
- nlohmann::json Server::modelDetailsToJson(const std::vector<ModelManager::ModelDetails>& modelDetails) {
- nlohmann::json jsonArray = nlohmann::json::array();
-
- for (const auto& detail : modelDetails) {
- nlohmann::json modelJson = {
- {"name", detail.name},
- {"exists", detail.exists},
- {"type", detail.type},
- {"file_size", detail.file_size}
- };
-
- // Handle path and sha256 separately to avoid type mismatch
- if (detail.exists) {
- modelJson["path"] = detail.path;
- modelJson["sha256"] = detail.sha256;
- } else {
- modelJson["path"] = nullptr;
- modelJson["sha256"] = "";
- }
-
- // Add conditional fields for required/recommended models
- if (detail.is_required) {
- modelJson["is_required"] = true;
- }
- if (detail.is_recommended) {
- modelJson["is_recommended"] = true;
- }
-
- jsonArray.push_back(modelJson);
- }
-
- return jsonArray;
- }
- // Helper function to determine which recommended fields to include based on architecture
- std::map<std::string, bool> Server::getRecommendedModelFields(const std::string& architecture) {
- std::map<std::string, bool> recommendedFields;
-
- // Initialize all fields as false (will be set to null if not applicable)
- recommendedFields["recommended_vae"] = false;
- recommendedFields["recommended_clip_l"] = false;
- recommendedFields["recommended_clip_g"] = false;
- recommendedFields["recommended_t5xxl"] = false;
- recommendedFields["recommended_clip_vision"] = false;
- recommendedFields["recommended_qwen2vl"] = false;
-
- // Architecture-specific field inclusion based on actual architecture strings
- if (architecture.find("Stable Diffusion 1.5") != std::string::npos) {
- // SD 1.x: recommended_vae only
- recommendedFields["recommended_vae"] = true;
- } else if (architecture.find("Stable Diffusion XL") != std::string::npos) {
- // SDXL: recommended_vae only
- recommendedFields["recommended_vae"] = true;
- } else if (architecture.find("Modern Architecture") != std::string::npos ||
- architecture.find("Flux Dev") != std::string::npos ||
- architecture.find("Flux Chroma") != std::string::npos) {
- // FLUX/SD3/Modern Architecture: recommended_vae, recommended_clip_l, recommended_t5xxl
- recommendedFields["recommended_vae"] = true;
- recommendedFields["recommended_clip_l"] = true;
- recommendedFields["recommended_t5xxl"] = true;
- } else if (architecture.find("SD 3") != std::string::npos) {
- // SD3: recommended_vae, recommended_clip_l, recommended_clip_g, recommended_t5xxl
- recommendedFields["recommended_vae"] = true;
- recommendedFields["recommended_clip_l"] = true;
- recommendedFields["recommended_clip_g"] = true;
- recommendedFields["recommended_t5xxl"] = true;
- } else if (architecture.find("Wan") != std::string::npos) {
- // Wan models: recommended_vae, recommended_t5xxl, recommended_clip_vision
- recommendedFields["recommended_vae"] = true;
- recommendedFields["recommended_t5xxl"] = true;
- recommendedFields["recommended_clip_vision"] = true;
- } else if (architecture.find("Qwen") != std::string::npos) {
- // Qwen models: recommended_vae, recommended_qwen2vl
- recommendedFields["recommended_vae"] = true;
- recommendedFields["recommended_qwen2vl"] = true;
- }
- // For UNKNOWN architecture, keep all fields false
-
- return recommendedFields;
- }
- // Helper function to populate recommended models with existence information
- void Server::populateRecommendedModels(nlohmann::json& response, const ModelManager::ModelInfo& modelInfo) {
- if (modelInfo.requiredModels.empty()) {
- return;
- }
-
- // Check existence of required models
- auto requiredModelsDetails = m_modelManager->checkRequiredModelsExistence(modelInfo.requiredModels);
-
- // Get the recommended fields for this architecture
- auto recommendedFields = getRecommendedModelFields(modelInfo.architecture);
-
- // Group models by type
- std::map<std::string, std::vector<ModelManager::ModelDetails>> modelsByType;
- for (const auto& detail : requiredModelsDetails) {
- modelsByType[detail.type].push_back(detail);
- }
-
- // Populate recommended fields based on model types and architecture requirements
- for (const auto& [type, models] : modelsByType) {
- if (type == "VAE" && recommendedFields["recommended_vae"]) {
- response["recommended_vae"] = modelDetailsToJson(models);
- } else if (type == "CLIP-L" && recommendedFields["recommended_clip_l"]) {
- response["recommended_clip_l"] = modelDetailsToJson(models);
- } else if (type == "CLIP-G" && recommendedFields["recommended_clip_g"]) {
- response["recommended_clip_g"] = modelDetailsToJson(models);
- } else if (type == "T5XXL" && recommendedFields["recommended_t5xxl"]) {
- response["recommended_t5xxl"] = modelDetailsToJson(models);
- } else if (type == "CLIP-Vision" && recommendedFields["recommended_clip_vision"]) {
- response["recommended_clip_vision"] = modelDetailsToJson(models);
- } else if (type == "Qwen2VL" && recommendedFields["recommended_qwen2vl"]) {
- response["recommended_qwen2vl"] = modelDetailsToJson(models);
- }
- }
-
- // Set non-applicable fields to null
- for (const auto& [fieldName, shouldInclude] : recommendedFields) {
- if (!shouldInclude || !response.contains(fieldName)) {
- response[fieldName] = nlohmann::json(nullptr);
- }
- }
- }
- void Server::handleModelsList(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Parse query parameters for enhanced filtering
- std::string typeFilter = req.get_param_value("type");
- std::string searchQuery = req.get_param_value("search");
- std::string sortBy = req.get_param_value("sort_by");
- std::string sortOrder = req.get_param_value("sort_order");
- std::string dateFilter = req.get_param_value("date");
- std::string sizeFilter = req.get_param_value("size");
- // Pagination parameters - only apply if limit is explicitly provided
- int page = 1;
- int limit = 50;
- bool usePagination = false;
- try {
- if (!req.get_param_value("limit").empty()) {
- limit = std::stoi(req.get_param_value("limit"));
- // Special case: limit<=0 means return all models (no pagination)
- if (limit <= 0) {
- usePagination = false;
- limit = INT_MAX; // Set to very large number to effectively disable pagination
- } else {
- usePagination = true;
- if (!req.get_param_value("page").empty()) {
- page = std::stoi(req.get_param_value("page"));
- if (page < 1) page = 1;
- }
- }
- }
- } catch (const std::exception& e) {
- sendErrorResponse(res, "Invalid pagination parameters", 400, "INVALID_PAGINATION", requestId);
- return;
- }
- // Filter parameters
- bool includeLoaded = req.get_param_value("loaded") == "true";
- bool includeUnloaded = req.get_param_value("unloaded") == "true";
- (void)req.get_param_value("include_metadata"); // unused but kept for API compatibility
- (void)req.get_param_value("include_thumbnails"); // unused but kept for API compatibility
- // Get all models
- auto allModels = m_modelManager->getAllModels();
- nlohmann::json models = nlohmann::json::array();
- // Apply filters and build response
- for (const auto& pair : allModels) {
- const auto& modelInfo = pair.second;
- // Apply type filter
- if (!typeFilter.empty()) {
- ModelType filterType = ModelManager::stringToModelType(typeFilter);
- if (modelInfo.type != filterType) continue;
- }
- // Apply loaded/unloaded filters
- if (includeLoaded && !modelInfo.isLoaded) continue;
- if (includeUnloaded && modelInfo.isLoaded) continue;
- // Apply search filter (case-insensitive search in name and description)
- if (!searchQuery.empty()) {
- std::string searchLower = searchQuery;
- std::transform(searchLower.begin(), searchLower.end(), searchLower.begin(), ::tolower);
- std::string nameLower = modelInfo.name;
- std::transform(nameLower.begin(), nameLower.end(), nameLower.begin(), ::tolower);
- std::string descLower = modelInfo.description;
- std::transform(descLower.begin(), descLower.end(), descLower.begin(), ::tolower);
- if (nameLower.find(searchLower) == std::string::npos &&
- descLower.find(searchLower) == std::string::npos) {
- continue;
- }
- }
- // Apply date filter (simplified - expects "recent", "old", or YYYY-MM-DD)
- if (!dateFilter.empty()) {
- auto now = std::filesystem::file_time_type::clock::now();
- auto modelTime = modelInfo.modifiedAt;
- auto duration = std::chrono::duration_cast<std::chrono::hours>(now - modelTime).count();
- if (dateFilter == "recent" && duration > 24 * 7) continue; // Older than 1 week
- if (dateFilter == "old" && duration < 24 * 30) continue; // Newer than 1 month
- }
- // Apply size filter (expects "small", "medium", "large", or size in MB)
- if (!sizeFilter.empty()) {
- double sizeMB = modelInfo.fileSize / (1024.0 * 1024.0);
- if (sizeFilter == "small" && sizeMB > 1024) continue; // > 1GB
- if (sizeFilter == "medium" && (sizeMB < 1024 || sizeMB > 4096)) continue; // < 1GB or > 4GB
- if (sizeFilter == "large" && sizeMB < 4096) continue; // < 4GB
- // Try to parse as specific size in MB
- try {
- double maxSizeMB = std::stod(sizeFilter);
- if (sizeMB > maxSizeMB) continue;
- } catch (...) {
- // Ignore if parsing fails
- }
- }
- // Build model JSON with enhanced structure
- nlohmann::json modelJson = {
- {"name", modelInfo.name},
- {"type", ModelManager::modelTypeToString(modelInfo.type)},
- {"file_size", modelInfo.fileSize},
- {"file_size_mb", modelInfo.fileSize / (1024.0 * 1024.0)},
- {"sha256", modelInfo.sha256.empty() ? nullptr : nlohmann::json(modelInfo.sha256)},
- {"sha256_short", (modelInfo.sha256.empty() || modelInfo.sha256.length() < 10) ? nullptr : nlohmann::json(modelInfo.sha256.substr(0, 10))}
- };
- // Add architecture information if available (checkpoints only)
- if (!modelInfo.architecture.empty()) {
- modelJson["architecture"] = modelInfo.architecture;
- modelJson["recommended_vae"] = modelInfo.recommendedVAE.empty() ? nullptr : nlohmann::json(modelInfo.recommendedVAE);
- if (modelInfo.recommendedWidth > 0) {
- modelJson["recommended_width"] = modelInfo.recommendedWidth;
- }
- if (modelInfo.recommendedHeight > 0) {
- modelJson["recommended_height"] = modelInfo.recommendedHeight;
- }
- if (modelInfo.recommendedSteps > 0) {
- modelJson["recommended_steps"] = modelInfo.recommendedSteps;
- }
- if (!modelInfo.recommendedSampler.empty()) {
- modelJson["recommended_sampler"] = modelInfo.recommendedSampler;
- }
-
- // Enhanced model information with existence checking
- if (!modelInfo.requiredModels.empty()) {
- auto requiredModelsDetails = m_modelManager->checkRequiredModelsExistence(modelInfo.requiredModels);
- modelJson["required_models"] = modelDetailsToJson(requiredModelsDetails);
-
- // Populate recommended models based on architecture
- populateRecommendedModels(modelJson, modelInfo);
- }
-
- // Backward compatibility - keep existing fields
- if (!modelInfo.missingModels.empty()) {
- modelJson["missing_models"] = modelInfo.missingModels;
- modelJson["has_missing_dependencies"] = true;
- } else {
- modelJson["has_missing_dependencies"] = false;
- }
- }
- models.push_back(modelJson);
- }
- // Apply sorting
- if (!sortBy.empty()) {
- std::sort(models.begin(), models.end(), [&sortBy, &sortOrder](const nlohmann::json& a, const nlohmann::json& b) {
- bool ascending = sortOrder != "desc";
- if (sortBy == "name") {
- return ascending ? a["name"] < b["name"] : a["name"] > b["name"];
- } else if (sortBy == "size") {
- return ascending ? a["file_size"] < b["file_size"] : a["file_size"] > b["file_size"];
- } else if (sortBy == "date") {
- return ascending ? a["last_modified"] < b["last_modified"] : a["last_modified"] > b["last_modified"];
- } else if (sortBy == "type") {
- return ascending ? a["type"] < b["type"] : a["type"] > b["type"];
- } else if (sortBy == "loaded") {
- return ascending ? a["is_loaded"] < b["is_loaded"] : a["is_loaded"] > b["is_loaded"];
- }
- return false;
- });
- }
- // Apply pagination only if limit parameter was provided
- int totalCount = models.size();
- nlohmann::json paginatedModels = nlohmann::json::array();
- nlohmann::json paginationInfo = nlohmann::json::object();
- if (usePagination) {
- // Apply pagination
- int totalPages = (totalCount + limit - 1) / limit;
- int startIndex = (page - 1) * limit;
- int endIndex = std::min(startIndex + limit, totalCount);
- for (int i = startIndex; i < endIndex; ++i) {
- paginatedModels.push_back(models[i]);
- }
- paginationInfo = {
- {"page", page},
- {"limit", limit},
- {"total_count", totalCount},
- {"total_pages", totalPages},
- {"has_next", page < totalPages},
- {"has_prev", page > 1}
- };
- } else {
- // Return all models without pagination
- paginatedModels = models;
- paginationInfo = {
- {"page", 1},
- {"limit", totalCount},
- {"total_count", totalCount},
- {"total_pages", 1},
- {"has_next", false},
- {"has_prev", false}
- };
- }
- // Build comprehensive response
- nlohmann::json response = {
- {"models", paginatedModels},
- {"pagination", paginationInfo},
- {"filters_applied", {
- {"type", typeFilter.empty() ? nlohmann::json(nullptr) : nlohmann::json(typeFilter)},
- {"search", searchQuery.empty() ? nlohmann::json(nullptr) : nlohmann::json(searchQuery)},
- {"date", dateFilter.empty() ? nlohmann::json(nullptr) : nlohmann::json(dateFilter)},
- {"size", sizeFilter.empty() ? nlohmann::json(nullptr) : nlohmann::json(sizeFilter)},
- {"loaded", includeLoaded ? nlohmann::json(true) : nlohmann::json(nullptr)},
- {"unloaded", includeUnloaded ? nlohmann::json(true) : nlohmann::json(nullptr)}
- }},
- {"sorting", {
- {"sort_by", sortBy.empty() ? "name" : nlohmann::json(sortBy)},
- {"sort_order", sortOrder.empty() ? "asc" : nlohmann::json(sortOrder)}
- }},
- {"statistics", {
- {"loaded_count", m_modelManager->getLoadedModelsCount()},
- {"available_count", m_modelManager->getAvailableModelsCount()}
- }},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to list models: ") + e.what(), 500, "MODEL_LIST_ERROR", requestId);
- }
- }
- void Server::handleQueueStatus(const httplib::Request& /*req*/, httplib::Response& res) {
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500);
- return;
- }
- // Get detailed queue status
- auto jobs = m_generationQueue->getQueueStatus();
- // Convert jobs to JSON
- nlohmann::json jobsJson = nlohmann::json::array();
- for (const auto& job : jobs) {
- std::string statusStr;
- switch (job.status) {
- case GenerationStatus::QUEUED: statusStr = "queued"; break;
- case GenerationStatus::PROCESSING: statusStr = "processing"; break;
- case GenerationStatus::COMPLETED: statusStr = "completed"; break;
- case GenerationStatus::FAILED: statusStr = "failed"; break;
- }
- // Convert time points to timestamps
- auto queuedTime = std::chrono::duration_cast<std::chrono::milliseconds>(
- job.queuedTime.time_since_epoch()).count();
- auto startTime = std::chrono::duration_cast<std::chrono::milliseconds>(
- job.startTime.time_since_epoch()).count();
- auto endTime = std::chrono::duration_cast<std::chrono::milliseconds>(
- job.endTime.time_since_epoch()).count();
- jobsJson.push_back({
- {"id", job.id},
- {"status", statusStr},
- {"prompt", job.prompt},
- {"queued_time", queuedTime},
- {"start_time", startTime > 0 ? nlohmann::json(startTime) : nlohmann::json(nullptr)},
- {"end_time", endTime > 0 ? nlohmann::json(endTime) : nlohmann::json(nullptr)},
- {"position", job.position},
- {"progress", job.progress}
- });
- }
- nlohmann::json response = {
- {"queue", {
- {"size", m_generationQueue->getQueueSize()},
- {"active_generations", m_generationQueue->getActiveGenerations()},
- {"running", m_generationQueue->isRunning()},
- {"jobs", jobsJson}
- }}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Queue status check failed: ") + e.what(), 500);
- }
- }
- void Server::handleJobStatus(const httplib::Request& req, httplib::Response& res) {
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500);
- return;
- }
- // Extract job ID from URL path
- std::string jobId = req.matches[1].str();
- if (jobId.empty()) {
- sendErrorResponse(res, "Missing job ID", 400);
- return;
- }
- // Get job information
- auto jobInfo = m_generationQueue->getJobInfo(jobId);
- if (jobInfo.id.empty()) {
- sendErrorResponse(res, "Job not found", 404);
- return;
- }
- // Convert status to string
- std::string statusStr;
- switch (jobInfo.status) {
- case GenerationStatus::QUEUED: statusStr = "queued"; break;
- case GenerationStatus::PROCESSING: statusStr = "processing"; break;
- case GenerationStatus::COMPLETED: statusStr = "completed"; break;
- case GenerationStatus::FAILED: statusStr = "failed"; break;
- }
- // Convert time points to timestamps
- auto queuedTime = std::chrono::duration_cast<std::chrono::milliseconds>(
- jobInfo.queuedTime.time_since_epoch()).count();
- auto startTime = std::chrono::duration_cast<std::chrono::milliseconds>(
- jobInfo.startTime.time_since_epoch()).count();
- auto endTime = std::chrono::duration_cast<std::chrono::milliseconds>(
- jobInfo.endTime.time_since_epoch()).count();
- // Create download URLs for output files
- nlohmann::json outputUrls = nlohmann::json::array();
- for (const auto& filePath : jobInfo.outputFiles) {
- // Extract filename from full path
- std::filesystem::path p(filePath);
- std::string filename = p.filename().string();
- // Create download URL
- std::string url = "/api/queue/job/" + jobInfo.id + "/output/" + filename;
- nlohmann::json fileInfo = {
- {"filename", filename},
- {"url", url},
- {"path", filePath}
- };
- outputUrls.push_back(fileInfo);
- }
- nlohmann::json response = {
- {"job", {
- {"id", jobInfo.id},
- {"status", statusStr},
- {"prompt", jobInfo.prompt},
- {"queued_time", queuedTime},
- {"start_time", startTime > 0 ? nlohmann::json(startTime) : nlohmann::json(nullptr)},
- {"end_time", endTime > 0 ? nlohmann::json(endTime) : nlohmann::json(nullptr)},
- {"position", jobInfo.position},
- {"outputs", outputUrls},
- {"error_message", jobInfo.errorMessage},
- {"progress", jobInfo.progress}
- }}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Job status check failed: ") + e.what(), 500);
- }
- }
- void Server::handleCancelJob(const httplib::Request& req, httplib::Response& res) {
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500);
- return;
- }
- // Parse JSON request body
- nlohmann::json requestJson = nlohmann::json::parse(req.body);
- // Validate required fields
- if (!requestJson.contains("job_id") || !requestJson["job_id"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'job_id' field", 400);
- return;
- }
- std::string jobId = requestJson["job_id"];
- // Try to cancel the job
- bool cancelled = m_generationQueue->cancelJob(jobId);
- if (cancelled) {
- nlohmann::json response = {
- {"status", "success"},
- {"message", "Job cancelled successfully"},
- {"job_id", jobId}
- };
- sendJsonResponse(res, response);
- } else {
- nlohmann::json response = {
- {"status", "error"},
- {"message", "Job not found or already processing"},
- {"job_id", jobId}
- };
- sendJsonResponse(res, response, 404);
- }
- } catch (const nlohmann::json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Job cancellation failed: ") + e.what(), 500);
- }
- }
- void Server::handleClearQueue(const httplib::Request& /*req*/, httplib::Response& res) {
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500);
- return;
- }
- // Clear the queue
- m_generationQueue->clearQueue();
- nlohmann::json response = {
- {"status", "success"},
- {"message", "Queue cleared successfully"}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Queue clear failed: ") + e.what(), 500);
- }
- }
- void Server::handleDownloadOutput(const httplib::Request& req, httplib::Response& res) {
- try {
- // Extract job ID and filename from URL path
- if (req.matches.size() < 3) {
- sendErrorResponse(res, "Invalid request: job ID and filename required", 400, "INVALID_REQUEST", "");
- return;
- }
- std::string jobId = req.matches[1];
- std::string filename = req.matches[2];
- // Validate inputs
- if (jobId.empty() || filename.empty()) {
- sendErrorResponse(res, "Job ID and filename cannot be empty", 400, "INVALID_PARAMETERS", "");
- return;
- }
- // Construct absolute file path using the same logic as when saving:
- // {outputDir}/{jobId}/{filename}
- std::string fullPath = std::filesystem::absolute(m_outputDir + "/" + jobId + "/" + filename).string();
- // Log the request for debugging
- std::cout << "Image download request: jobId=" << jobId << ", filename=" << filename
- << ", fullPath=" << fullPath << std::endl;
- // Check if file exists
- if (!std::filesystem::exists(fullPath)) {
- std::cerr << "Output file not found: " << fullPath << std::endl;
- sendErrorResponse(res, "Output file not found: " + filename, 404, "FILE_NOT_FOUND", "");
- return;
- }
- // Check file size to detect zero-byte files
- auto fileSize = std::filesystem::file_size(fullPath);
- if (fileSize == 0) {
- std::cerr << "Output file is zero bytes: " << fullPath << std::endl;
- sendErrorResponse(res, "Output file is empty (corrupted generation)", 500, "EMPTY_FILE", "");
- return;
- }
- // Check if file is accessible
- std::ifstream file(fullPath, std::ios::binary);
- if (!file.is_open()) {
- std::cerr << "Failed to open output file: " << fullPath << std::endl;
- sendErrorResponse(res, "Output file not accessible", 500, "FILE_ACCESS_ERROR", "");
- return;
- }
- // Read file contents
- std::string fileContent;
- try {
- fileContent = std::string(
- std::istreambuf_iterator<char>(file),
- std::istreambuf_iterator<char>()
- );
- file.close();
- } catch (const std::exception& e) {
- std::cerr << "Failed to read file content: " << e.what() << std::endl;
- sendErrorResponse(res, "Failed to read file content", 500, "FILE_READ_ERROR", "");
- return;
- }
- // Verify we actually read data
- if (fileContent.empty()) {
- std::cerr << "File content is empty after read: " << fullPath << std::endl;
- sendErrorResponse(res, "File content is empty after read", 500, "EMPTY_CONTENT", "");
- return;
- }
- // Determine content type based on file extension
- std::string contentType = "application/octet-stream";
- if (Utils::endsWith(filename, ".png")) {
- contentType = "image/png";
- } else if (Utils::endsWith(filename, ".jpg") || Utils::endsWith(filename, ".jpeg")) {
- contentType = "image/jpeg";
- } else if (Utils::endsWith(filename, ".mp4")) {
- contentType = "video/mp4";
- } else if (Utils::endsWith(filename, ".gif")) {
- contentType = "image/gif";
- } else if (Utils::endsWith(filename, ".webp")) {
- contentType = "image/webp";
- }
- // Set response headers for proper browser handling
- res.set_header("Content-Type", contentType);
- res.set_header("Content-Length", std::to_string(fileContent.length()));
- res.set_header("Cache-Control", "public, max-age=3600"); // Cache for 1 hour
- res.set_header("Access-Control-Allow-Origin", "*"); // CORS for image access
- // Uncomment if you want to force download instead of inline display:
- // res.set_header("Content-Disposition", "attachment; filename=\"" + filename + "\"");
- // Set the content
- res.set_content(fileContent, contentType);
- res.status = 200;
- std::cout << "Successfully served image: " << filename << " (" << fileContent.length() << " bytes)" << std::endl;
- } catch (const std::exception& e) {
- std::cerr << "Exception in handleDownloadOutput: " << e.what() << std::endl;
- sendErrorResponse(res, std::string("Failed to download file: ") + e.what(), 500, "DOWNLOAD_ERROR", "");
- }
- }
- void Server::handleJobOutput(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
-
- try {
- // Extract job ID from URL path
- if (req.matches.size() < 2) {
- sendErrorResponse(res, "Invalid request: job ID required", 400, "INVALID_REQUEST", requestId);
- return;
- }
- std::string jobId = req.matches[1].str();
- // Validate job ID
- if (jobId.empty()) {
- sendErrorResponse(res, "Job ID cannot be empty", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Log the request for debugging
- std::cout << "Job output request: jobId=" << jobId << std::endl;
- // Get job information to check if it exists and is completed
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
- return;
- }
- auto jobInfo = m_generationQueue->getJobInfo(jobId);
- if (jobInfo.id.empty()) {
- sendErrorResponse(res, "Job not found", 404, "JOB_NOT_FOUND", requestId);
- return;
- }
- // Check if job is completed
- if (jobInfo.status != GenerationStatus::COMPLETED) {
- std::string statusStr;
- switch (jobInfo.status) {
- case GenerationStatus::QUEUED: statusStr = "queued"; break;
- case GenerationStatus::PROCESSING: statusStr = "processing"; break;
- case GenerationStatus::FAILED: statusStr = "failed"; break;
- default: statusStr = "unknown"; break;
- }
-
- nlohmann::json response = {
- {"error", {
- {"message", "Job not completed yet"},
- {"status_code", 400},
- {"error_code", "JOB_NOT_COMPLETED"},
- {"request_id", requestId},
- {"timestamp", std::chrono::duration_cast<std::chrono::seconds>(
- std::chrono::system_clock::now().time_since_epoch()).count()},
- {"job_status", statusStr}
- }}
- };
- sendJsonResponse(res, response, 400);
- return;
- }
- // Check if job has output files
- if (jobInfo.outputFiles.empty()) {
- sendErrorResponse(res, "No output files found for completed job", 404, "NO_OUTPUT_FILES", requestId);
- return;
- }
- // For simplicity, return the first output file
- // In a more complex implementation, we could return all files or allow file selection
- std::string firstOutputFile = jobInfo.outputFiles[0];
-
- // Extract filename from full path
- std::filesystem::path filePath(firstOutputFile);
- std::string filename = filePath.filename().string();
-
- // Construct absolute file path
- std::string fullPath = std::filesystem::absolute(firstOutputFile).string();
- // Check if file exists
- if (!std::filesystem::exists(fullPath)) {
- std::cerr << "Output file not found: " << fullPath << std::endl;
- sendErrorResponse(res, "Output file not found: " + filename, 404, "FILE_NOT_FOUND", requestId);
- return;
- }
- // Check file size to detect zero-byte files
- auto fileSize = std::filesystem::file_size(fullPath);
- if (fileSize == 0) {
- std::cerr << "Output file is zero bytes: " << fullPath << std::endl;
- sendErrorResponse(res, "Output file is empty (corrupted generation)", 500, "EMPTY_FILE", requestId);
- return;
- }
- // Check if file is accessible
- std::ifstream file(fullPath, std::ios::binary);
- if (!file.is_open()) {
- std::cerr << "Failed to open output file: " << fullPath << std::endl;
- sendErrorResponse(res, "Output file not accessible", 500, "FILE_ACCESS_ERROR", requestId);
- return;
- }
- // Read file contents
- std::string fileContent;
- try {
- fileContent = std::string(
- std::istreambuf_iterator<char>(file),
- std::istreambuf_iterator<char>()
- );
- file.close();
- } catch (const std::exception& e) {
- std::cerr << "Failed to read file content: " << e.what() << std::endl;
- sendErrorResponse(res, "Failed to read file content", 500, "FILE_READ_ERROR", requestId);
- return;
- }
- // Verify we actually read data
- if (fileContent.empty()) {
- std::cerr << "File content is empty after read: " << fullPath << std::endl;
- sendErrorResponse(res, "File content is empty after read", 500, "EMPTY_CONTENT", requestId);
- return;
- }
- // Determine content type based on file extension
- std::string contentType = "application/octet-stream";
- if (Utils::endsWith(filename, ".png")) {
- contentType = "image/png";
- } else if (Utils::endsWith(filename, ".jpg") || Utils::endsWith(filename, ".jpeg")) {
- contentType = "image/jpeg";
- } else if (Utils::endsWith(filename, ".mp4")) {
- contentType = "video/mp4";
- } else if (Utils::endsWith(filename, ".gif")) {
- contentType = "image/gif";
- } else if (Utils::endsWith(filename, ".webp")) {
- contentType = "image/webp";
- }
- // Set response headers for proper browser handling
- res.set_header("Content-Type", contentType);
- res.set_header("Content-Length", std::to_string(fileContent.length()));
- res.set_header("Cache-Control", "public, max-age=3600"); // Cache for 1 hour
- res.set_header("Access-Control-Allow-Origin", "*"); // CORS for image access
-
- // Set additional metadata headers
- res.set_header("X-Job-ID", jobId);
- res.set_header("X-Filename", filename);
- res.set_header("X-File-Size", std::to_string(fileSize));
-
- // If there are multiple files, indicate this
- if (jobInfo.outputFiles.size() > 1) {
- res.set_header("X-Total-Files", std::to_string(jobInfo.outputFiles.size()));
- res.set_header("X-File-Index", "1");
- }
- // Set the content
- res.set_content(fileContent, contentType);
- res.status = 200;
- std::cout << "Successfully served job output: jobId=" << jobId
- << ", filename=" << filename
- << " (" << fileContent.length() << " bytes)" << std::endl;
- } catch (const std::exception& e) {
- std::cerr << "Exception in handleJobOutput: " << e.what() << std::endl;
- sendErrorResponse(res, std::string("Failed to get job output: ") + e.what(), 500, "OUTPUT_ERROR", requestId);
- }
- }
- void Server::handleImageResize(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- // Parse JSON request body
- nlohmann::json requestJson = nlohmann::json::parse(req.body);
- // Validate required fields
- if (!requestJson.contains("image") || !requestJson["image"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'image' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- if (!requestJson.contains("width") || !requestJson["width"].is_number_integer()) {
- sendErrorResponse(res, "Missing or invalid 'width' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- if (!requestJson.contains("height") || !requestJson["height"].is_number_integer()) {
- sendErrorResponse(res, "Missing or invalid 'height' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- std::string imageInput = requestJson["image"];
- int targetWidth = requestJson["width"];
- int targetHeight = requestJson["height"];
- // Validate dimensions
- if (targetWidth < 1 || targetWidth > 4096) {
- sendErrorResponse(res, "Width must be between 1 and 4096", 400, "INVALID_DIMENSIONS", requestId);
- return;
- }
- if (targetHeight < 1 || targetHeight > 4096) {
- sendErrorResponse(res, "Height must be between 1 and 4096", 400, "INVALID_DIMENSIONS", requestId);
- return;
- }
- // Load the source image
- auto [imageData, sourceWidth, sourceHeight, sourceChannels, success, loadError] = loadImageFromInput(imageInput);
- if (!success) {
- sendErrorResponse(res, "Failed to load image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId);
- return;
- }
- // Convert image data to stb_image format for processing
- int channels = 3; // Force RGB
- size_t sourceSize = sourceWidth * sourceHeight * channels;
- std::vector<uint8_t> sourcePixels(sourceSize);
- std::memcpy(sourcePixels.data(), imageData.data(), std::min(imageData.size(), sourceSize));
- // Resize the image using stb_image_resize if available, otherwise use simple scaling
- std::vector<uint8_t> resizedPixels(targetWidth * targetHeight * channels);
- // Simple nearest-neighbor scaling for now (can be improved with better algorithms)
- float xScale = static_cast<float>(sourceWidth) / targetWidth;
- float yScale = static_cast<float>(sourceHeight) / targetHeight;
- for (int y = 0; y < targetHeight; y++) {
- for (int x = 0; x < targetWidth; x++) {
- int sourceX = static_cast<int>(x * xScale);
- int sourceY = static_cast<int>(y * yScale);
- // Clamp to source bounds
- sourceX = std::min(sourceX, sourceWidth - 1);
- sourceY = std::min(sourceY, sourceHeight - 1);
- for (int c = 0; c < channels; c++) {
- resizedPixels[(y * targetWidth + x) * channels + c] =
- sourcePixels[(sourceY * sourceWidth + sourceX) * channels + c];
- }
- }
- }
- // Convert resized image to base64
- std::string base64Data = Utils::base64Encode(resizedPixels);
- // Determine MIME type based on input
- std::string mimeType = "image/jpeg"; // default
- if (Utils::startsWith(imageInput, "data:image/png")) {
- mimeType = "image/png";
- } else if (Utils::startsWith(imageInput, "data:image/gif")) {
- mimeType = "image/gif";
- } else if (Utils::startsWith(imageInput, "data:image/webp")) {
- mimeType = "image/webp";
- } else if (Utils::startsWith(imageInput, "data:image/bmp")) {
- mimeType = "image/bmp";
- }
- // Create data URL format
- std::string dataUrl = "data:" + mimeType + ";base64," + base64Data;
- // Build response
- nlohmann::json response = {
- {"success", true},
- {"original_width", sourceWidth},
- {"original_height", sourceHeight},
- {"resized_width", targetWidth},
- {"resized_height", targetHeight},
- {"mime_type", mimeType},
- {"base64_data", dataUrl},
- {"file_size_bytes", resizedPixels.size()},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response, 200);
- std::cout << "Successfully resized image from " << sourceWidth << "x" << sourceHeight
- << " to " << targetWidth << "x" << targetHeight
- << " (" << resizedPixels.size() << " bytes)" << std::endl;
- } catch (const nlohmann::json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- std::cerr << "Exception in handleImageResize: " << e.what() << std::endl;
- sendErrorResponse(res, std::string("Failed to resize image: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleImageCrop(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- // Parse JSON request body
- nlohmann::json requestJson = nlohmann::json::parse(req.body);
- // Validate required fields
- if (!requestJson.contains("image") || !requestJson["image"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'image' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- if (!requestJson.contains("x") || !requestJson["x"].is_number_integer()) {
- sendErrorResponse(res, "Missing or invalid 'x' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- if (!requestJson.contains("y") || !requestJson["y"].is_number_integer()) {
- sendErrorResponse(res, "Missing or invalid 'y' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- if (!requestJson.contains("width") || !requestJson["width"].is_number_integer()) {
- sendErrorResponse(res, "Missing or invalid 'width' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- if (!requestJson.contains("height") || !requestJson["height"].is_number_integer()) {
- sendErrorResponse(res, "Missing or invalid 'height' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- std::string imageInput = requestJson["image"];
- int cropX = requestJson["x"];
- int cropY = requestJson["y"];
- int cropWidth = requestJson["width"];
- int cropHeight = requestJson["height"];
- // Load the source image
- auto [imageData, sourceWidth, sourceHeight, sourceChannels, success, loadError] = loadImageFromInput(imageInput);
- if (!success) {
- sendErrorResponse(res, "Failed to load image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId);
- return;
- }
- // Validate crop dimensions
- if (cropX < 0 || cropY < 0) {
- sendErrorResponse(res, "Crop coordinates must be non-negative", 400, "INVALID_CROP_AREA", requestId);
- return;
- }
- if (cropX + cropWidth > sourceWidth || cropY + cropHeight > sourceHeight) {
- sendErrorResponse(res, "Crop area exceeds image dimensions", 400, "INVALID_CROP_AREA", requestId);
- return;
- }
- if (cropWidth < 1 || cropHeight < 1) {
- sendErrorResponse(res, "Crop width and height must be at least 1", 400, "INVALID_CROP_AREA", requestId);
- return;
- }
- // Convert image data to stb_image format for processing
- int channels = 3; // Force RGB
- size_t sourceSize = sourceWidth * sourceHeight * channels;
- std::vector<uint8_t> sourcePixels(sourceSize);
- std::memcpy(sourcePixels.data(), imageData.data(), std::min(imageData.size(), sourceSize));
- // Crop the image
- std::vector<uint8_t> croppedPixels(cropWidth * cropHeight * channels);
- for (int y = 0; y < cropHeight; y++) {
- for (int x = 0; x < cropWidth; x++) {
- int sourceX = cropX + x;
- int sourceY = cropY + y;
- for (int c = 0; c < channels; c++) {
- croppedPixels[(y * cropWidth + x) * channels + c] =
- sourcePixels[(sourceY * sourceWidth + sourceX) * channels + c];
- }
- }
- }
- // Convert cropped image to base64
- std::string base64Data = Utils::base64Encode(croppedPixels);
- // Determine MIME type based on input
- std::string mimeType = "image/jpeg"; // default
- if (Utils::startsWith(imageInput, "data:image/png")) {
- mimeType = "image/png";
- } else if (Utils::startsWith(imageInput, "data:image/gif")) {
- mimeType = "image/gif";
- } else if (Utils::startsWith(imageInput, "data:image/webp")) {
- mimeType = "image/webp";
- } else if (Utils::startsWith(imageInput, "data:image/bmp")) {
- mimeType = "image/bmp";
- }
- // Create data URL format
- std::string dataUrl = "data:" + mimeType + ";base64," + base64Data;
- // Build response
- nlohmann::json response = {
- {"success", true},
- {"original_width", sourceWidth},
- {"original_height", sourceHeight},
- {"crop_x", cropX},
- {"crop_y", cropY},
- {"cropped_width", cropWidth},
- {"cropped_height", cropHeight},
- {"mime_type", mimeType},
- {"base64_data", dataUrl},
- {"file_size_bytes", croppedPixels.size()},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response, 200);
- std::cout << "Successfully cropped image from " << sourceWidth << "x" << sourceHeight
- << " to " << cropWidth << "x" << cropHeight
- << " at (" << cropX << "," << cropY << ")"
- << " (" << croppedPixels.size() << " bytes)" << std::endl;
- } catch (const nlohmann::json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- std::cerr << "Exception in handleImageCrop: " << e.what() << std::endl;
- sendErrorResponse(res, std::string("Failed to crop image: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleDownloadImageFromUrl(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- // Parse query parameters
- std::string imageUrl = req.get_param_value("url");
- if (imageUrl.empty()) {
- sendErrorResponse(res, "Missing 'url' parameter", 400, "MISSING_URL", requestId);
- return;
- }
- // Basic URL format validation
- if (!Utils::startsWith(imageUrl, "http://") && !Utils::startsWith(imageUrl, "https://")) {
- sendErrorResponse(res, "Invalid URL format. URL must start with http:// or https://", 400, "INVALID_URL_FORMAT", requestId);
- return;
- }
- // Extract filename from URL for content type detection
- std::string filename = imageUrl;
- size_t lastSlash = imageUrl.find_last_of('/');
- if (lastSlash != std::string::npos) {
- filename = imageUrl.substr(lastSlash + 1);
- }
- // Remove query parameters and fragments
- size_t questionMark = filename.find('?');
- if (questionMark != std::string::npos) {
- filename = filename.substr(0, questionMark);
- }
- size_t hashMark = filename.find('#');
- if (hashMark != std::string::npos) {
- filename = filename.substr(0, hashMark);
- }
- // Check if URL has image extension
- std::string extension;
- size_t lastDot = filename.find_last_of('.');
- if (lastDot != std::string::npos) {
- extension = filename.substr(lastDot + 1);
- std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower);
- }
- // Validate image extension
- const std::vector<std::string> validExtensions = {"jpg", "jpeg", "png", "gif", "webp", "bmp"};
- if (extension.empty() || std::find(validExtensions.begin(), validExtensions.end(), extension) == validExtensions.end()) {
- sendErrorResponse(res, "URL must point to an image file with a valid extension: " +
- std::accumulate(validExtensions.begin(), validExtensions.end(), std::string(),
- [](const std::string& a, const std::string& b) {
- return a.empty() ? b : a + ", " + b;
- }), 400, "INVALID_IMAGE_EXTENSION", requestId);
- return;
- }
- // Load image using existing loadImageFromInput function
- auto [imageData, width, height, channels, success, error] = loadImageFromInput(imageUrl);
- if (!success) {
- sendErrorResponse(res, "Failed to download image from URL: " + error, 400, "IMAGE_DOWNLOAD_FAILED", requestId);
- return;
- }
- // Convert image data to base64
- std::string base64Data = Utils::base64Encode(imageData);
- // Determine MIME type based on extension
- std::string mimeType = "image/jpeg"; // default
- if (extension == "png") {
- mimeType = "image/png";
- } else if (extension == "gif") {
- mimeType = "image/gif";
- } else if (extension == "webp") {
- mimeType = "image/webp";
- } else if (extension == "bmp") {
- mimeType = "image/bmp";
- } else if (extension == "jpg" || extension == "jpeg") {
- mimeType = "image/jpeg";
- }
- // Create data URL format
- std::string dataUrl = "data:" + mimeType + ";base64," + base64Data;
- // Build response
- nlohmann::json response = {
- {"success", true},
- {"url", imageUrl},
- {"filename", filename},
- {"width", width},
- {"height", height},
- {"channels", channels},
- {"mime_type", mimeType},
- {"base64_data", dataUrl},
- {"file_size_bytes", imageData.size()},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response, 200);
- std::cout << "Successfully downloaded and encoded image from URL: " << imageUrl
- << " (" << width << "x" << height << ", " << imageData.size() << " bytes)" << std::endl;
- } catch (const nlohmann::json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- std::cerr << "Exception in handleDownloadImageFromUrl: " << e.what() << std::endl;
- sendErrorResponse(res, std::string("Failed to download image from URL: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::sendJsonResponse(httplib::Response& res, const nlohmann::json& json, int status_code) {
- res.set_header("Content-Type", "application/json");
- res.status = status_code;
- res.body = json.dump();
- }
- void Server::sendErrorResponse(httplib::Response& res, const std::string& message, int status_code,
- const std::string& error_code, const std::string& request_id) {
- nlohmann::json errorResponse = {
- {"error", {
- {"message", message},
- {"status_code", status_code},
- {"error_code", error_code},
- {"request_id", request_id},
- {"timestamp", std::chrono::duration_cast<std::chrono::seconds>(
- std::chrono::system_clock::now().time_since_epoch()).count()}
- }}
- };
- sendJsonResponse(res, errorResponse, status_code);
- }
- std::pair<bool, std::string> Server::validateGenerationParameters(const nlohmann::json& params) {
- // Validate required fields
- if (!params.contains("prompt") || !params["prompt"].is_string()) {
- return {false, "Missing or invalid 'prompt' field"};
- }
- const std::string& prompt = params["prompt"];
- if (prompt.empty()) {
- return {false, "Prompt cannot be empty"};
- }
- if (prompt.length() > m_config.maxPromptLength) {
- return {false, "Prompt too long (max " + std::to_string(m_config.maxPromptLength) + " characters)"};
- }
- // Validate negative prompt if present
- if (params.contains("negative_prompt")) {
- if (!params["negative_prompt"].is_string()) {
- return {false, "Invalid 'negative_prompt' field, must be string"};
- }
- if (params["negative_prompt"].get<std::string>().length() > m_config.maxNegativePromptLength) {
- return {false, "Negative prompt too long (max " + std::to_string(m_config.maxNegativePromptLength) + " characters)"};
- }
- }
- // Validate width
- if (params.contains("width")) {
- if (!params["width"].is_number_integer()) {
- return {false, "Invalid 'width' field, must be integer"};
- }
- int width = params["width"];
- if (width < 64 || width > 2048 || width % 64 != 0) {
- return {false, "Width must be between 64 and 2048 and divisible by 64"};
- }
- }
- // Validate height
- if (params.contains("height")) {
- if (!params["height"].is_number_integer()) {
- return {false, "Invalid 'height' field, must be integer"};
- }
- int height = params["height"];
- if (height < 64 || height > 2048 || height % 64 != 0) {
- return {false, "Height must be between 64 and 2048 and divisible by 64"};
- }
- }
- // Validate batch count
- if (params.contains("batch_count")) {
- if (!params["batch_count"].is_number_integer()) {
- return {false, "Invalid 'batch_count' field, must be integer"};
- }
- int batchCount = params["batch_count"];
- if (batchCount < 1 || batchCount > 100) {
- return {false, "Batch count must be between 1 and 100"};
- }
- }
- // Validate steps
- if (params.contains("steps")) {
- if (!params["steps"].is_number_integer()) {
- return {false, "Invalid 'steps' field, must be integer"};
- }
- int steps = params["steps"];
- if (steps < 1 || steps > 150) {
- return {false, "Steps must be between 1 and 150"};
- }
- }
- // Validate CFG scale
- if (params.contains("cfg_scale")) {
- if (!params["cfg_scale"].is_number()) {
- return {false, "Invalid 'cfg_scale' field, must be number"};
- }
- float cfgScale = params["cfg_scale"];
- if (cfgScale < 1.0f || cfgScale > 30.0f) {
- return {false, "CFG scale must be between 1.0 and 30.0"};
- }
- }
- // Validate seed
- if (params.contains("seed")) {
- if (!params["seed"].is_string() && !params["seed"].is_number_integer()) {
- return {false, "Invalid 'seed' field, must be string or integer"};
- }
- }
- // Validate sampling method
- if (params.contains("sampling_method")) {
- if (!params["sampling_method"].is_string()) {
- return {false, "Invalid 'sampling_method' field, must be string"};
- }
- std::string method = params["sampling_method"];
- std::vector<std::string> validMethods = {
- "euler", "euler_a", "heun", "dpm2", "dpm++2s_a", "dpm++2m",
- "dpm++2mv2", "ipndm", "ipndm_v", "lcm", "ddim_trailing", "tcd", "default"
- };
- if (std::find(validMethods.begin(), validMethods.end(), method) == validMethods.end()) {
- return {false, "Invalid sampling method"};
- }
- }
- // Validate scheduler
- if (params.contains("scheduler")) {
- if (!params["scheduler"].is_string()) {
- return {false, "Invalid 'scheduler' field, must be string"};
- }
- std::string scheduler = params["scheduler"];
- std::vector<std::string> validSchedulers = {
- "discrete", "karras", "exponential", "ays", "gits",
- "smoothstep", "sgm_uniform", "simple", "default"
- };
- if (std::find(validSchedulers.begin(), validSchedulers.end(), scheduler) == validSchedulers.end()) {
- return {false, "Invalid scheduler"};
- }
- }
- // Validate strength
- if (params.contains("strength")) {
- if (!params["strength"].is_number()) {
- return {false, "Invalid 'strength' field, must be number"};
- }
- float strength = params["strength"];
- if (strength < 0.0f || strength > 1.0f) {
- return {false, "Strength must be between 0.0 and 1.0"};
- }
- }
- // Validate control strength
- if (params.contains("control_strength")) {
- if (!params["control_strength"].is_number()) {
- return {false, "Invalid 'control_strength' field, must be number"};
- }
- float controlStrength = params["control_strength"];
- if (controlStrength < 0.0f || controlStrength > 1.0f) {
- return {false, "Control strength must be between 0.0 and 1.0"};
- }
- }
- // Validate clip skip
- if (params.contains("clip_skip")) {
- if (!params["clip_skip"].is_number_integer()) {
- return {false, "Invalid 'clip_skip' field, must be integer"};
- }
- int clipSkip = params["clip_skip"];
- if (clipSkip < -1 || clipSkip > 12) {
- return {false, "Clip skip must be between -1 and 12"};
- }
- }
- // Validate threads
- if (params.contains("threads")) {
- if (!params["threads"].is_number_integer()) {
- return {false, "Invalid 'threads' field, must be integer"};
- }
- int threads = params["threads"];
- if (threads < -1 || threads > 32) {
- return {false, "Threads must be between -1 (auto) and 32"};
- }
- }
- return {true, ""};
- }
- SamplingMethod Server::parseSamplingMethod(const std::string& method) {
- if (method == "euler") return SamplingMethod::EULER;
- else if (method == "euler_a") return SamplingMethod::EULER_A;
- else if (method == "heun") return SamplingMethod::HEUN;
- else if (method == "dpm2") return SamplingMethod::DPM2;
- else if (method == "dpm++2s_a") return SamplingMethod::DPMPP2S_A;
- else if (method == "dpm++2m") return SamplingMethod::DPMPP2M;
- else if (method == "dpm++2mv2") return SamplingMethod::DPMPP2MV2;
- else if (method == "ipndm") return SamplingMethod::IPNDM;
- else if (method == "ipndm_v") return SamplingMethod::IPNDM_V;
- else if (method == "lcm") return SamplingMethod::LCM;
- else if (method == "ddim_trailing") return SamplingMethod::DDIM_TRAILING;
- else if (method == "tcd") return SamplingMethod::TCD;
- else return SamplingMethod::DEFAULT;
- }
- Scheduler Server::parseScheduler(const std::string& scheduler) {
- if (scheduler == "discrete") return Scheduler::DISCRETE;
- else if (scheduler == "karras") return Scheduler::KARRAS;
- else if (scheduler == "exponential") return Scheduler::EXPONENTIAL;
- else if (scheduler == "ays") return Scheduler::AYS;
- else if (scheduler == "gits") return Scheduler::GITS;
- else if (scheduler == "smoothstep") return Scheduler::SMOOTHSTEP;
- else if (scheduler == "sgm_uniform") return Scheduler::SGM_UNIFORM;
- else if (scheduler == "simple") return Scheduler::SIMPLE;
- else return Scheduler::DEFAULT;
- }
- std::string Server::generateRequestId() {
- std::random_device rd;
- std::mt19937 gen(rd());
- std::uniform_int_distribution<> dis(100000, 999999);
- return "req_" + std::to_string(dis(gen));
- }
- std::tuple<std::vector<uint8_t>, int, int, int, bool, std::string>
- Server::loadImageFromInput(const std::string& input) {
- std::vector<uint8_t> imageData;
- int width = 0, height = 0, channels = 0;
- // Auto-detect input source type
- // 1. Check if input is a URL (starts with http:// or https://)
- if (Utils::startsWith(input, "http://") || Utils::startsWith(input, "https://")) {
- // Parse URL to extract host and path
- std::string url = input;
- std::string scheme, host, path;
- int port = 80;
- // Determine scheme and port
- if (Utils::startsWith(url, "https://")) {
- scheme = "https";
- port = 443;
- url = url.substr(8); // Remove "https://"
- } else {
- scheme = "http";
- port = 80;
- url = url.substr(7); // Remove "http://"
- }
- // Extract host and path
- size_t slashPos = url.find('/');
- if (slashPos != std::string::npos) {
- host = url.substr(0, slashPos);
- path = url.substr(slashPos);
- } else {
- host = url;
- path = "/";
- }
- // Check for custom port
- size_t colonPos = host.find(':');
- if (colonPos != std::string::npos) {
- try {
- port = std::stoi(host.substr(colonPos + 1));
- host = host.substr(0, colonPos);
- } catch (...) {
- return {imageData, 0, 0, 0, false, "Invalid port in URL"};
- }
- }
- // Download image using httplib
- try {
- httplib::Result res;
- if (scheme == "https") {
- #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
- httplib::SSLClient client(host, port);
- client.set_follow_location(true);
- client.set_connection_timeout(30, 0); // 30 seconds
- client.set_read_timeout(60, 0); // 60 seconds
- res = client.Get(path.c_str());
- #else
- return {imageData, 0, 0, 0, false, "HTTPS not supported (OpenSSL not available)"};
- #endif
- } else {
- httplib::Client client(host, port);
- client.set_follow_location(true);
- client.set_connection_timeout(30, 0); // 30 seconds
- client.set_read_timeout(60, 0); // 60 seconds
- res = client.Get(path.c_str());
- }
- if (!res) {
- return {imageData, 0, 0, 0, false, "Failed to download image from URL: Connection error"};
- }
- if (res->status != 200) {
- return {imageData, 0, 0, 0, false, "Failed to download image from URL: HTTP " + std::to_string(res->status)};
- }
- // Convert response body to vector
- std::vector<uint8_t> downloadedData(res->body.begin(), res->body.end());
- // Load image from memory
- int w, h, c;
- unsigned char* pixels = stbi_load_from_memory(
- downloadedData.data(),
- downloadedData.size(),
- &w, &h, &c,
- 3 // Force RGB
- );
- if (!pixels) {
- return {imageData, 0, 0, 0, false, "Failed to decode image from URL"};
- }
- width = w;
- height = h;
- channels = 3;
- size_t dataSize = width * height * channels;
- imageData.resize(dataSize);
- std::memcpy(imageData.data(), pixels, dataSize);
- stbi_image_free(pixels);
- } catch (const std::exception& e) {
- return {imageData, 0, 0, 0, false, "Failed to download image from URL: " + std::string(e.what())};
- }
- }
- // 2. Check if input is base64 encoded data URI (starts with "data:image")
- else if (Utils::startsWith(input, "data:image")) {
- // Extract base64 data after the comma
- size_t commaPos = input.find(',');
- if (commaPos == std::string::npos) {
- return {imageData, 0, 0, 0, false, "Invalid data URI format"};
- }
- std::string base64Data = input.substr(commaPos + 1);
- std::vector<uint8_t> decodedData = Utils::base64Decode(base64Data);
- // Load image from memory using stb_image
- int w, h, c;
- unsigned char* pixels = stbi_load_from_memory(
- decodedData.data(),
- decodedData.size(),
- &w, &h, &c,
- 3 // Force RGB
- );
- if (!pixels) {
- return {imageData, 0, 0, 0, false, "Failed to decode image from base64 data URI"};
- }
- width = w;
- height = h;
- channels = 3; // We forced RGB
- // Copy pixel data
- size_t dataSize = width * height * channels;
- imageData.resize(dataSize);
- std::memcpy(imageData.data(), pixels, dataSize);
- stbi_image_free(pixels);
- }
- // 3. Check if input is raw base64 (long string without slashes, likely base64)
- else if (input.length() > 100 && input.find('/') == std::string::npos && input.find('.') == std::string::npos) {
- // Likely raw base64 without data URI prefix
- std::vector<uint8_t> decodedData = Utils::base64Decode(input);
- int w, h, c;
- unsigned char* pixels = stbi_load_from_memory(
- decodedData.data(),
- decodedData.size(),
- &w, &h, &c,
- 3 // Force RGB
- );
- if (!pixels) {
- return {imageData, 0, 0, 0, false, "Failed to decode image from base64"};
- }
- width = w;
- height = h;
- channels = 3;
- size_t dataSize = width * height * channels;
- imageData.resize(dataSize);
- std::memcpy(imageData.data(), pixels, dataSize);
- stbi_image_free(pixels);
- }
- // 4. Treat as local file path
- else {
- int w, h, c;
- unsigned char* pixels = stbi_load(input.c_str(), &w, &h, &c, 3);
- if (!pixels) {
- return {imageData, 0, 0, 0, false, "Failed to load image from file: " + input};
- }
- width = w;
- height = h;
- channels = 3;
- size_t dataSize = width * height * channels;
- imageData.resize(dataSize);
- std::memcpy(imageData.data(), pixels, dataSize);
- stbi_image_free(pixels);
- }
- return {imageData, width, height, channels, true, ""};
- }
- std::string Server::samplingMethodToString(SamplingMethod method) {
- switch (method) {
- case SamplingMethod::EULER: return "euler";
- case SamplingMethod::EULER_A: return "euler_a";
- case SamplingMethod::HEUN: return "heun";
- case SamplingMethod::DPM2: return "dpm2";
- case SamplingMethod::DPMPP2S_A: return "dpm++2s_a";
- case SamplingMethod::DPMPP2M: return "dpm++2m";
- case SamplingMethod::DPMPP2MV2: return "dpm++2mv2";
- case SamplingMethod::IPNDM: return "ipndm";
- case SamplingMethod::IPNDM_V: return "ipndm_v";
- case SamplingMethod::LCM: return "lcm";
- case SamplingMethod::DDIM_TRAILING: return "ddim_trailing";
- case SamplingMethod::TCD: return "tcd";
- default: return "default";
- }
- }
- std::string Server::schedulerToString(Scheduler scheduler) {
- switch (scheduler) {
- case Scheduler::DISCRETE: return "discrete";
- case Scheduler::KARRAS: return "karras";
- case Scheduler::EXPONENTIAL: return "exponential";
- case Scheduler::AYS: return "ays";
- case Scheduler::GITS: return "gits";
- case Scheduler::SMOOTHSTEP: return "smoothstep";
- case Scheduler::SGM_UNIFORM: return "sgm_uniform";
- case Scheduler::SIMPLE: return "simple";
- default: return "default";
- }
- }
- uint64_t Server::estimateGenerationTime(const GenerationRequest& request) {
- // Basic estimation based on parameters
- uint64_t baseTime = 1000; // 1 second base time
- // Factor in steps
- baseTime *= request.steps;
- // Factor in resolution
- double resolutionFactor = (request.width * request.height) / (512.0 * 512.0);
- baseTime = static_cast<uint64_t>(baseTime * resolutionFactor);
- // Factor in batch count
- baseTime *= request.batchCount;
- // Adjust for sampling method (some are faster than others)
- switch (request.samplingMethod) {
- case SamplingMethod::LCM:
- baseTime /= 4; // LCM is much faster
- break;
- case SamplingMethod::EULER:
- case SamplingMethod::EULER_A:
- baseTime *= 0.8; // Euler methods are faster
- break;
- case SamplingMethod::DPM2:
- case SamplingMethod::DPMPP2S_A:
- baseTime *= 1.2; // DPM methods are slower
- break;
- default:
- break;
- }
- return baseTime;
- }
- size_t Server::estimateMemoryUsage(const GenerationRequest& request) {
- // Basic memory estimation in bytes
- size_t baseMemory = 1024 * 1024 * 1024; // 1GB base
- // Factor in resolution
- double resolutionFactor = (request.width * request.height) / (512.0 * 512.0);
- baseMemory = static_cast<size_t>(baseMemory * resolutionFactor);
- // Factor in batch count
- baseMemory *= request.batchCount;
- // Additional memory for certain features
- if (request.diffusionFlashAttn) {
- baseMemory += 512 * 1024 * 1024; // Extra 512MB for flash attention
- }
- if (!request.controlNetPath.empty()) {
- baseMemory += 1024 * 1024 * 1024; // Extra 1GB for ControlNet
- }
- return baseMemory;
- }
- // Specialized generation endpoints
- void Server::handleText2Img(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
- return;
- }
- nlohmann::json requestJson = nlohmann::json::parse(req.body);
- // Validate required fields for text2img
- if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Validate all parameters
- auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
- if (!isValid) {
- sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Check if any model is loaded
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Get currently loaded checkpoint model
- auto allModels = m_modelManager->getAllModels();
- std::string loadedModelName;
- for (const auto& [modelName, modelInfo] : allModels) {
- if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
- loadedModelName = modelName;
- break;
- }
- }
- if (loadedModelName.empty()) {
- sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
- return;
- }
- // Create generation request specifically for text2img
- GenerationRequest genRequest;
- genRequest.id = requestId;
- genRequest.modelName = loadedModelName; // Use the currently loaded model
- genRequest.prompt = requestJson["prompt"];
- genRequest.negativePrompt = requestJson.value("negative_prompt", "");
- genRequest.width = requestJson.value("width", 512);
- genRequest.height = requestJson.value("height", 512);
- genRequest.batchCount = requestJson.value("batch_count", 1);
- genRequest.steps = requestJson.value("steps", 20);
- genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
- genRequest.seed = requestJson.value("seed", "random");
- // Parse optional parameters
- if (requestJson.contains("sampling_method")) {
- genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
- }
- if (requestJson.contains("scheduler")) {
- genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
- }
- // Set text2img specific defaults
- genRequest.strength = 1.0f; // Full strength for text2img
- // Optional VAE model
- if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
- std::string vaeModelId = requestJson["vae_model"];
- if (!vaeModelId.empty()) {
- auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
- if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
- genRequest.vaePath = vaeInfo.path;
- } else {
- sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
- return;
- }
- }
- }
- // Optional TAESD model
- if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
- std::string taesdModelId = requestJson["taesd_model"];
- if (!taesdModelId.empty()) {
- auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
- if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
- genRequest.taesdPath = taesdInfo.path;
- } else {
- sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
- return;
- }
- }
- }
- // Enqueue request
- auto future = m_generationQueue->enqueueRequest(genRequest);
- nlohmann::json params = {
- {"prompt", genRequest.prompt},
- {"negative_prompt", genRequest.negativePrompt},
- {"model", genRequest.modelName},
- {"width", genRequest.width},
- {"height", genRequest.height},
- {"batch_count", genRequest.batchCount},
- {"steps", genRequest.steps},
- {"cfg_scale", genRequest.cfgScale},
- {"seed", genRequest.seed},
- {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
- {"scheduler", schedulerToString(genRequest.scheduler)}
- };
- // Add VAE/TAESD if specified
- if (!genRequest.vaePath.empty()) {
- params["vae_model"] = requestJson.value("vae_model", "");
- }
- if (!genRequest.taesdPath.empty()) {
- params["taesd_model"] = requestJson.value("taesd_model", "");
- }
- nlohmann::json response = {
- {"request_id", requestId},
- {"status", "queued"},
- {"message", "Text-to-image generation request queued successfully"},
- {"queue_position", m_generationQueue->getQueueSize()},
- {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
- {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
- {"type", "text2img"},
- {"parameters", params}
- };
- sendJsonResponse(res, response, 202);
- } catch (const nlohmann::json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Text-to-image request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleImg2Img(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
- return;
- }
- nlohmann::json requestJson = nlohmann::json::parse(req.body);
- // Validate required fields for img2img
- if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- if (!requestJson.contains("init_image") || !requestJson["init_image"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'init_image' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Validate all parameters
- auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
- if (!isValid) {
- sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Check if any model is loaded
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Get currently loaded checkpoint model
- auto allModels = m_modelManager->getAllModels();
- std::string loadedModelName;
- for (const auto& [modelName, modelInfo] : allModels) {
- if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
- loadedModelName = modelName;
- break;
- }
- }
- if (loadedModelName.empty()) {
- sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
- return;
- }
- // Load the init image
- std::string initImageInput = requestJson["init_image"];
- auto [imageData, imgWidth, imgHeight, imgChannels, success, loadError] = loadImageFromInput(initImageInput);
- if (!success) {
- sendErrorResponse(res, "Failed to load init image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId);
- return;
- }
- // Create generation request specifically for img2img
- GenerationRequest genRequest;
- genRequest.id = requestId;
- genRequest.requestType = GenerationRequest::RequestType::IMG2IMG;
- genRequest.modelName = loadedModelName; // Use the currently loaded model
- genRequest.prompt = requestJson["prompt"];
- genRequest.negativePrompt = requestJson.value("negative_prompt", "");
- genRequest.width = requestJson.value("width", imgWidth); // Default to input image dimensions
- genRequest.height = requestJson.value("height", imgHeight);
- genRequest.batchCount = requestJson.value("batch_count", 1);
- genRequest.steps = requestJson.value("steps", 20);
- genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
- genRequest.seed = requestJson.value("seed", "random");
- genRequest.strength = requestJson.value("strength", 0.75f);
- // Set init image data
- genRequest.initImageData = imageData;
- genRequest.initImageWidth = imgWidth;
- genRequest.initImageHeight = imgHeight;
- genRequest.initImageChannels = imgChannels;
- // Parse optional parameters
- if (requestJson.contains("sampling_method")) {
- genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
- }
- if (requestJson.contains("scheduler")) {
- genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
- }
- // Optional VAE model
- if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
- std::string vaeModelId = requestJson["vae_model"];
- if (!vaeModelId.empty()) {
- auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
- if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
- genRequest.vaePath = vaeInfo.path;
- } else {
- sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
- return;
- }
- }
- }
- // Optional TAESD model
- if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
- std::string taesdModelId = requestJson["taesd_model"];
- if (!taesdModelId.empty()) {
- auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
- if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
- genRequest.taesdPath = taesdInfo.path;
- } else {
- sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
- return;
- }
- }
- }
- // Enqueue request
- auto future = m_generationQueue->enqueueRequest(genRequest);
- nlohmann::json params = {
- {"prompt", genRequest.prompt},
- {"negative_prompt", genRequest.negativePrompt},
- {"init_image", requestJson["init_image"]},
- {"model", genRequest.modelName},
- {"width", genRequest.width},
- {"height", genRequest.height},
- {"batch_count", genRequest.batchCount},
- {"steps", genRequest.steps},
- {"cfg_scale", genRequest.cfgScale},
- {"seed", genRequest.seed},
- {"strength", genRequest.strength},
- {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
- {"scheduler", schedulerToString(genRequest.scheduler)}
- };
- // Add VAE/TAESD if specified
- if (!genRequest.vaePath.empty()) {
- params["vae_model"] = requestJson.value("vae_model", "");
- }
- if (!genRequest.taesdPath.empty()) {
- params["taesd_model"] = requestJson.value("taesd_model", "");
- }
- nlohmann::json response = {
- {"request_id", requestId},
- {"status", "queued"},
- {"message", "Image-to-image generation request queued successfully"},
- {"queue_position", m_generationQueue->getQueueSize()},
- {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
- {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
- {"type", "img2img"},
- {"parameters", params}
- };
- sendJsonResponse(res, response, 202);
- } catch (const nlohmann::json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Image-to-image request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleControlNet(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
- return;
- }
- nlohmann::json requestJson = nlohmann::json::parse(req.body);
- // Validate required fields for ControlNet
- if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- if (!requestJson.contains("control_image") || !requestJson["control_image"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'control_image' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Validate all parameters
- auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
- if (!isValid) {
- sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Check if any model is loaded
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Get currently loaded checkpoint model
- auto allModels = m_modelManager->getAllModels();
- std::string loadedModelName;
- for (const auto& [modelName, modelInfo] : allModels) {
- if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
- loadedModelName = modelName;
- break;
- }
- }
- if (loadedModelName.empty()) {
- sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
- return;
- }
- // Create generation request specifically for ControlNet
- GenerationRequest genRequest;
- genRequest.id = requestId;
- genRequest.modelName = loadedModelName; // Use the currently loaded model
- genRequest.prompt = requestJson["prompt"];
- genRequest.negativePrompt = requestJson.value("negative_prompt", "");
- genRequest.width = requestJson.value("width", 512);
- genRequest.height = requestJson.value("height", 512);
- genRequest.batchCount = requestJson.value("batch_count", 1);
- genRequest.steps = requestJson.value("steps", 20);
- genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
- genRequest.seed = requestJson.value("seed", "random");
- genRequest.controlStrength = requestJson.value("control_strength", 0.9f);
- genRequest.controlNetPath = requestJson.value("control_net_model", "");
- // Parse optional parameters
- if (requestJson.contains("sampling_method")) {
- genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
- }
- if (requestJson.contains("scheduler")) {
- genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
- }
- // Optional VAE model
- if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
- std::string vaeModelId = requestJson["vae_model"];
- if (!vaeModelId.empty()) {
- auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
- if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
- genRequest.vaePath = vaeInfo.path;
- } else {
- sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
- return;
- }
- }
- }
- // Optional TAESD model
- if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
- std::string taesdModelId = requestJson["taesd_model"];
- if (!taesdModelId.empty()) {
- auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
- if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
- genRequest.taesdPath = taesdInfo.path;
- } else {
- sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
- return;
- }
- }
- }
- // Store control image path (would be handled in actual implementation)
- genRequest.outputPath = requestJson.value("control_image", "");
- // Enqueue request
- auto future = m_generationQueue->enqueueRequest(genRequest);
- nlohmann::json params = {
- {"prompt", genRequest.prompt},
- {"negative_prompt", genRequest.negativePrompt},
- {"control_image", requestJson["control_image"]},
- {"control_net_model", genRequest.controlNetPath},
- {"model", genRequest.modelName},
- {"width", genRequest.width},
- {"height", genRequest.height},
- {"batch_count", genRequest.batchCount},
- {"steps", genRequest.steps},
- {"cfg_scale", genRequest.cfgScale},
- {"seed", genRequest.seed},
- {"control_strength", genRequest.controlStrength},
- {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
- {"scheduler", schedulerToString(genRequest.scheduler)}
- };
- // Add VAE/TAESD if specified
- if (!genRequest.vaePath.empty()) {
- params["vae_model"] = requestJson.value("vae_model", "");
- }
- if (!genRequest.taesdPath.empty()) {
- params["taesd_model"] = requestJson.value("taesd_model", "");
- }
- nlohmann::json response = {
- {"request_id", requestId},
- {"status", "queued"},
- {"message", "ControlNet generation request queued successfully"},
- {"queue_position", m_generationQueue->getQueueSize()},
- {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
- {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
- {"type", "controlnet"},
- {"parameters", params}
- };
- sendJsonResponse(res, response, 202);
- } catch (const nlohmann::json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("ControlNet request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleUpscale(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
- return;
- }
- nlohmann::json requestJson = nlohmann::json::parse(req.body);
- // Validate required fields for upscaler
- if (!requestJson.contains("image") || !requestJson["image"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'image' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- if (!requestJson.contains("esrgan_model") || !requestJson["esrgan_model"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'esrgan_model' field (model hash or name)", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Check if model manager is available
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Get the ESRGAN/upscaler model
- std::string esrganModelId = requestJson["esrgan_model"];
- auto modelInfo = m_modelManager->getModelInfo(esrganModelId);
- if (modelInfo.name.empty()) {
- sendErrorResponse(res, "ESRGAN model not found: " + esrganModelId, 404, "MODEL_NOT_FOUND", requestId);
- return;
- }
- if (modelInfo.type != ModelType::ESRGAN && modelInfo.type != ModelType::UPSCALER) {
- sendErrorResponse(res, "Model is not an ESRGAN/upscaler model", 400, "INVALID_MODEL_TYPE", requestId);
- return;
- }
- // Load the input image
- std::string imageInput = requestJson["image"];
- auto [imageData, imgWidth, imgHeight, imgChannels, success, loadError] = loadImageFromInput(imageInput);
- if (!success) {
- sendErrorResponse(res, "Failed to load image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId);
- return;
- }
- // Create upscaler request
- GenerationRequest genRequest;
- genRequest.id = requestId;
- genRequest.requestType = GenerationRequest::RequestType::UPSCALER;
- genRequest.esrganPath = modelInfo.path;
- genRequest.upscaleFactor = requestJson.value("upscale_factor", 4);
- genRequest.nThreads = requestJson.value("threads", -1);
- genRequest.offloadParamsToCpu = requestJson.value("offload_to_cpu", false);
- genRequest.diffusionConvDirect = requestJson.value("direct", false);
- // Set input image data
- genRequest.initImageData = imageData;
- genRequest.initImageWidth = imgWidth;
- genRequest.initImageHeight = imgHeight;
- genRequest.initImageChannels = imgChannels;
- // Enqueue request
- auto future = m_generationQueue->enqueueRequest(genRequest);
- nlohmann::json response = {
- {"request_id", requestId},
- {"status", "queued"},
- {"message", "Upscale request queued successfully"},
- {"queue_position", m_generationQueue->getQueueSize()},
- {"type", "upscale"},
- {"parameters", {
- {"esrgan_model", esrganModelId},
- {"upscale_factor", genRequest.upscaleFactor},
- {"input_width", imgWidth},
- {"input_height", imgHeight},
- {"output_width", imgWidth * genRequest.upscaleFactor},
- {"output_height", imgHeight * genRequest.upscaleFactor}
- }}
- };
- sendJsonResponse(res, response, 202);
- } catch (const nlohmann::json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Upscale request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleInpainting(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_generationQueue) {
- sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
- return;
- }
- nlohmann::json requestJson = nlohmann::json::parse(req.body);
- // Validate required fields for inpainting
- if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- if (!requestJson.contains("source_image") || !requestJson["source_image"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'source_image' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- if (!requestJson.contains("mask_image") || !requestJson["mask_image"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'mask_image' field", 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Validate all parameters
- auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
- if (!isValid) {
- sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Check if any model is loaded
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Get currently loaded checkpoint model
- auto allModels = m_modelManager->getAllModels();
- std::string loadedModelName;
- for (const auto& [modelName, modelInfo] : allModels) {
- if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
- loadedModelName = modelName;
- break;
- }
- }
- if (loadedModelName.empty()) {
- sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
- return;
- }
- // Load the source image
- std::string sourceImageInput = requestJson["source_image"];
- auto [sourceImageData, sourceImgWidth, sourceImgHeight, sourceImgChannels, sourceSuccess, sourceLoadError] = loadImageFromInput(sourceImageInput);
- if (!sourceSuccess) {
- sendErrorResponse(res, "Failed to load source image: " + sourceLoadError, 400, "IMAGE_LOAD_ERROR", requestId);
- return;
- }
- // Load the mask image
- std::string maskImageInput = requestJson["mask_image"];
- auto [maskImageData, maskImgWidth, maskImgHeight, maskImgChannels, maskSuccess, maskLoadError] = loadImageFromInput(maskImageInput);
- if (!maskSuccess) {
- sendErrorResponse(res, "Failed to load mask image: " + maskLoadError, 400, "MASK_LOAD_ERROR", requestId);
- return;
- }
- // Validate that source and mask images have compatible dimensions
- if (sourceImgWidth != maskImgWidth || sourceImgHeight != maskImgHeight) {
- sendErrorResponse(res, "Source and mask images must have the same dimensions", 400, "DIMENSION_MISMATCH", requestId);
- return;
- }
- // Create generation request specifically for inpainting
- GenerationRequest genRequest;
- genRequest.id = requestId;
- genRequest.requestType = GenerationRequest::RequestType::INPAINTING;
- genRequest.modelName = loadedModelName; // Use the currently loaded model
- genRequest.prompt = requestJson["prompt"];
- genRequest.negativePrompt = requestJson.value("negative_prompt", "");
- genRequest.width = requestJson.value("width", sourceImgWidth); // Default to input image dimensions
- genRequest.height = requestJson.value("height", sourceImgHeight);
- genRequest.batchCount = requestJson.value("batch_count", 1);
- genRequest.steps = requestJson.value("steps", 20);
- genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
- genRequest.seed = requestJson.value("seed", "random");
- genRequest.strength = requestJson.value("strength", 0.75f);
- // Set source image data
- genRequest.initImageData = sourceImageData;
- genRequest.initImageWidth = sourceImgWidth;
- genRequest.initImageHeight = sourceImgHeight;
- genRequest.initImageChannels = sourceImgChannels;
- // Set mask image data
- genRequest.maskImageData = maskImageData;
- genRequest.maskImageWidth = maskImgWidth;
- genRequest.maskImageHeight = maskImgHeight;
- genRequest.maskImageChannels = maskImgChannels;
- // Parse optional parameters
- if (requestJson.contains("sampling_method")) {
- genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
- }
- if (requestJson.contains("scheduler")) {
- genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
- }
- // Optional VAE model
- if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
- std::string vaeModelId = requestJson["vae_model"];
- if (!vaeModelId.empty()) {
- auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
- if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
- genRequest.vaePath = vaeInfo.path;
- } else {
- sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
- return;
- }
- }
- }
- // Optional TAESD model
- if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
- std::string taesdModelId = requestJson["taesd_model"];
- if (!taesdModelId.empty()) {
- auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
- if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
- genRequest.taesdPath = taesdInfo.path;
- } else {
- sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
- return;
- }
- }
- }
- // Enqueue request
- auto future = m_generationQueue->enqueueRequest(genRequest);
- nlohmann::json params = {
- {"prompt", genRequest.prompt},
- {"negative_prompt", genRequest.negativePrompt},
- {"source_image", requestJson["source_image"]},
- {"mask_image", requestJson["mask_image"]},
- {"model", genRequest.modelName},
- {"width", genRequest.width},
- {"height", genRequest.height},
- {"batch_count", genRequest.batchCount},
- {"steps", genRequest.steps},
- {"cfg_scale", genRequest.cfgScale},
- {"seed", genRequest.seed},
- {"strength", genRequest.strength},
- {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
- {"scheduler", schedulerToString(genRequest.scheduler)}
- };
- // Add VAE/TAESD if specified
- if (!genRequest.vaePath.empty()) {
- params["vae_model"] = requestJson.value("vae_model", "");
- }
- if (!genRequest.taesdPath.empty()) {
- params["taesd_model"] = requestJson.value("taesd_model", "");
- }
- nlohmann::json response = {
- {"request_id", requestId},
- {"status", "queued"},
- {"message", "Inpainting generation request queued successfully"},
- {"queue_position", m_generationQueue->getQueueSize()},
- {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
- {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
- {"type", "inpainting"},
- {"parameters", params}
- };
- sendJsonResponse(res, response, 202);
- } catch (const nlohmann::json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Inpainting request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- // Utility endpoints
- void Server::handleSamplers(const httplib::Request& /*req*/, httplib::Response& res) {
- try {
- nlohmann::json samplers = {
- {"samplers", {
- {
- {"name", "euler"},
- {"description", "Euler sampler - fast and simple"},
- {"recommended_steps", 20}
- },
- {
- {"name", "euler_a"},
- {"description", "Euler ancestral sampler - adds randomness"},
- {"recommended_steps", 20}
- },
- {
- {"name", "heun"},
- {"description", "Heun sampler - more accurate but slower"},
- {"recommended_steps", 20}
- },
- {
- {"name", "dpm2"},
- {"description", "DPM2 sampler - second-order DPM"},
- {"recommended_steps", 20}
- },
- {
- {"name", "dpm++2s_a"},
- {"description", "DPM++ 2s ancestral sampler"},
- {"recommended_steps", 20}
- },
- {
- {"name", "dpm++2m"},
- {"description", "DPM++ 2m sampler - multistep"},
- {"recommended_steps", 20}
- },
- {
- {"name", "dpm++2mv2"},
- {"description", "DPM++ 2m v2 sampler - improved multistep"},
- {"recommended_steps", 20}
- },
- {
- {"name", "ipndm"},
- {"description", "IPNDM sampler - improved noise prediction"},
- {"recommended_steps", 20}
- },
- {
- {"name", "ipndm_v"},
- {"description", "IPNDM v sampler - variant of IPNDM"},
- {"recommended_steps", 20}
- },
- {
- {"name", "lcm"},
- {"description", "LCM sampler - Latent Consistency Model, very fast"},
- {"recommended_steps", 4}
- },
- {
- {"name", "ddim_trailing"},
- {"description", "DDIM trailing sampler - deterministic"},
- {"recommended_steps", 20}
- },
- {
- {"name", "tcd"},
- {"description", "TCD sampler - Trajectory Consistency Distillation"},
- {"recommended_steps", 8}
- },
- {
- {"name", "default"},
- {"description", "Use model's default sampler"},
- {"recommended_steps", 20}
- }
- }}
- };
- sendJsonResponse(res, samplers);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to get samplers: ") + e.what(), 500);
- }
- }
- void Server::handleSchedulers(const httplib::Request& /*req*/, httplib::Response& res) {
- try {
- nlohmann::json schedulers = {
- {"schedulers", {
- {
- {"name", "discrete"},
- {"description", "Discrete scheduler - standard noise schedule"}
- },
- {
- {"name", "karras"},
- {"description", "Karras scheduler - improved noise schedule"}
- },
- {
- {"name", "exponential"},
- {"description", "Exponential scheduler - exponential noise decay"}
- },
- {
- {"name", "ays"},
- {"description", "AYS scheduler - Adaptive Your Scheduler"}
- },
- {
- {"name", "gits"},
- {"description", "GITS scheduler - Generalized Iterative Time Steps"}
- },
- {
- {"name", "smoothstep"},
- {"description", "Smoothstep scheduler - smooth transition function"}
- },
- {
- {"name", "sgm_uniform"},
- {"description", "SGM uniform scheduler - uniform noise schedule"}
- },
- {
- {"name", "simple"},
- {"description", "Simple scheduler - basic linear schedule"}
- },
- {
- {"name", "default"},
- {"description", "Use model's default scheduler"}
- }
- }}
- };
- sendJsonResponse(res, schedulers);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to get schedulers: ") + e.what(), 500);
- }
- }
- void Server::handleParameters(const httplib::Request& /*req*/, httplib::Response& res) {
- try {
- nlohmann::json parameters = {
- {"parameters", {
- {
- {"name", "prompt"},
- {"type", "string"},
- {"required", true},
- {"description", "Text prompt for image generation"},
- {"min_length", 1},
- {"max_length", 10000},
- {"example", "a beautiful landscape with mountains"}
- },
- {
- {"name", "negative_prompt"},
- {"type", "string"},
- {"required", false},
- {"description", "Negative prompt to guide generation away from"},
- {"min_length", 0},
- {"max_length", 10000},
- {"example", "blurry, low quality, distorted"}
- },
- {
- {"name", "width"},
- {"type", "integer"},
- {"required", false},
- {"description", "Image width in pixels"},
- {"min", 64},
- {"max", 2048},
- {"multiple_of", 64},
- {"default", 512}
- },
- {
- {"name", "height"},
- {"type", "integer"},
- {"required", false},
- {"description", "Image height in pixels"},
- {"min", 64},
- {"max", 2048},
- {"multiple_of", 64},
- {"default", 512}
- },
- {
- {"name", "steps"},
- {"type", "integer"},
- {"required", false},
- {"description", "Number of diffusion steps"},
- {"min", 1},
- {"max", 150},
- {"default", 20}
- },
- {
- {"name", "cfg_scale"},
- {"type", "number"},
- {"required", false},
- {"description", "Classifier-Free Guidance scale"},
- {"min", 1.0},
- {"max", 30.0},
- {"default", 7.5}
- },
- {
- {"name", "seed"},
- {"type", "string|integer"},
- {"required", false},
- {"description", "Seed for generation (use 'random' for random seed)"},
- {"example", "42"}
- },
- {
- {"name", "sampling_method"},
- {"type", "string"},
- {"required", false},
- {"description", "Sampling method to use"},
- {"enum", {"euler", "euler_a", "heun", "dpm2", "dpm++2s_a", "dpm++2m", "dpm++2mv2", "ipndm", "ipndm_v", "lcm", "ddim_trailing", "tcd", "default"}},
- {"default", "default"}
- },
- {
- {"name", "scheduler"},
- {"type", "string"},
- {"required", false},
- {"description", "Scheduler to use"},
- {"enum", {"discrete", "karras", "exponential", "ays", "gits", "smoothstep", "sgm_uniform", "simple", "default"}},
- {"default", "default"}
- },
- {
- {"name", "batch_count"},
- {"type", "integer"},
- {"required", false},
- {"description", "Number of images to generate"},
- {"min", 1},
- {"max", 100},
- {"default", 1}
- },
- {
- {"name", "strength"},
- {"type", "number"},
- {"required", false},
- {"description", "Strength for img2img (0.0-1.0)"},
- {"min", 0.0},
- {"max", 1.0},
- {"default", 0.75}
- },
- {
- {"name", "control_strength"},
- {"type", "number"},
- {"required", false},
- {"description", "ControlNet strength (0.0-1.0)"},
- {"min", 0.0},
- {"max", 1.0},
- {"default", 0.9}
- }
- }},
- {"openapi", {
- {"version", "3.0.0"},
- {"info", {
- {"title", "Stable Diffusion REST API"},
- {"version", "1.0.0"},
- {"description", "Comprehensive REST API for stable-diffusion.cpp functionality"}
- }},
- {"components", {
- {"schemas", {
- {"GenerationRequest", {
- {"type", "object"},
- {"required", {"prompt"}},
- {"properties", {
- {"prompt", {{"type", "string"}, {"description", "Text prompt for generation"}}},
- {"negative_prompt", {{"type", "string"}, {"description", "Negative prompt"}}},
- {"width", {{"type", "integer"}, {"minimum", 64}, {"maximum", 2048}, {"default", 512}}},
- {"height", {{"type", "integer"}, {"minimum", 64}, {"maximum", 2048}, {"default", 512}}},
- {"steps", {{"type", "integer"}, {"minimum", 1}, {"maximum", 150}, {"default", 20}}},
- {"cfg_scale", {{"type", "number"}, {"minimum", 1.0}, {"maximum", 30.0}, {"default", 7.5}}}
- }}
- }}
- }}
- }}
- }}
- };
- sendJsonResponse(res, parameters);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to get parameters: ") + e.what(), 500);
- }
- }
- void Server::handleValidate(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- nlohmann::json requestJson = nlohmann::json::parse(req.body);
- // Validate parameters
- auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
- nlohmann::json response = {
- {"request_id", requestId},
- {"valid", isValid},
- {"message", isValid ? "Parameters are valid" : errorMessage},
- {"errors", isValid ? nlohmann::json::array() : nlohmann::json::array({errorMessage})}
- };
- sendJsonResponse(res, response, isValid ? 200 : 400);
- } catch (const nlohmann::json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Validation failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleEstimate(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- nlohmann::json requestJson = nlohmann::json::parse(req.body);
- // Validate parameters first
- auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
- if (!isValid) {
- sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
- return;
- }
- // Create a temporary request to estimate
- GenerationRequest genRequest;
- genRequest.prompt = requestJson["prompt"];
- genRequest.width = requestJson.value("width", 512);
- genRequest.height = requestJson.value("height", 512);
- genRequest.batchCount = requestJson.value("batch_count", 1);
- genRequest.steps = requestJson.value("steps", 20);
- genRequest.diffusionFlashAttn = requestJson.value("diffusion_flash_attn", false);
- genRequest.controlNetPath = requestJson.value("control_net_path", "");
- if (requestJson.contains("sampling_method")) {
- genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
- }
- // Calculate estimates
- uint64_t estimatedTime = estimateGenerationTime(genRequest);
- size_t estimatedMemory = estimateMemoryUsage(genRequest);
- nlohmann::json response = {
- {"request_id", requestId},
- {"estimated_time_seconds", estimatedTime / 1000},
- {"estimated_memory_mb", estimatedMemory / (1024 * 1024)},
- {"parameters", {
- {"resolution", std::to_string(genRequest.width) + "x" + std::to_string(genRequest.height)},
- {"steps", genRequest.steps},
- {"batch_count", genRequest.batchCount},
- {"sampling_method", samplingMethodToString(genRequest.samplingMethod)}
- }}
- };
- sendJsonResponse(res, response);
- } catch (const nlohmann::json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Estimation failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleConfig(const httplib::Request& /*req*/, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- // Get current configuration
- nlohmann::json config = {
- {"request_id", requestId},
- {"config", {
- {"server", {
- {"host", m_host},
- {"port", m_port},
- {"max_concurrent_generations", 1}
- }},
- {"generation", {
- {"default_width", 512},
- {"default_height", 512},
- {"default_steps", 20},
- {"default_cfg_scale", 7.5},
- {"max_batch_count", 100},
- {"max_steps", 150},
- {"max_resolution", 2048}
- }},
- {"rate_limiting", {
- {"requests_per_minute", 60},
- {"enabled", true}
- }}
- }}
- };
- sendJsonResponse(res, config);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Config operation failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleSystem(const httplib::Request& /*req*/, httplib::Response& res) {
- try {
- nlohmann::json system = {
- {"system", {
- {"version", "1.0.0"},
- {"build", "stable-diffusion.cpp-rest"},
- {"uptime", std::chrono::duration_cast<std::chrono::seconds>(
- std::chrono::steady_clock::now().time_since_epoch()).count()},
- {"capabilities", {
- {"text2img", true},
- {"img2img", true},
- {"controlnet", true},
- {"batch_generation", true},
- {"parameter_validation", true},
- {"estimation", true}
- }},
- {"supported_formats", {
- {"input", {"png", "jpg", "jpeg", "webp"}},
- {"output", {"png", "jpg", "jpeg", "webp"}}
- }},
- {"limits", {
- {"max_resolution", 2048},
- {"max_steps", 150},
- {"max_batch_count", 100},
- {"max_prompt_length", 10000}
- }}
- }},
- {"hardware", {
- {"cpu_threads", std::thread::hardware_concurrency()}
- }}
- };
- sendJsonResponse(res, system);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("System info failed: ") + e.what(), 500);
- }
- }
- void Server::handleSystemRestart(const httplib::Request& /*req*/, httplib::Response& res) {
- try {
- nlohmann::json response = {
- {"message", "Server restart initiated. The server will shut down gracefully and exit. Please use a process manager to automatically restart it."},
- {"status", "restarting"}
- };
- sendJsonResponse(res, response);
- // Schedule server stop after response is sent
- // Using a separate thread to allow the response to be sent first
- std::thread([this]() {
- std::this_thread::sleep_for(std::chrono::seconds(1));
- this->stop();
- // Exit with code 42 to signal restart intent to process manager
- std::exit(42);
- }).detach();
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Restart failed: ") + e.what(), 500);
- }
- }
- // Helper methods for model management
- nlohmann::json Server::getModelCapabilities(ModelType type) {
- nlohmann::json capabilities = nlohmann::json::object();
- switch (type) {
- case ModelType::CHECKPOINT:
- capabilities = {
- {"text2img", true},
- {"img2img", true},
- {"inpainting", true},
- {"outpainting", true},
- {"controlnet", true},
- {"lora", true},
- {"vae", true},
- {"sampling_methods", {"euler", "euler_a", "heun", "dpm2", "dpm++2s_a", "dpm++2m", "dpm++2mv2", "ipndm", "ipndm_v", "lcm", "ddim_trailing", "tcd"}},
- {"schedulers", {"discrete", "karras", "exponential", "ays", "gits", "smoothstep", "sgm_uniform", "simple"}},
- {"recommended_resolution", "512x512"},
- {"max_resolution", "2048x2048"},
- {"supports_batch", true}
- };
- break;
- case ModelType::LORA:
- capabilities = {
- {"text2img", true},
- {"img2img", true},
- {"inpainting", true},
- {"controlnet", false},
- {"lora", true},
- {"vae", false},
- {"requires_checkpoint", true},
- {"strength_range", {0.0, 2.0}},
- {"recommended_strength", 1.0}
- };
- break;
- case ModelType::CONTROLNET:
- capabilities = {
- {"text2img", false},
- {"img2img", true},
- {"inpainting", true},
- {"controlnet", true},
- {"requires_checkpoint", true},
- {"control_modes", {"canny", "depth", "pose", "scribble", "hed", "mlsd", "normal", "seg"}},
- {"strength_range", {0.0, 1.0}},
- {"recommended_strength", 0.9}
- };
- break;
- case ModelType::VAE:
- capabilities = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"vae", true},
- {"requires_checkpoint", true},
- {"encoding", true},
- {"decoding", true},
- {"precision", {"fp16", "fp32"}}
- };
- break;
- case ModelType::EMBEDDING:
- capabilities = {
- {"text2img", true},
- {"img2img", true},
- {"inpainting", true},
- {"embedding", true},
- {"requires_checkpoint", true},
- {"token_count", 1},
- {"compatible_with", {"checkpoint", "lora"}}
- };
- break;
- case ModelType::TAESD:
- capabilities = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"vae", true},
- {"requires_checkpoint", true},
- {"fast_decoding", true},
- {"real_time", true},
- {"precision", {"fp16", "fp32"}}
- };
- break;
- case ModelType::ESRGAN:
- capabilities = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"upscaling", true},
- {"scale_factors", {2, 4}},
- {"models", {"ESRGAN", "RealESRGAN", "SwinIR"}},
- {"supports_alpha", false}
- };
- break;
- default:
- capabilities = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"capabilities", {}}
- };
- break;
- }
- return capabilities;
- }
- nlohmann::json Server::getModelTypeStatistics() {
- if (!m_modelManager) return nlohmann::json::object();
- nlohmann::json stats = nlohmann::json::object();
- auto allModels = m_modelManager->getAllModels();
- // Initialize counters for each type
- std::map<ModelType, int> typeCounts;
- std::map<ModelType, int> loadedCounts;
- std::map<ModelType, size_t> sizeByType;
- for (const auto& pair : allModels) {
- ModelType type = pair.second.type;
- typeCounts[type]++;
- if (pair.second.isLoaded) {
- loadedCounts[type]++;
- }
- sizeByType[type] += pair.second.fileSize;
- }
- // Build statistics JSON
- for (const auto& count : typeCounts) {
- std::string typeName = ModelManager::modelTypeToString(count.first);
- stats[typeName] = {
- {"total_count", count.second},
- {"loaded_count", loadedCounts[count.first]},
- {"total_size_bytes", sizeByType[count.first]},
- {"total_size_mb", sizeByType[count.first] / (1024.0 * 1024.0)},
- {"average_size_mb", count.second > 0 ? (sizeByType[count.first] / (1024.0 * 1024.0)) / count.second : 0.0}
- };
- }
- return stats;
- }
- // Additional helper methods for model management
- nlohmann::json Server::getModelCompatibility(const ModelManager::ModelInfo& modelInfo) {
- nlohmann::json compatibility = {
- {"is_compatible", true},
- {"compatibility_score", 100},
- {"issues", nlohmann::json::array()},
- {"warnings", nlohmann::json::array()},
- {"requirements", {
- {"min_memory_mb", 1024},
- {"recommended_memory_mb", 2048},
- {"supported_formats", {"safetensors", "ckpt", "gguf"}},
- {"required_dependencies", {}}
- }}
- };
- // Check for specific compatibility issues based on model type
- if (modelInfo.type == ModelType::LORA) {
- compatibility["requirements"]["required_dependencies"] = {"checkpoint"};
- } else if (modelInfo.type == ModelType::CONTROLNET) {
- compatibility["requirements"]["required_dependencies"] = {"checkpoint"};
- } else if (modelInfo.type == ModelType::VAE) {
- compatibility["requirements"]["required_dependencies"] = {"checkpoint"};
- }
- return compatibility;
- }
- nlohmann::json Server::getModelRequirements(ModelType type) {
- nlohmann::json requirements = {
- {"min_memory_mb", 1024},
- {"recommended_memory_mb", 2048},
- {"min_disk_space_mb", 1024},
- {"supported_formats", {"safetensors", "ckpt", "gguf"}},
- {"required_dependencies", nlohmann::json::array()},
- {"optional_dependencies", nlohmann::json::array()},
- {"system_requirements", {
- {"cpu_cores", 4},
- {"cpu_architecture", "x86_64"},
- {"os", "Linux/Windows/macOS"},
- {"gpu_memory_mb", 2048},
- {"gpu_compute_capability", "3.5+"}
- }}
- };
- switch (type) {
- case ModelType::CHECKPOINT:
- requirements["min_memory_mb"] = 2048;
- requirements["recommended_memory_mb"] = 4096;
- requirements["min_disk_space_mb"] = 2048;
- requirements["supported_formats"] = {"safetensors", "ckpt", "gguf"};
- break;
- case ModelType::LORA:
- requirements["min_memory_mb"] = 512;
- requirements["recommended_memory_mb"] = 1024;
- requirements["min_disk_space_mb"] = 100;
- requirements["supported_formats"] = {"safetensors", "ckpt"};
- requirements["required_dependencies"] = {"checkpoint"};
- break;
- case ModelType::CONTROLNET:
- requirements["min_memory_mb"] = 1024;
- requirements["recommended_memory_mb"] = 2048;
- requirements["min_disk_space_mb"] = 500;
- requirements["supported_formats"] = {"safetensors", "pth"};
- requirements["required_dependencies"] = {"checkpoint"};
- break;
- case ModelType::VAE:
- requirements["min_memory_mb"] = 512;
- requirements["recommended_memory_mb"] = 1024;
- requirements["min_disk_space_mb"] = 200;
- requirements["supported_formats"] = {"safetensors", "pt", "ckpt", "gguf"};
- requirements["required_dependencies"] = {"checkpoint"};
- break;
- case ModelType::EMBEDDING:
- requirements["min_memory_mb"] = 64;
- requirements["recommended_memory_mb"] = 256;
- requirements["min_disk_space_mb"] = 10;
- requirements["supported_formats"] = {"safetensors", "pt"};
- requirements["required_dependencies"] = {"checkpoint"};
- break;
- case ModelType::TAESD:
- requirements["min_memory_mb"] = 256;
- requirements["recommended_memory_mb"] = 512;
- requirements["min_disk_space_mb"] = 100;
- requirements["supported_formats"] = {"safetensors", "pth", "gguf"};
- requirements["required_dependencies"] = {"checkpoint"};
- break;
- case ModelType::ESRGAN:
- requirements["min_memory_mb"] = 1024;
- requirements["recommended_memory_mb"] = 2048;
- requirements["min_disk_space_mb"] = 500;
- requirements["supported_formats"] = {"pth", "pt"};
- requirements["optional_dependencies"] = {"checkpoint"};
- break;
- default:
- break;
- }
- return requirements;
- }
- nlohmann::json Server::getRecommendedUsage(ModelType type) {
- nlohmann::json usage = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"controlnet", false},
- {"lora", false},
- {"vae", false},
- {"recommended_resolution", "512x512"},
- {"recommended_steps", 20},
- {"recommended_cfg_scale", 7.5},
- {"recommended_batch_size", 1}
- };
- switch (type) {
- case ModelType::CHECKPOINT:
- usage = {
- {"text2img", true},
- {"img2img", true},
- {"inpainting", true},
- {"controlnet", true},
- {"lora", true},
- {"vae", true},
- {"recommended_resolution", "512x512"},
- {"recommended_steps", 20},
- {"recommended_cfg_scale", 7.5},
- {"recommended_batch_size", 1}
- };
- break;
- case ModelType::LORA:
- usage = {
- {"text2img", true},
- {"img2img", true},
- {"inpainting", true},
- {"controlnet", false},
- {"lora", true},
- {"vae", false},
- {"recommended_strength", 1.0},
- {"recommended_usage", "Style transfer, character customization"}
- };
- break;
- case ModelType::CONTROLNET:
- usage = {
- {"text2img", false},
- {"img2img", true},
- {"inpainting", true},
- {"controlnet", true},
- {"lora", false},
- {"vae", false},
- {"recommended_strength", 0.9},
- {"recommended_usage", "Precise control over output"}
- };
- break;
- case ModelType::VAE:
- usage = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"controlnet", false},
- {"lora", false},
- {"vae", true},
- {"recommended_usage", "Improved encoding/decoding quality"}
- };
- break;
- case ModelType::EMBEDDING:
- usage = {
- {"text2img", true},
- {"img2img", true},
- {"inpainting", true},
- {"controlnet", false},
- {"lora", false},
- {"vae", false},
- {"embedding", true},
- {"recommended_usage", "Concept control, style words"}
- };
- break;
- case ModelType::TAESD:
- usage = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"controlnet", false},
- {"lora", false},
- {"vae", true},
- {"recommended_usage", "Real-time decoding"}
- };
- break;
- case ModelType::ESRGAN:
- usage = {
- {"text2img", false},
- {"img2img", false},
- {"inpainting", false},
- {"controlnet", false},
- {"lora", false},
- {"vae", false},
- {"upscaling", true},
- {"recommended_usage", "Image upscaling and quality enhancement"}
- };
- break;
- default:
- break;
- }
- return usage;
- }
- std::string Server::getModelTypeFromDirectoryName(const std::string& dirName) {
- if (dirName == "stable-diffusion" || dirName == "checkpoints") {
- return "checkpoint";
- } else if (dirName == "lora") {
- return "lora";
- } else if (dirName == "controlnet") {
- return "controlnet";
- } else if (dirName == "vae") {
- return "vae";
- } else if (dirName == "taesd") {
- return "taesd";
- } else if (dirName == "esrgan" || dirName == "upscaler") {
- return "esrgan";
- } else if (dirName == "embeddings" || dirName == "textual-inversion") {
- return "embedding";
- } else {
- return "unknown";
- }
- }
- std::string Server::getDirectoryDescription(const std::string& dirName) {
- if (dirName == "stable-diffusion" || dirName == "checkpoints") {
- return "Main stable diffusion model files";
- } else if (dirName == "lora") {
- return "LoRA adapter models for style transfer";
- } else if (dirName == "controlnet") {
- return "ControlNet models for precise control";
- } else if (dirName == "vae") {
- return "VAE models for improved encoding/decoding";
- } else if (dirName == "taesd") {
- return "TAESD models for real-time decoding";
- } else if (dirName == "esrgan" || dirName == "upscaler") {
- return "ESRGAN models for image upscaling";
- } else if (dirName == "embeddings" || dirName == "textual-inversion") {
- return "Text embeddings for concept control";
- } else {
- return "Unknown model directory";
- }
- }
- nlohmann::json Server::getDirectoryContents(const std::string& dirPath) {
- nlohmann::json contents = nlohmann::json::array();
- try {
- if (std::filesystem::exists(dirPath) && std::filesystem::is_directory(dirPath)) {
- for (const auto& entry : std::filesystem::directory_iterator(dirPath)) {
- if (entry.is_regular_file()) {
- nlohmann::json file = {
- {"name", entry.path().filename().string()},
- {"path", entry.path().string()},
- {"size", std::filesystem::file_size(entry.path())},
- {"size_mb", std::filesystem::file_size(entry.path()) / (1024.0 * 1024.0)},
- {"last_modified", std::chrono::duration_cast<std::chrono::seconds>(
- std::filesystem::last_write_time(entry.path()).time_since_epoch()).count()}
- };
- contents.push_back(file);
- }
- }
- }
- } catch (const std::exception& e) {
- // Return empty array if directory access fails
- }
- return contents;
- }
- nlohmann::json Server::getLargestModel(const std::map<std::string, ModelManager::ModelInfo>& allModels) {
- nlohmann::json largest = nlohmann::json::object();
- size_t maxSize = 0;
- std::string largestName;
- for (const auto& pair : allModels) {
- if (pair.second.fileSize > maxSize) {
- maxSize = pair.second.fileSize;
- largestName = pair.second.name;
- }
- }
- if (!largestName.empty()) {
- largest = {
- {"name", largestName},
- {"size", maxSize},
- {"size_mb", maxSize / (1024.0 * 1024.0)},
- {"type", ModelManager::modelTypeToString(allModels.at(largestName).type)}
- };
- }
- return largest;
- }
- nlohmann::json Server::getSmallestModel(const std::map<std::string, ModelManager::ModelInfo>& allModels) {
- nlohmann::json smallest = nlohmann::json::object();
- size_t minSize = SIZE_MAX;
- std::string smallestName;
- for (const auto& pair : allModels) {
- if (pair.second.fileSize < minSize) {
- minSize = pair.second.fileSize;
- smallestName = pair.second.name;
- }
- }
- if (!smallestName.empty()) {
- smallest = {
- {"name", smallestName},
- {"size", minSize},
- {"size_mb", minSize / (1024.0 * 1024.0)},
- {"type", ModelManager::modelTypeToString(allModels.at(smallestName).type)}
- };
- }
- return smallest;
- }
- nlohmann::json Server::validateModelFile(const std::string& modelPath, const std::string& modelType) {
- nlohmann::json validation = {
- {"is_valid", false},
- {"errors", nlohmann::json::array()},
- {"warnings", nlohmann::json::array()},
- {"file_info", nlohmann::json::object()},
- {"compatibility", nlohmann::json::object()},
- {"recommendations", nlohmann::json::array()}
- };
- try {
- if (!std::filesystem::exists(modelPath)) {
- validation["errors"].push_back("File does not exist");
- return validation;
- }
- if (!std::filesystem::is_regular_file(modelPath)) {
- validation["errors"].push_back("Path is not a regular file");
- return validation;
- }
- // Check file extension
- std::string extension = std::filesystem::path(modelPath).extension().string();
- if (extension.empty()) {
- validation["errors"].push_back("Missing file extension");
- return validation;
- }
- // Remove dot and convert to lowercase
- if (extension[0] == '.') {
- extension = extension.substr(1);
- }
- std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower);
- // Validate extension based on model type
- ModelType type = ModelManager::stringToModelType(modelType);
- bool validExtension = false;
- switch (type) {
- case ModelType::CHECKPOINT:
- validExtension = (extension == "safetensors" || extension == "ckpt" || extension == "gguf");
- break;
- case ModelType::LORA:
- validExtension = (extension == "safetensors" || extension == "ckpt");
- break;
- case ModelType::CONTROLNET:
- validExtension = (extension == "safetensors" || extension == "pth");
- break;
- case ModelType::VAE:
- validExtension = (extension == "safetensors" || extension == "pt" || extension == "ckpt" || extension == "gguf");
- break;
- case ModelType::EMBEDDING:
- validExtension = (extension == "safetensors" || extension == "pt");
- break;
- case ModelType::TAESD:
- validExtension = (extension == "safetensors" || extension == "pth" || extension == "gguf");
- break;
- case ModelType::ESRGAN:
- validExtension = (extension == "pth" || extension == "pt");
- break;
- default:
- break;
- }
- if (!validExtension) {
- validation["errors"].push_back("Invalid file extension for model type: " + extension);
- }
- // Check file size
- size_t fileSize = std::filesystem::file_size(modelPath);
- if (fileSize == 0) {
- validation["errors"].push_back("File is empty");
- } else if (fileSize > 8ULL * 1024 * 1024 * 1024) { // 8GB
- validation["warnings"].push_back("Very large file may cause performance issues");
- }
- // Build file info
- validation["file_info"] = {
- {"path", modelPath},
- {"size", fileSize},
- {"size_mb", fileSize / (1024.0 * 1024.0)},
- {"extension", extension},
- {"last_modified", std::chrono::duration_cast<std::chrono::seconds>(
- std::filesystem::last_write_time(modelPath).time_since_epoch()).count()}
- };
- // Check compatibility
- validation["compatibility"] = {
- {"extension_valid", validExtension},
- {"size_appropriate", fileSize <= 4ULL * 1024 * 1024 * 1024}, // 4GB
- {"recommended_format", "safetensors"}
- };
- // Add recommendations
- if (!validExtension) {
- validation["recommendations"].push_back("Convert to SafeTensors format for better security and performance");
- }
- if (fileSize > 2ULL * 1024 * 1024 * 1024) { // 2GB
- validation["recommendations"].push_back("Consider using a smaller model for better performance");
- }
- // If no errors found, mark as valid
- if (validation["errors"].empty()) {
- validation["is_valid"] = true;
- }
- } catch (const std::exception& e) {
- validation["errors"].push_back("Validation failed: " + std::string(e.what()));
- }
- return validation;
- }
- nlohmann::json Server::checkModelCompatibility(const ModelManager::ModelInfo& modelInfo, const std::string& systemInfo) {
- nlohmann::json compatibility = {
- {"is_compatible", true},
- {"compatibility_score", 100},
- {"issues", nlohmann::json::array()},
- {"warnings", nlohmann::json::array()},
- {"requirements", nlohmann::json::object()},
- {"recommendations", nlohmann::json::array()},
- {"system_info", nlohmann::json::object()}
- };
- // Check system compatibility
- if (systemInfo == "auto") {
- compatibility["system_info"] = {
- {"cpu_cores", std::thread::hardware_concurrency()}
- };
- }
- // Check model-specific compatibility issues
- if (modelInfo.type == ModelType::CHECKPOINT) {
- if (modelInfo.fileSize > 4ULL * 1024 * 1024 * 1024) { // 4GB
- compatibility["warnings"].push_back("Large checkpoint model may require significant memory");
- compatibility["compatibility_score"] = 80;
- }
- if (modelInfo.fileSize < 500 * 1024 * 1024) { // 500MB
- compatibility["warnings"].push_back("Small checkpoint model may have limited capabilities");
- compatibility["compatibility_score"] = 85;
- }
- } else if (modelInfo.type == ModelType::LORA) {
- if (modelInfo.fileSize > 500 * 1024 * 1024) { // 500MB
- compatibility["warnings"].push_back("Large LoRA may impact performance");
- compatibility["compatibility_score"] = 75;
- }
- }
- return compatibility;
- }
- nlohmann::json Server::calculateSpecificRequirements(const std::string& modelType, const std::string& resolution, const std::string& batchSize) {
- (void)modelType; // Suppress unused parameter warning
- nlohmann::json specific = {
- {"memory_requirements", nlohmann::json::object()},
- {"performance_impact", nlohmann::json::object()},
- {"quality_expectations", nlohmann::json::object()}
- };
- // Parse resolution
- int width = 512, height = 512;
- try {
- size_t xPos = resolution.find('x');
- if (xPos != std::string::npos) {
- width = std::stoi(resolution.substr(0, xPos));
- height = std::stoi(resolution.substr(xPos + 1));
- }
- } catch (...) {
- // Use defaults if parsing fails
- }
- // Parse batch size
- int batch = 1;
- try {
- batch = std::stoi(batchSize);
- } catch (...) {
- // Use default if parsing fails
- }
- // Calculate memory requirements based on resolution and batch
- size_t pixels = width * height;
- size_t baseMemory = 1024 * 1024 * 1024; // 1GB base
- size_t resolutionMemory = (pixels * 4) / (512 * 512); // Scale based on 512x512
- size_t batchMemory = (batch - 1) * baseMemory * 0.5; // Additional memory for batch
- specific["memory_requirements"] = {
- {"base_memory_mb", baseMemory / (1024 * 1024)},
- {"resolution_memory_mb", resolutionMemory / (1024 * 1024)},
- {"batch_memory_mb", batchMemory / (1024 * 1024)},
- {"total_memory_mb", (baseMemory + resolutionMemory + batchMemory) / (1024 * 1024)}
- };
- // Calculate performance impact
- double performanceFactor = 1.0;
- if (pixels > 512 * 512) {
- performanceFactor = 1.5;
- }
- if (batch > 1) {
- performanceFactor *= 1.2;
- }
- specific["performance_impact"] = {
- {"resolution_factor", pixels > 512 * 512 ? 1.5 : 1.0},
- {"batch_factor", batch > 1 ? 1.2 : 1.0},
- {"overall_factor", performanceFactor}
- };
- return specific;
- }
- // Enhanced model management endpoint implementations
- void Server::handleModelInfo(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Extract model ID from URL path
- std::string modelId = req.matches[1].str();
- if (modelId.empty()) {
- sendErrorResponse(res, "Missing model ID", 400, "MISSING_MODEL_ID", requestId);
- return;
- }
- // Get model information
- auto modelInfo = m_modelManager->getModelInfo(modelId);
- if (modelInfo.name.empty()) {
- sendErrorResponse(res, "Model not found", 404, "MODEL_NOT_FOUND", requestId);
- return;
- }
- // Build comprehensive model information
- nlohmann::json response = {
- {"model", {
- {"name", modelInfo.name},
- {"path", modelInfo.path},
- {"type", ModelManager::modelTypeToString(modelInfo.type)},
- {"is_loaded", modelInfo.isLoaded},
- {"file_size", modelInfo.fileSize},
- {"file_size_mb", modelInfo.fileSize / (1024.0 * 1024.0)},
- {"description", modelInfo.description},
- {"metadata", modelInfo.metadata},
- {"capabilities", getModelCapabilities(modelInfo.type)},
- {"compatibility", getModelCompatibility(modelInfo)},
- {"requirements", getModelRequirements(modelInfo.type)},
- {"recommended_usage", getRecommendedUsage(modelInfo.type)},
- {"last_modified", std::chrono::duration_cast<std::chrono::seconds>(
- modelInfo.modifiedAt.time_since_epoch()).count()}
- }},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to get model info: ") + e.what(), 500, "MODEL_INFO_ERROR", requestId);
- }
- }
- void Server::handleLoadModelById(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Extract model ID from URL path (could be hash or name)
- std::string modelIdentifier = req.matches[1].str();
- if (modelIdentifier.empty()) {
- sendErrorResponse(res, "Missing model identifier", 400, "MISSING_MODEL_ID", requestId);
- return;
- }
- // Try to find by hash first (if it looks like a hash - 10+ hex chars)
- std::string modelId = modelIdentifier;
- if (modelIdentifier.length() >= 10 &&
- std::all_of(modelIdentifier.begin(), modelIdentifier.end(),
- [](char c) { return std::isxdigit(c); })) {
- std::string foundName = m_modelManager->findModelByHash(modelIdentifier);
- if (!foundName.empty()) {
- modelId = foundName;
- std::cout << "Resolved hash " << modelIdentifier << " to model: " << modelId << std::endl;
- }
- }
- // Parse optional parameters from request body
- nlohmann::json requestJson;
- if (!req.body.empty()) {
- try {
- requestJson = nlohmann::json::parse(req.body);
- } catch (const nlohmann::json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- return;
- }
- }
- // Unload previous model if one is loaded
- std::string previousModel;
- {
- std::lock_guard<std::mutex> lock(m_currentModelMutex);
- previousModel = m_currentlyLoadedModel;
- }
- if (!previousModel.empty() && previousModel != modelId) {
- std::cout << "Unloading previous model: " << previousModel << std::endl;
- m_modelManager->unloadModel(previousModel);
- }
- // Load model
- bool success = m_modelManager->loadModel(modelId);
- if (success) {
- // Update currently loaded model
- {
- std::lock_guard<std::mutex> lock(m_currentModelMutex);
- m_currentlyLoadedModel = modelId;
- }
- auto modelInfo = m_modelManager->getModelInfo(modelId);
- nlohmann::json response = {
- {"status", "success"},
- {"model", {
- {"name", modelInfo.name},
- {"path", modelInfo.path},
- {"type", ModelManager::modelTypeToString(modelInfo.type)},
- {"is_loaded", modelInfo.isLoaded}
- }},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } else {
- sendErrorResponse(res, "Failed to load model", 400, "MODEL_LOAD_FAILED", requestId);
- }
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Model load failed: ") + e.what(), 500, "MODEL_LOAD_ERROR", requestId);
- }
- }
- void Server::handleUnloadModelById(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Extract model ID from URL path
- std::string modelId = req.matches[1].str();
- if (modelId.empty()) {
- sendErrorResponse(res, "Missing model ID", 400, "MISSING_MODEL_ID", requestId);
- return;
- }
- // Unload model
- bool success = m_modelManager->unloadModel(modelId);
- if (success) {
- // Clear currently loaded model if it matches
- {
- std::lock_guard<std::mutex> lock(m_currentModelMutex);
- if (m_currentlyLoadedModel == modelId) {
- m_currentlyLoadedModel = "";
- }
- }
- nlohmann::json response = {
- {"status", "success"},
- {"model", {
- {"name", modelId},
- {"is_loaded", false}
- }},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } else {
- sendErrorResponse(res, "Failed to unload model or model not found", 404, "MODEL_UNLOAD_FAILED", requestId);
- }
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Model unload failed: ") + e.what(), 500, "MODEL_UNLOAD_ERROR", requestId);
- }
- }
- void Server::handleModelTypes(const httplib::Request& /*req*/, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- nlohmann::json types = {
- {"model_types", {
- {
- {"type", "checkpoint"},
- {"description", "Main stable diffusion model files for text-to-image, image-to-image, and inpainting"},
- {"extensions", {"safetensors", "ckpt", "gguf"}},
- {"capabilities", {"text2img", "img2img", "inpainting", "controlnet", "lora", "vae"}},
- {"recommended_for", "General purpose image generation"}
- },
- {
- {"type", "lora"},
- {"description", "LoRA adapter models for style transfer and character customization"},
- {"extensions", {"safetensors", "ckpt"}},
- {"capabilities", {"style_transfer", "character_customization"}},
- {"requires", {"checkpoint"}},
- {"recommended_for", "Style modification and character-specific generation"}
- },
- {
- {"type", "controlnet"},
- {"description", "ControlNet models for precise control over output composition"},
- {"extensions", {"safetensors", "pth"}},
- {"capabilities", {"precise_control", "composition_control"}},
- {"requires", {"checkpoint"}},
- {"recommended_for", "Precise control over image generation"}
- },
- {
- {"type", "vae"},
- {"description", "VAE models for improved encoding and decoding quality"},
- {"extensions", {"safetensors", "pt", "ckpt", "gguf"}},
- {"capabilities", {"encoding", "decoding", "quality_improvement"}},
- {"requires", {"checkpoint"}},
- {"recommended_for", "Improved image quality and encoding"}
- },
- {
- {"type", "embedding"},
- {"description", "Text embeddings for concept control and style words"},
- {"extensions", {"safetensors", "pt"}},
- {"capabilities", {"concept_control", "style_words"}},
- {"requires", {"checkpoint"}},
- {"recommended_for", "Concept control and specific styles"}
- },
- {
- {"type", "taesd"},
- {"description", "TAESD models for real-time decoding"},
- {"extensions", {"safetensors", "pth", "gguf"}},
- {"capabilities", {"real_time_decoding", "fast_preview"}},
- {"requires", {"checkpoint"}},
- {"recommended_for", "Real-time applications and fast previews"}
- },
- {
- {"type", "esrgan"},
- {"description", "ESRGAN models for image upscaling and enhancement"},
- {"extensions", {"pth", "pt"}},
- {"capabilities", {"upscaling", "enhancement", "quality_improvement"}},
- {"recommended_for", "Image upscaling and quality enhancement"}
- }
- }},
- {"request_id", requestId}
- };
- sendJsonResponse(res, types);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to get model types: ") + e.what(), 500, "MODEL_TYPES_ERROR", requestId);
- }
- }
- void Server::handleModelDirectories(const httplib::Request& /*req*/, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- std::string modelsDir = m_modelManager->getModelsDirectory();
- nlohmann::json directories = nlohmann::json::array();
- // Define expected model directories
- std::vector<std::string> modelDirs = {
- "stable-diffusion", "checkpoints", "lora", "controlnet",
- "vae", "taesd", "esrgan", "embeddings"
- };
- for (const auto& dirName : modelDirs) {
- std::string dirPath = modelsDir + "/" + dirName;
- std::string type = getModelTypeFromDirectoryName(dirName);
- std::string description = getDirectoryDescription(dirName);
- nlohmann::json dirInfo = {
- {"name", dirName},
- {"path", dirPath},
- {"type", type},
- {"description", description},
- {"exists", std::filesystem::exists(dirPath) && std::filesystem::is_directory(dirPath)},
- {"contents", getDirectoryContents(dirPath)}
- };
- directories.push_back(dirInfo);
- }
- nlohmann::json response = {
- {"models_directory", modelsDir},
- {"directories", directories},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to get model directories: ") + e.what(), 500, "MODEL_DIRECTORIES_ERROR", requestId);
- }
- }
- void Server::handleRefreshModels(const httplib::Request& /*req*/, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Force refresh of model cache
- bool success = m_modelManager->scanModelsDirectory();
- if (success) {
- nlohmann::json response = {
- {"status", "success"},
- {"message", "Model cache refreshed successfully"},
- {"models_found", m_modelManager->getAvailableModelsCount()},
- {"models_loaded", m_modelManager->getLoadedModelsCount()},
- {"models_directory", m_modelManager->getModelsDirectory()},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } else {
- sendErrorResponse(res, "Failed to refresh model cache", 500, "MODEL_REFRESH_FAILED", requestId);
- }
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Model refresh failed: ") + e.what(), 500, "MODEL_REFRESH_ERROR", requestId);
- }
- }
- void Server::handleHashModels(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_generationQueue || !m_modelManager) {
- sendErrorResponse(res, "Services not available", 500, "SERVICE_UNAVAILABLE", requestId);
- return;
- }
- // Parse request body
- nlohmann::json requestJson;
- if (!req.body.empty()) {
- requestJson = nlohmann::json::parse(req.body);
- }
- HashRequest hashReq;
- hashReq.id = requestId;
- hashReq.forceRehash = requestJson.value("force_rehash", false);
- if (requestJson.contains("models") && requestJson["models"].is_array()) {
- for (const auto& model : requestJson["models"]) {
- hashReq.modelNames.push_back(model.get<std::string>());
- }
- }
- // Enqueue hash request
- auto future = m_generationQueue->enqueueHashRequest(hashReq);
- nlohmann::json response = {
- {"request_id", requestId},
- {"status", "queued"},
- {"message", "Hash job queued successfully"},
- {"models_to_hash", hashReq.modelNames.empty() ? "all_unhashed" : std::to_string(hashReq.modelNames.size())}
- };
- sendJsonResponse(res, response, 202);
- } catch (const nlohmann::json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Hash request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleConvertModel(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_generationQueue || !m_modelManager) {
- sendErrorResponse(res, "Services not available", 500, "SERVICE_UNAVAILABLE", requestId);
- return;
- }
- // Parse request body
- nlohmann::json requestJson;
- try {
- requestJson = nlohmann::json::parse(req.body);
- } catch (const nlohmann::json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- return;
- }
- // Validate required fields
- if (!requestJson.contains("model_name")) {
- sendErrorResponse(res, "Missing required field: model_name", 400, "MISSING_FIELD", requestId);
- return;
- }
- if (!requestJson.contains("quantization_type")) {
- sendErrorResponse(res, "Missing required field: quantization_type", 400, "MISSING_FIELD", requestId);
- return;
- }
- std::string modelName = requestJson["model_name"].get<std::string>();
- std::string quantizationType = requestJson["quantization_type"].get<std::string>();
- // Validate quantization type
- const std::vector<std::string> validTypes = {"f32", "f16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "q2_K", "q3_K", "q4_K"};
- if (std::find(validTypes.begin(), validTypes.end(), quantizationType) == validTypes.end()) {
- sendErrorResponse(res, "Invalid quantization_type. Valid types: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K",
- 400, "INVALID_QUANTIZATION_TYPE", requestId);
- return;
- }
- // Get model info to find the full path
- auto modelInfo = m_modelManager->getModelInfo(modelName);
- if (modelInfo.name.empty()) {
- sendErrorResponse(res, "Model not found: " + modelName, 404, "MODEL_NOT_FOUND", requestId);
- return;
- }
- // Check if model is already GGUF
- if (modelInfo.fullPath.find(".gguf") != std::string::npos) {
- sendErrorResponse(res, "Model is already in GGUF format. Cannot convert GGUF to GGUF.",
- 400, "ALREADY_GGUF", requestId);
- return;
- }
- // Build output path
- std::string outputPath = requestJson.value("output_path", "");
- if (outputPath.empty()) {
- // Generate default output path: model_name_quantization.gguf
- namespace fs = std::filesystem;
- fs::path inputPath(modelInfo.fullPath);
- std::string baseName = inputPath.stem().string();
- std::string outputDir = inputPath.parent_path().string();
- outputPath = outputDir + "/" + baseName + "_" + quantizationType + ".gguf";
- }
- // Create conversion request
- ConversionRequest convReq;
- convReq.id = requestId;
- convReq.modelName = modelName;
- convReq.modelPath = modelInfo.fullPath;
- convReq.outputPath = outputPath;
- convReq.quantizationType = quantizationType;
- // Enqueue conversion request
- auto future = m_generationQueue->enqueueConversionRequest(convReq);
- nlohmann::json response = {
- {"request_id", requestId},
- {"status", "queued"},
- {"message", "Model conversion queued successfully"},
- {"model_name", modelName},
- {"input_path", modelInfo.fullPath},
- {"output_path", outputPath},
- {"quantization_type", quantizationType}
- };
- sendJsonResponse(res, response, 202);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Conversion request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
- }
- }
- void Server::handleModelStats(const httplib::Request& /*req*/, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- auto allModels = m_modelManager->getAllModels();
- nlohmann::json response = {
- {"statistics", {
- {"total_models", allModels.size()},
- {"loaded_models", m_modelManager->getLoadedModelsCount()},
- {"available_models", m_modelManager->getAvailableModelsCount()},
- {"model_types", getModelTypeStatistics()},
- {"largest_model", getLargestModel(allModels)},
- {"smallest_model", getSmallestModel(allModels)}
- }},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Failed to get model stats: ") + e.what(), 500, "MODEL_STATS_ERROR", requestId);
- }
- }
- void Server::handleBatchModels(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Parse JSON request body
- nlohmann::json requestJson = nlohmann::json::parse(req.body);
- if (!requestJson.contains("operation") || !requestJson["operation"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'operation' field", 400, "INVALID_OPERATION", requestId);
- return;
- }
- if (!requestJson.contains("models") || !requestJson["models"].is_array()) {
- sendErrorResponse(res, "Missing or invalid 'models' field", 400, "INVALID_MODELS", requestId);
- return;
- }
- std::string operation = requestJson["operation"];
- nlohmann::json models = requestJson["models"];
- nlohmann::json results = nlohmann::json::array();
- for (const auto& model : models) {
- if (!model.is_string()) {
- results.push_back({
- {"model", model},
- {"success", false},
- {"error", "Invalid model name"}
- });
- continue;
- }
- std::string modelName = model;
- bool success = false;
- std::string error = "";
- if (operation == "load") {
- success = m_modelManager->loadModel(modelName);
- if (!success) error = "Failed to load model";
- } else if (operation == "unload") {
- success = m_modelManager->unloadModel(modelName);
- if (!success) error = "Failed to unload model";
- } else {
- error = "Unsupported operation";
- }
- results.push_back({
- {"model", modelName},
- {"success", success},
- {"error", error.empty() ? nlohmann::json(nullptr) : nlohmann::json(error)}
- });
- }
- nlohmann::json response = {
- {"operation", operation},
- {"results", results},
- {"successful_count", std::count_if(results.begin(), results.end(),
- [](const nlohmann::json& result) { return result["success"].get<bool>(); })},
- {"failed_count", std::count_if(results.begin(), results.end(),
- [](const nlohmann::json& result) { return !result["success"].get<bool>(); })},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const nlohmann::json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Batch operation failed: ") + e.what(), 500, "BATCH_OPERATION_ERROR", requestId);
- }
- }
- void Server::handleValidateModel(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- // Parse JSON request body
- nlohmann::json requestJson = nlohmann::json::parse(req.body);
- if (!requestJson.contains("model_path") || !requestJson["model_path"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'model_path' field", 400, "INVALID_MODEL_PATH", requestId);
- return;
- }
- std::string modelPath = requestJson["model_path"];
- std::string modelType = requestJson.value("model_type", "checkpoint");
- // Validate model file
- nlohmann::json validation = validateModelFile(modelPath, modelType);
- nlohmann::json response = {
- {"validation", validation},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const nlohmann::json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Model validation failed: ") + e.what(), 500, "MODEL_VALIDATION_ERROR", requestId);
- }
- }
- void Server::handleCheckCompatibility(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- if (!m_modelManager) {
- sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
- return;
- }
- // Parse JSON request body
- nlohmann::json requestJson = nlohmann::json::parse(req.body);
- if (!requestJson.contains("model_name") || !requestJson["model_name"].is_string()) {
- sendErrorResponse(res, "Missing or invalid 'model_name' field", 400, "INVALID_MODEL_NAME", requestId);
- return;
- }
- std::string modelName = requestJson["model_name"];
- std::string systemInfo = requestJson.value("system_info", "auto");
- // Get model information
- auto modelInfo = m_modelManager->getModelInfo(modelName);
- if (modelInfo.name.empty()) {
- sendErrorResponse(res, "Model not found", 404, "MODEL_NOT_FOUND", requestId);
- return;
- }
- // Check compatibility
- nlohmann::json compatibility = checkModelCompatibility(modelInfo, systemInfo);
- nlohmann::json response = {
- {"model", modelName},
- {"compatibility", compatibility},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const nlohmann::json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Compatibility check failed: ") + e.what(), 500, "COMPATIBILITY_CHECK_ERROR", requestId);
- }
- }
- void Server::handleModelRequirements(const httplib::Request& req, httplib::Response& res) {
- std::string requestId = generateRequestId();
- try {
- // Parse JSON request body
- nlohmann::json requestJson = nlohmann::json::parse(req.body);
- std::string modelType = requestJson.value("model_type", "checkpoint");
- std::string resolution = requestJson.value("resolution", "512x512");
- std::string batchSize = requestJson.value("batch_size", "1");
- // Calculate specific requirements
- nlohmann::json requirements = calculateSpecificRequirements(modelType, resolution, batchSize);
- // Get general requirements for model type
- ModelType type = ModelManager::stringToModelType(modelType);
- nlohmann::json generalRequirements = getModelRequirements(type);
- nlohmann::json response = {
- {"model_type", modelType},
- {"configuration", {
- {"resolution", resolution},
- {"batch_size", batchSize}
- }},
- {"specific_requirements", requirements},
- {"general_requirements", generalRequirements},
- {"request_id", requestId}
- };
- sendJsonResponse(res, response);
- } catch (const nlohmann::json::parse_error& e) {
- sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
- } catch (const std::exception& e) {
- sendErrorResponse(res, std::string("Requirements calculation failed: ") + e.what(), 500, "REQUIREMENTS_ERROR", requestId);
- }
- }
- void Server::serverThreadFunction(const std::string& host, int port) {
- try {
- std::cout << "Server thread starting, attempting to bind to " << host << ":" << port << std::endl;
- // Check if port is available before attempting to bind
- std::cout << "Checking if port " << port << " is available..." << std::endl;
- // Try to create a test socket to check if port is in use
- int test_socket = socket(AF_INET, SOCK_STREAM, 0);
- if (test_socket >= 0) {
- // Set SO_REUSEADDR to avoid TIME_WAIT issues
- int opt = 1;
- if (setsockopt(test_socket, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) {
- std::cerr << "Warning: Failed to set SO_REUSEADDR on test socket: " << strerror(errno) << std::endl;
- }
- // Also set SO_REUSEPORT if available (for better concurrent binding handling)
- #ifdef SO_REUSEPORT
- int reuseport = 1;
- if (setsockopt(test_socket, SOL_SOCKET, SO_REUSEPORT, &reuseport, sizeof(reuseport)) < 0) {
- std::cerr << "Warning: Failed to set SO_REUSEPORT on test socket: " << strerror(errno) << std::endl;
- }
- #endif
- struct sockaddr_in addr;
- addr.sin_family = AF_INET;
- addr.sin_port = htons(port);
- addr.sin_addr.s_addr = INADDR_ANY;
- // Try to bind to the port
- if (bind(test_socket, (struct sockaddr*)&addr, sizeof(addr)) < 0) {
- close(test_socket);
- std::cerr << "ERROR: Port " << port << " is already in use! Cannot start server." << std::endl;
- std::cerr << "This could be due to:" << std::endl;
- std::cerr << "1. Another instance is already running on this port" << std::endl;
- std::cerr << "2. A previous instance crashed and the socket is in TIME_WAIT state" << std::endl;
- std::cerr << "3. The port is being used by another application" << std::endl;
- std::cerr << std::endl;
- std::cerr << "Solutions:" << std::endl;
- std::cerr << "- Wait 30-60 seconds for TIME_WAIT to expire (if server crashed)" << std::endl;
- std::cerr << "- Kill any existing processes: sudo lsof -ti:" << port << " | xargs kill -9" << std::endl;
- std::cerr << "- Use a different port with -p <port>" << std::endl;
- m_isRunning.store(false);
- m_startupFailed.store(true);
- return;
- }
- close(test_socket);
- }
- std::cout << "Port " << port << " is available, proceeding with server startup..." << std::endl;
- std::cout << "Calling listen()..." << std::endl;
- // We need to set m_isRunning after successful bind but before blocking
- // cpp-httplib doesn't provide a callback, so we set it optimistically
- // and clear it if listen() returns false
- m_isRunning.store(true);
- bool listenResult = m_httpServer->listen(host.c_str(), port);
- std::cout << "listen() returned: " << (listenResult ? "true" : "false") << std::endl;
- // If we reach here, server has stopped (either normally or due to error)
- m_isRunning.store(false);
- if (!listenResult) {
- std::cerr << "Server listen failed! This usually means port is in use or permission denied." << std::endl;
- }
- } catch (const std::exception& e) {
- std::cerr << "Exception in server thread: " << e.what() << std::endl;
- m_isRunning.store(false);
- }
- }
|