server.cpp 204 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894489548964897489848994900490149024903490449054906490749084909491049114912491349144915491649174918491949204921492249234924492549264927492849294930493149324933493449354936493749384939494049414942494349444945494649474948494949504951495249534954495549564957495849594960496149624963496449654966496749684969497049714972497349744975497649774978497949804981498249834984498549864987498849894990499149924993499449954996499749984999500050015002500350045005500650075008500950105011501250135014501550165017501850195020502150225023502450255026502750285029503050315032503350345035503650375038503950405041504250435044504550465047504850495050505150525053505450555056505750585059506050615062506350645065506650675068506950705071507250735074507550765077507850795080508150825083508450855086508750885089509050915092509350945095509650975098509951005101510251035104510551065107510851095110511151125113511451155116511751185119512051215122512351245125512651275128
  1. #include "server.h"
  2. #include "model_manager.h"
  3. #include "generation_queue.h"
  4. #include "utils.h"
  5. #include "auth_middleware.h"
  6. #include "user_manager.h"
  7. #include <httplib.h>
  8. #include <nlohmann/json.hpp>
  9. #include <iostream>
  10. #include <sstream>
  11. #include <fstream>
  12. #include <chrono>
  13. #include <random>
  14. #include <algorithm>
  15. #include <thread>
  16. #include <filesystem>
  17. // Include stb_image for loading images (implementation is in generation_queue.cpp)
  18. #include "../stable-diffusion.cpp-src/thirdparty/stb_image.h"
  19. #include <sys/socket.h>
  20. #include <netinet/in.h>
  21. #include <unistd.h>
  22. #include <arpa/inet.h>
  23. Server::Server(ModelManager* modelManager, GenerationQueue* generationQueue, const std::string& outputDir, const std::string& uiDir)
  24. : m_modelManager(modelManager)
  25. , m_generationQueue(generationQueue)
  26. , m_isRunning(false)
  27. , m_startupFailed(false)
  28. , m_port(8080)
  29. , m_outputDir(outputDir)
  30. , m_uiDir(uiDir)
  31. , m_userManager(nullptr)
  32. , m_authMiddleware(nullptr)
  33. {
  34. m_httpServer = std::make_unique<httplib::Server>();
  35. }
  36. Server::~Server() {
  37. stop();
  38. }
  39. bool Server::start(const std::string& host, int port) {
  40. if (m_isRunning.load()) {
  41. return false;
  42. }
  43. m_host = host;
  44. m_port = port;
  45. // Validate host and port
  46. if (host.empty() || (port < 1 || port > 65535)) {
  47. return false;
  48. }
  49. // Set up CORS headers
  50. setupCORS();
  51. // Register API endpoints
  52. registerEndpoints();
  53. // Reset startup flags
  54. m_startupFailed.store(false);
  55. // Start server in a separate thread
  56. m_serverThread = std::thread(&Server::serverThreadFunction, this, host, port);
  57. // Wait for server to actually start and bind to the port
  58. // Give more time for server to actually start and bind
  59. for (int i = 0; i < 100; i++) { // Wait up to 10 seconds
  60. std::this_thread::sleep_for(std::chrono::milliseconds(100));
  61. // Check if startup failed early
  62. if (m_startupFailed.load()) {
  63. if (m_serverThread.joinable()) {
  64. m_serverThread.join();
  65. }
  66. return false;
  67. }
  68. if (m_isRunning.load()) {
  69. // Give it a moment more to ensure server is fully started
  70. std::this_thread::sleep_for(std::chrono::milliseconds(500));
  71. if (m_isRunning.load()) {
  72. return true;
  73. }
  74. }
  75. }
  76. if (m_isRunning.load()) {
  77. return true;
  78. } else {
  79. if (m_serverThread.joinable()) {
  80. m_serverThread.join();
  81. }
  82. return false;
  83. }
  84. }
  85. void Server::stop() {
  86. // Use atomic check to ensure thread safety
  87. bool wasRunning = m_isRunning.exchange(false);
  88. if (!wasRunning) {
  89. return; // Already stopped
  90. }
  91. if (m_httpServer) {
  92. m_httpServer->stop();
  93. // Give the server a moment to stop the blocking listen call
  94. std::this_thread::sleep_for(std::chrono::milliseconds(100));
  95. // If server thread is still running, try to force unblock the listen call
  96. // by making a quick connection to the server port
  97. if (m_serverThread.joinable()) {
  98. try {
  99. // Create a quick connection to interrupt the blocking listen
  100. httplib::Client client("127.0.0.1", m_port);
  101. client.set_connection_timeout(0, 500000); // 0.5 seconds
  102. client.set_read_timeout(0, 500000); // 0.5 seconds
  103. client.set_write_timeout(0, 500000); // 0.5 seconds
  104. auto res = client.Get("/api/health");
  105. // We don't care about the response, just trying to unblock
  106. } catch (...) {
  107. // Ignore any connection errors - we're just trying to unblock
  108. }
  109. }
  110. }
  111. if (m_serverThread.joinable()) {
  112. m_serverThread.join();
  113. }
  114. }
  115. bool Server::isRunning() const {
  116. return m_isRunning.load();
  117. }
  118. void Server::waitForStop() {
  119. if (m_serverThread.joinable()) {
  120. m_serverThread.join();
  121. }
  122. }
  123. void Server::registerEndpoints() {
  124. // Register authentication endpoints first (before applying middleware)
  125. registerAuthEndpoints();
  126. // Health check endpoint (public)
  127. m_httpServer->Get("/api/health", [this](const httplib::Request& req, httplib::Response& res) {
  128. handleHealthCheck(req, res);
  129. });
  130. // API status endpoint (public)
  131. m_httpServer->Get("/api/status", [this](const httplib::Request& req, httplib::Response& res) {
  132. handleApiStatus(req, res);
  133. });
  134. // Apply authentication middleware to protected endpoints
  135. auto withAuth = [this](std::function<void(const httplib::Request&, httplib::Response&)> handler) {
  136. return [this, handler](const httplib::Request& req, httplib::Response& res) {
  137. if (m_authMiddleware) {
  138. AuthContext authContext = m_authMiddleware->authenticate(req, res);
  139. if (!authContext.authenticated) {
  140. m_authMiddleware->sendAuthError(res, authContext.errorMessage, authContext.errorCode);
  141. return;
  142. }
  143. }
  144. handler(req, res);
  145. };
  146. };
  147. // Specialized generation endpoints (protected)
  148. m_httpServer->Post("/api/generate/text2img", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  149. handleText2Img(req, res);
  150. }));
  151. m_httpServer->Post("/api/generate/img2img", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  152. handleImg2Img(req, res);
  153. }));
  154. m_httpServer->Post("/api/generate/controlnet", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  155. handleControlNet(req, res);
  156. }));
  157. m_httpServer->Post("/api/generate/upscale", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  158. handleUpscale(req, res);
  159. }));
  160. m_httpServer->Post("/api/generate/inpainting", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  161. handleInpainting(req, res);
  162. }));
  163. // Utility endpoints (now protected - require authentication)
  164. m_httpServer->Get("/api/samplers", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  165. handleSamplers(req, res);
  166. }));
  167. m_httpServer->Get("/api/schedulers", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  168. handleSchedulers(req, res);
  169. }));
  170. m_httpServer->Get("/api/parameters", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  171. handleParameters(req, res);
  172. }));
  173. m_httpServer->Post("/api/validate", [this](const httplib::Request& req, httplib::Response& res) {
  174. handleValidate(req, res);
  175. });
  176. m_httpServer->Post("/api/estimate", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  177. handleEstimate(req, res);
  178. }));
  179. m_httpServer->Get("/api/config", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  180. handleConfig(req, res);
  181. }));
  182. m_httpServer->Get("/api/system", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  183. handleSystem(req, res);
  184. }));
  185. m_httpServer->Post("/api/system/restart", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  186. handleSystemRestart(req, res);
  187. }));
  188. // Models list endpoint (now protected - require authentication)
  189. m_httpServer->Get("/api/models", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  190. handleModelsList(req, res);
  191. }));
  192. // Model-specific endpoints
  193. m_httpServer->Get("/api/models/(.*)", [this](const httplib::Request& req, httplib::Response& res) {
  194. handleModelInfo(req, res);
  195. });
  196. m_httpServer->Post("/api/models/(.*)/load", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  197. handleLoadModelById(req, res);
  198. }));
  199. m_httpServer->Post("/api/models/(.*)/unload", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  200. handleUnloadModelById(req, res);
  201. }));
  202. // Model management endpoints (now protected - require authentication)
  203. m_httpServer->Get("/api/models/types", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  204. handleModelTypes(req, res);
  205. }));
  206. m_httpServer->Get("/api/models/directories", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  207. handleModelDirectories(req, res);
  208. }));
  209. m_httpServer->Post("/api/models/refresh", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  210. handleRefreshModels(req, res);
  211. }));
  212. m_httpServer->Post("/api/models/hash", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  213. handleHashModels(req, res);
  214. }));
  215. m_httpServer->Post("/api/models/convert", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  216. handleConvertModel(req, res);
  217. }));
  218. m_httpServer->Get("/api/models/stats", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  219. handleModelStats(req, res);
  220. }));
  221. m_httpServer->Post("/api/models/batch", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  222. handleBatchModels(req, res);
  223. }));
  224. // Model validation endpoints (already protected with withAuth)
  225. m_httpServer->Post("/api/models/validate", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  226. handleValidateModel(req, res);
  227. }));
  228. m_httpServer->Post("/api/models/compatible", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  229. handleCheckCompatibility(req, res);
  230. }));
  231. m_httpServer->Post("/api/models/requirements", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  232. handleModelRequirements(req, res);
  233. }));
  234. // Queue status endpoint (now protected - require authentication)
  235. m_httpServer->Get("/api/queue/status", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  236. handleQueueStatus(req, res);
  237. }));
  238. // Download job output file endpoint (must be before job status endpoint to match more specific pattern first)
  239. // Note: This endpoint is public to allow frontend to display generated images without authentication
  240. m_httpServer->Get("/api/queue/job/(.*)/output/(.*)", [this](const httplib::Request& req, httplib::Response& res) {
  241. handleDownloadOutput(req, res);
  242. });
  243. // Get job output by job ID endpoint (public to allow frontend to display generated images without authentication)
  244. m_httpServer->Get("/api/v1/jobs/(.*)/output", [this](const httplib::Request& req, httplib::Response& res) {
  245. handleJobOutput(req, res);
  246. });
  247. // Download image from URL endpoint (public for CORS-free image handling)
  248. m_httpServer->Get("/api/image/download", [this](const httplib::Request& req, httplib::Response& res) {
  249. handleDownloadImageFromUrl(req, res);
  250. });
  251. // Image resize endpoint (protected)
  252. m_httpServer->Post("/api/image/resize", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  253. handleImageResize(req, res);
  254. }));
  255. // Image crop endpoint (protected)
  256. m_httpServer->Post("/api/image/crop", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  257. handleImageCrop(req, res);
  258. }));
  259. // Job status endpoint (now protected - require authentication)
  260. m_httpServer->Get("/api/queue/job/(.*)", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  261. handleJobStatus(req, res);
  262. }));
  263. // Cancel job endpoint (protected)
  264. m_httpServer->Post("/api/queue/cancel", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  265. handleCancelJob(req, res);
  266. }));
  267. // Clear queue endpoint (protected)
  268. m_httpServer->Post("/api/queue/clear", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  269. handleClearQueue(req, res);
  270. }));
  271. // Serve static web UI files if uiDir is configured
  272. if (!m_uiDir.empty() && std::filesystem::exists(m_uiDir)) {
  273. std::cout << "Serving static UI files from: " << m_uiDir << " at /ui" << std::endl;
  274. // Read UI version from version.nlohmann::json if available
  275. std::string uiVersion = "unknown";
  276. std::string versionFilePath = m_uiDir + "/version.nlohmann::json";
  277. if (std::filesystem::exists(versionFilePath)) {
  278. try {
  279. std::ifstream versionFile(versionFilePath);
  280. if (versionFile.is_open()) {
  281. nlohmann::json versionData = nlohmann::json::parse(versionFile);
  282. if (versionData.contains("version")) {
  283. uiVersion = versionData["version"].get<std::string>();
  284. }
  285. versionFile.close();
  286. }
  287. } catch (const std::exception& e) {
  288. std::cerr << "Failed to read UI version: " << e.what() << std::endl;
  289. }
  290. }
  291. std::cout << "UI version: " << uiVersion << std::endl;
  292. // Serve dynamic config.js that provides runtime configuration to the web UI
  293. m_httpServer->Get("/ui/config.js", [this, uiVersion](const httplib::Request& /*req*/, httplib::Response& res) {
  294. // Generate JavaScript configuration with current server settings
  295. std::ostringstream configJs;
  296. configJs << "// Auto-generated configuration\n"
  297. << "window.__SERVER_CONFIG__ = {\n"
  298. << " apiUrl: 'http://" << m_host << ":" << m_port << "',\n"
  299. << " apiBasePath: '/api',\n"
  300. << " host: '" << m_host << "',\n"
  301. << " port: " << m_port << ",\n"
  302. << " uiVersion: '" << uiVersion << "',\n";
  303. // Add authentication method information
  304. if (m_authMiddleware) {
  305. auto authConfig = m_authMiddleware->getConfig();
  306. std::string authMethod = "none";
  307. switch (authConfig.authMethod) {
  308. case AuthMethod::UNIX:
  309. authMethod = "unix";
  310. break;
  311. case AuthMethod::JWT:
  312. authMethod = "jwt";
  313. break;
  314. default:
  315. authMethod = "none";
  316. break;
  317. }
  318. configJs << " authMethod: '" << authMethod << "',\n"
  319. << " authEnabled: " << (authConfig.authMethod != AuthMethod::NONE ? "true" : "false") << "\n";
  320. } else {
  321. configJs << " authMethod: 'none',\n"
  322. << " authEnabled: false\n";
  323. }
  324. configJs << "};\n";
  325. // No cache for config.js - always fetch fresh
  326. res.set_header("Cache-Control", "no-cache, no-store, must-revalidate");
  327. res.set_header("Pragma", "no-cache");
  328. res.set_header("Expires", "0");
  329. res.set_content(configJs.str(), "application/javascript");
  330. });
  331. // Set up file request handler for caching static assets
  332. m_httpServer->set_file_request_handler([uiVersion](const httplib::Request& req, httplib::Response& res) {
  333. // Add cache headers based on file type and version
  334. std::string path = req.path;
  335. // For versioned static assets (.js, .css, images), use long cache
  336. if (path.find("/_next/") != std::string::npos ||
  337. path.find(".js") != std::string::npos ||
  338. path.find(".css") != std::string::npos ||
  339. path.find(".png") != std::string::npos ||
  340. path.find(".jpg") != std::string::npos ||
  341. path.find(".svg") != std::string::npos ||
  342. path.find(".ico") != std::string::npos ||
  343. path.find(".woff") != std::string::npos ||
  344. path.find(".woff2") != std::string::npos ||
  345. path.find(".ttf") != std::string::npos) {
  346. // Long cache (1 year) for static assets
  347. res.set_header("Cache-Control", "public, max-age=31536000, immutable");
  348. // Add ETag based on UI version for cache validation
  349. res.set_header("ETag", "\"" + uiVersion + "\"");
  350. // Check If-None-Match for conditional requests
  351. if (req.has_header("If-None-Match")) {
  352. std::string clientETag = req.get_header_value("If-None-Match");
  353. if (clientETag == "\"" + uiVersion + "\"") {
  354. res.status = 304; // Not Modified
  355. return;
  356. }
  357. }
  358. } else if (path.find(".html") != std::string::npos || path == "/ui/" || path == "/ui") {
  359. // HTML files should revalidate but can be cached briefly
  360. res.set_header("Cache-Control", "public, max-age=0, must-revalidate");
  361. res.set_header("ETag", "\"" + uiVersion + "\"");
  362. }
  363. });
  364. // Create a handler for UI routes with authentication check
  365. auto uiHandler = [this](const httplib::Request& req, httplib::Response& res) {
  366. // Check if authentication is enabled
  367. if (m_authMiddleware) {
  368. auto authConfig = m_authMiddleware->getConfig();
  369. if (authConfig.authMethod != AuthMethod::NONE) {
  370. // Authentication is enabled, check if user is authenticated
  371. AuthContext authContext = m_authMiddleware->authenticate(req, res);
  372. // For Unix auth, we need to check if the user is authenticated
  373. // The authenticateUnix function will return a guest context for UI requests
  374. // when no Authorization header is present, but we still need to show the login page
  375. if (!authContext.authenticated) {
  376. // Check if this is a request for a static asset (JS, CSS, images)
  377. // These should be served even without authentication to allow the login page to work
  378. bool isStaticAsset = false;
  379. std::string path = req.path;
  380. if (path.find(".js") != std::string::npos ||
  381. path.find(".css") != std::string::npos ||
  382. path.find(".png") != std::string::npos ||
  383. path.find(".jpg") != std::string::npos ||
  384. path.find(".jpeg") != std::string::npos ||
  385. path.find(".svg") != std::string::npos ||
  386. path.find(".ico") != std::string::npos ||
  387. path.find("/_next/") != std::string::npos) {
  388. isStaticAsset = true;
  389. }
  390. // For static assets, allow them to be served without authentication
  391. if (isStaticAsset) {
  392. // Continue to serve the file
  393. } else {
  394. // For HTML requests, redirect to login page
  395. if (req.path.find(".html") != std::string::npos ||
  396. req.path == "/ui/" || req.path == "/ui") {
  397. // Serve the login page instead of the requested page
  398. std::string loginPagePath = m_uiDir + "/login.html";
  399. if (std::filesystem::exists(loginPagePath)) {
  400. std::ifstream loginFile(loginPagePath);
  401. if (loginFile.is_open()) {
  402. std::string content((std::istreambuf_iterator<char>(loginFile)),
  403. std::istreambuf_iterator<char>());
  404. res.set_content(content, "text/html");
  405. return;
  406. }
  407. }
  408. // If login.html doesn't exist, serve a simple login page
  409. std::string simpleLoginPage = R"(
  410. <!DOCTYPE html>
  411. <html>
  412. <head>
  413. <title>Login Required</title>
  414. <style>
  415. body { font-family: Arial, sans-serif; max-width: 500px; margin: 100px auto; padding: 20px; }
  416. .form-group { margin-bottom: 15px; }
  417. label { display: block; margin-bottom: 5px; }
  418. input { width: 100%; padding: 8px; box-sizing: border-box; }
  419. button { background-color: #007bff; color: white; padding: 10px 15px; border: none; cursor: pointer; }
  420. .error { color: red; margin-top: 10px; }
  421. </style>
  422. </head>
  423. <body>
  424. <h1>Login Required</h1>
  425. <p>Please enter your username to continue.</p>
  426. <form id="loginForm">
  427. <div class="form-group">
  428. <label for="username">Username:</label>
  429. <input type="text" id="username" name="username" required>
  430. </div>
  431. <button type="submit">Login</button>
  432. </form>
  433. <div id="error" class="error"></div>
  434. <script>
  435. document.getElementById('loginForm').addEventListener('submit', async (e) => {
  436. e.preventDefault();
  437. const username = document.getElementById('username').value;
  438. const errorDiv = document.getElementById('error');
  439. try {
  440. const response = await fetch('/api/auth/login', {
  441. method: 'POST',
  442. headers: { 'Content-Type': 'application/nlohmann::json' },
  443. body: JSON.stringify({ username })
  444. });
  445. if (response.ok) {
  446. const data = await response.nlohmann::json();
  447. localStorage.setItem('auth_token', data.token);
  448. localStorage.setItem('unix_user', username);
  449. window.location.reload();
  450. } else {
  451. const error = await response.nlohmann::json();
  452. errorDiv.textContent = error.message || 'Login failed';
  453. }
  454. } catch (err) {
  455. errorDiv.textContent = 'Login failed: ' + err.message;
  456. }
  457. });
  458. </script>
  459. </body>
  460. </html>
  461. )";
  462. res.set_content(simpleLoginPage, "text/html");
  463. return;
  464. } else {
  465. // For non-HTML files, return unauthorized
  466. m_authMiddleware->sendAuthError(res, "Authentication required", "AUTH_REQUIRED");
  467. return;
  468. }
  469. }
  470. }
  471. }
  472. }
  473. // If we get here, either auth is disabled or user is authenticated
  474. // Serve the requested file
  475. std::string filePath = req.path.substr(3); // Remove "/ui" prefix
  476. if (filePath.empty() || filePath == "/") {
  477. filePath = "/index.html";
  478. }
  479. std::string fullPath = m_uiDir + filePath;
  480. if (std::filesystem::exists(fullPath) && std::filesystem::is_regular_file(fullPath)) {
  481. std::ifstream file(fullPath, std::ios::binary);
  482. if (file.is_open()) {
  483. std::string content((std::istreambuf_iterator<char>(file)),
  484. std::istreambuf_iterator<char>());
  485. // Determine content type based on file extension
  486. std::string contentType = "text/plain";
  487. if (filePath.find(".html") != std::string::npos) {
  488. contentType = "text/html";
  489. } else if (filePath.find(".js") != std::string::npos) {
  490. contentType = "application/javascript";
  491. } else if (filePath.find(".css") != std::string::npos) {
  492. contentType = "text/css";
  493. } else if (filePath.find(".png") != std::string::npos) {
  494. contentType = "image/png";
  495. } else if (filePath.find(".jpg") != std::string::npos || filePath.find(".jpeg") != std::string::npos) {
  496. contentType = "image/jpeg";
  497. } else if (filePath.find(".svg") != std::string::npos) {
  498. contentType = "image/svg+xml";
  499. }
  500. res.set_content(content, contentType);
  501. } else {
  502. res.status = 404;
  503. res.set_content("File not found", "text/plain");
  504. }
  505. } else {
  506. // For SPA routing, if the file doesn't exist, serve index.html
  507. // This allows Next.js to handle client-side routing
  508. std::string indexPath = m_uiDir + "/index.html";
  509. if (std::filesystem::exists(indexPath)) {
  510. std::ifstream indexFile(indexPath, std::ios::binary);
  511. if (indexFile.is_open()) {
  512. std::string content((std::istreambuf_iterator<char>(indexFile)),
  513. std::istreambuf_iterator<char>());
  514. res.set_content(content, "text/html");
  515. } else {
  516. res.status = 404;
  517. res.set_content("File not found", "text/plain");
  518. }
  519. } else {
  520. res.status = 404;
  521. res.set_content("File not found", "text/plain");
  522. }
  523. }
  524. };
  525. // Set up UI routes with authentication
  526. m_httpServer->Get("/ui/.*", uiHandler);
  527. // Redirect /ui to /ui/ to ensure proper routing
  528. m_httpServer->Get("/ui", [](const httplib::Request& /*req*/, httplib::Response& res) {
  529. res.set_redirect("/ui/");
  530. });
  531. }
  532. }
  533. void Server::setAuthComponents(std::shared_ptr<UserManager> userManager, std::shared_ptr<AuthMiddleware> authMiddleware) {
  534. m_userManager = userManager;
  535. m_authMiddleware = authMiddleware;
  536. }
  537. void Server::registerAuthEndpoints() {
  538. // Login endpoint
  539. m_httpServer->Post("/api/auth/login", [this](const httplib::Request& req, httplib::Response& res) {
  540. handleLogin(req, res);
  541. });
  542. // Logout endpoint
  543. m_httpServer->Post("/api/auth/logout", [this](const httplib::Request& req, httplib::Response& res) {
  544. handleLogout(req, res);
  545. });
  546. // Token validation endpoint
  547. m_httpServer->Get("/api/auth/validate", [this](const httplib::Request& req, httplib::Response& res) {
  548. handleValidateToken(req, res);
  549. });
  550. // Refresh token endpoint
  551. m_httpServer->Post("/api/auth/refresh", [this](const httplib::Request& req, httplib::Response& res) {
  552. handleRefreshToken(req, res);
  553. });
  554. // Get current user endpoint
  555. m_httpServer->Get("/api/auth/me", [this](const httplib::Request& req, httplib::Response& res) {
  556. handleGetCurrentUser(req, res);
  557. });
  558. }
  559. void Server::handleLogin(const httplib::Request& req, httplib::Response& res) {
  560. std::string requestId = generateRequestId();
  561. try {
  562. if (!m_userManager || !m_authMiddleware) {
  563. sendErrorResponse(res, "Authentication system not available", 500, "AUTH_UNAVAILABLE", requestId);
  564. return;
  565. }
  566. // Parse request body
  567. nlohmann::json requestJson;
  568. try {
  569. requestJson = nlohmann::json::parse(req.body);
  570. } catch (const nlohmann::json::parse_error& e) {
  571. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  572. return;
  573. }
  574. // Check if using Unix authentication
  575. if (m_authMiddleware->getConfig().authMethod == AuthMethod::UNIX) {
  576. // For Unix auth, get username and password from request body
  577. std::string username = requestJson.value("username", "");
  578. std::string password = requestJson.value("password", "");
  579. if (username.empty()) {
  580. sendErrorResponse(res, "Missing username", 400, "MISSING_USERNAME", requestId);
  581. return;
  582. }
  583. // Check if PAM is enabled - if so, password is required
  584. if (m_userManager->isPamAuthEnabled() && password.empty()) {
  585. sendErrorResponse(res, "Password is required for Unix authentication", 400, "MISSING_PASSWORD", requestId);
  586. return;
  587. }
  588. // Authenticate Unix user (with or without password depending on PAM)
  589. auto result = m_userManager->authenticateUnix(username, password);
  590. if (!result.success) {
  591. sendErrorResponse(res, result.errorMessage, 401, "UNIX_AUTH_FAILED", requestId);
  592. return;
  593. }
  594. // Generate simple token for Unix auth
  595. std::string token = "unix_token_" + std::to_string(std::chrono::duration_cast<std::chrono::seconds>(
  596. std::chrono::system_clock::now().time_since_epoch()).count()) + "_" + username;
  597. nlohmann::json response = {
  598. {"token", token},
  599. {"user", {
  600. {"id", result.userId},
  601. {"username", result.username},
  602. {"role", result.role},
  603. {"permissions", result.permissions}
  604. }},
  605. {"message", "Unix authentication successful"}
  606. };
  607. sendJsonResponse(res, response);
  608. return;
  609. }
  610. // For non-Unix auth, validate required fields
  611. if (!requestJson.contains("username") || !requestJson.contains("password")) {
  612. sendErrorResponse(res, "Missing username or password", 400, "MISSING_CREDENTIALS", requestId);
  613. return;
  614. }
  615. std::string username = requestJson["username"];
  616. std::string password = requestJson["password"];
  617. // Authenticate user
  618. auto result = m_userManager->authenticateUser(username, password);
  619. if (!result.success) {
  620. sendErrorResponse(res, result.errorMessage, 401, "INVALID_CREDENTIALS", requestId);
  621. return;
  622. }
  623. // Generate JWT token if using JWT auth
  624. std::string token;
  625. if (m_authMiddleware->getConfig().authMethod == AuthMethod::JWT) {
  626. // For now, create a simple token (in a real implementation, use JWT)
  627. token = "token_" + std::to_string(std::chrono::duration_cast<std::chrono::seconds>(
  628. std::chrono::system_clock::now().time_since_epoch()).count()) + "_" + username;
  629. }
  630. nlohmann::json response = {
  631. {"token", token},
  632. {"user", {
  633. {"id", result.userId},
  634. {"username", result.username},
  635. {"role", result.role},
  636. {"permissions", result.permissions}
  637. }},
  638. {"message", "Login successful"}
  639. };
  640. sendJsonResponse(res, response);
  641. } catch (const std::exception& e) {
  642. sendErrorResponse(res, std::string("Login failed: ") + e.what(), 500, "LOGIN_ERROR", requestId);
  643. }
  644. }
  645. void Server::handleLogout(const httplib::Request& /*req*/, httplib::Response& res) {
  646. std::string requestId = generateRequestId();
  647. try {
  648. // For now, just return success (in a real implementation, invalidate the token)
  649. nlohmann::json response = {
  650. {"message", "Logout successful"}
  651. };
  652. sendJsonResponse(res, response);
  653. } catch (const std::exception& e) {
  654. sendErrorResponse(res, std::string("Logout failed: ") + e.what(), 500, "LOGOUT_ERROR", requestId);
  655. }
  656. }
  657. void Server::handleValidateToken(const httplib::Request& req, httplib::Response& res) {
  658. std::string requestId = generateRequestId();
  659. try {
  660. if (!m_userManager || !m_authMiddleware) {
  661. sendErrorResponse(res, "Authentication system not available", 500, "AUTH_UNAVAILABLE", requestId);
  662. return;
  663. }
  664. // Extract token from header
  665. std::string authHeader = req.get_header_value("Authorization");
  666. if (authHeader.empty()) {
  667. sendErrorResponse(res, "Missing authorization token", 401, "MISSING_TOKEN", requestId);
  668. return;
  669. }
  670. // Simple token validation (in a real implementation, validate JWT)
  671. // For now, just check if it starts with "token_"
  672. if (authHeader.find("Bearer ") != 0) {
  673. sendErrorResponse(res, "Invalid authorization header format", 401, "INVALID_HEADER", requestId);
  674. return;
  675. }
  676. std::string token = authHeader.substr(7); // Remove "Bearer "
  677. if (token.find("token_") != 0) {
  678. sendErrorResponse(res, "Invalid token", 401, "INVALID_TOKEN", requestId);
  679. return;
  680. }
  681. // Extract username from token (simple format: token_timestamp_username)
  682. size_t last_underscore = token.find_last_of('_');
  683. if (last_underscore == std::string::npos) {
  684. sendErrorResponse(res, "Invalid token format", 401, "INVALID_TOKEN", requestId);
  685. return;
  686. }
  687. std::string username = token.substr(last_underscore + 1);
  688. // Get user info
  689. auto userInfo = m_userManager->getUserInfoByUsername(username);
  690. if (userInfo.id.empty()) {
  691. sendErrorResponse(res, "User not found", 401, "USER_NOT_FOUND", requestId);
  692. return;
  693. }
  694. nlohmann::json response = {
  695. {"user", {
  696. {"id", userInfo.id},
  697. {"username", userInfo.username},
  698. {"role", userInfo.role},
  699. {"permissions", userInfo.permissions}
  700. }},
  701. {"valid", true}
  702. };
  703. sendJsonResponse(res, response);
  704. } catch (const std::exception& e) {
  705. sendErrorResponse(res, std::string("Token validation failed: ") + e.what(), 500, "VALIDATION_ERROR", requestId);
  706. }
  707. }
  708. void Server::handleRefreshToken(const httplib::Request& /*req*/, httplib::Response& res) {
  709. std::string requestId = generateRequestId();
  710. try {
  711. // For now, just return a new token (in a real implementation, refresh JWT)
  712. nlohmann::json response = {
  713. {"token", "new_token_" + std::to_string(std::chrono::duration_cast<std::chrono::seconds>(
  714. std::chrono::system_clock::now().time_since_epoch()).count())},
  715. {"message", "Token refreshed successfully"}
  716. };
  717. sendJsonResponse(res, response);
  718. } catch (const std::exception& e) {
  719. sendErrorResponse(res, std::string("Token refresh failed: ") + e.what(), 500, "REFRESH_ERROR", requestId);
  720. }
  721. }
  722. void Server::handleGetCurrentUser(const httplib::Request& req, httplib::Response& res) {
  723. std::string requestId = generateRequestId();
  724. try {
  725. if (!m_userManager || !m_authMiddleware) {
  726. sendErrorResponse(res, "Authentication system not available", 500, "AUTH_UNAVAILABLE", requestId);
  727. return;
  728. }
  729. // Authenticate the request
  730. AuthContext authContext = m_authMiddleware->authenticate(req, res);
  731. if (!authContext.authenticated) {
  732. sendErrorResponse(res, "Authentication required", 401, "AUTH_REQUIRED", requestId);
  733. return;
  734. }
  735. nlohmann::json response = {
  736. {"user", {
  737. {"id", authContext.userId},
  738. {"username", authContext.username},
  739. {"role", authContext.role},
  740. {"permissions", authContext.permissions}
  741. }}
  742. };
  743. sendJsonResponse(res, response);
  744. } catch (const std::exception& e) {
  745. sendErrorResponse(res, std::string("Get current user failed: ") + e.what(), 500, "USER_ERROR", requestId);
  746. }
  747. }
  748. void Server::setupCORS() {
  749. // Use post-routing handler to set CORS headers after the response is generated
  750. // This ensures we don't duplicate headers that may be set by other handlers
  751. m_httpServer->set_post_routing_handler([](const httplib::Request& /*req*/, httplib::Response& res) {
  752. // Only add CORS headers if they haven't been set already
  753. if (!res.has_header("Access-Control-Allow-Origin")) {
  754. res.set_header("Access-Control-Allow-Origin", "*");
  755. }
  756. if (!res.has_header("Access-Control-Allow-Methods")) {
  757. res.set_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS");
  758. }
  759. if (!res.has_header("Access-Control-Allow-Headers")) {
  760. res.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
  761. }
  762. });
  763. // Handle OPTIONS requests for CORS preflight (API endpoints only)
  764. m_httpServer->Options("/api/.*", [](const httplib::Request&, httplib::Response& res) {
  765. res.set_header("Access-Control-Allow-Origin", "*");
  766. res.set_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS");
  767. res.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
  768. res.status = 200;
  769. });
  770. }
  771. void Server::handleHealthCheck(const httplib::Request& /*req*/, httplib::Response& res) {
  772. try {
  773. nlohmann::json response = {
  774. {"status", "healthy"},
  775. {"timestamp", std::chrono::duration_cast<std::chrono::seconds>(
  776. std::chrono::system_clock::now().time_since_epoch()).count()},
  777. {"version", "1.0.0"}
  778. };
  779. sendJsonResponse(res, response);
  780. } catch (const std::exception& e) {
  781. sendErrorResponse(res, std::string("Health check failed: ") + e.what(), 500);
  782. }
  783. }
  784. void Server::handleApiStatus(const httplib::Request& /*req*/, httplib::Response& res) {
  785. try {
  786. nlohmann::json response = {
  787. {"server", {
  788. {"running", m_isRunning.load()},
  789. {"host", m_host},
  790. {"port", m_port}
  791. }},
  792. {"generation_queue", {
  793. {"running", m_generationQueue ? m_generationQueue->isRunning() : false},
  794. {"queue_size", m_generationQueue ? m_generationQueue->getQueueSize() : 0},
  795. {"active_generations", m_generationQueue ? m_generationQueue->getActiveGenerations() : 0}
  796. }},
  797. {"models", {
  798. {"loaded_count", m_modelManager ? m_modelManager->getLoadedModelsCount() : 0},
  799. {"available_count", m_modelManager ? m_modelManager->getAvailableModelsCount() : 0}
  800. }}
  801. };
  802. sendJsonResponse(res, response);
  803. } catch (const std::exception& e) {
  804. sendErrorResponse(res, std::string("Status check failed: ") + e.what(), 500);
  805. }
  806. }
  807. void Server::handleModelsList(const httplib::Request& req, httplib::Response& res) {
  808. std::string requestId = generateRequestId();
  809. try {
  810. if (!m_modelManager) {
  811. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  812. return;
  813. }
  814. // Parse query parameters for enhanced filtering
  815. std::string typeFilter = req.get_param_value("type");
  816. std::string searchQuery = req.get_param_value("search");
  817. std::string sortBy = req.get_param_value("sort_by");
  818. std::string sortOrder = req.get_param_value("sort_order");
  819. std::string dateFilter = req.get_param_value("date");
  820. std::string sizeFilter = req.get_param_value("size");
  821. // Pagination parameters - only apply if limit is explicitly provided
  822. int page = 1;
  823. int limit = 50;
  824. bool usePagination = false;
  825. try {
  826. if (!req.get_param_value("limit").empty()) {
  827. limit = std::stoi(req.get_param_value("limit"));
  828. // Special case: limit<=0 means return all models (no pagination)
  829. if (limit <= 0) {
  830. usePagination = false;
  831. limit = INT_MAX; // Set to very large number to effectively disable pagination
  832. } else {
  833. usePagination = true;
  834. if (!req.get_param_value("page").empty()) {
  835. page = std::stoi(req.get_param_value("page"));
  836. if (page < 1) page = 1;
  837. }
  838. }
  839. }
  840. } catch (const std::exception& e) {
  841. sendErrorResponse(res, "Invalid pagination parameters", 400, "INVALID_PAGINATION", requestId);
  842. return;
  843. }
  844. // Filter parameters
  845. bool includeLoaded = req.get_param_value("loaded") == "true";
  846. bool includeUnloaded = req.get_param_value("unloaded") == "true";
  847. (void)req.get_param_value("include_metadata"); // unused but kept for API compatibility
  848. (void)req.get_param_value("include_thumbnails"); // unused but kept for API compatibility
  849. // Get all models
  850. auto allModels = m_modelManager->getAllModels();
  851. nlohmann::json models = nlohmann::json::array();
  852. // Apply filters and build response
  853. for (const auto& pair : allModels) {
  854. const auto& modelInfo = pair.second;
  855. // Apply type filter
  856. if (!typeFilter.empty()) {
  857. ModelType filterType = ModelManager::stringToModelType(typeFilter);
  858. if (modelInfo.type != filterType) continue;
  859. }
  860. // Apply loaded/unloaded filters
  861. if (includeLoaded && !modelInfo.isLoaded) continue;
  862. if (includeUnloaded && modelInfo.isLoaded) continue;
  863. // Apply search filter (case-insensitive search in name and description)
  864. if (!searchQuery.empty()) {
  865. std::string searchLower = searchQuery;
  866. std::transform(searchLower.begin(), searchLower.end(), searchLower.begin(), ::tolower);
  867. std::string nameLower = modelInfo.name;
  868. std::transform(nameLower.begin(), nameLower.end(), nameLower.begin(), ::tolower);
  869. std::string descLower = modelInfo.description;
  870. std::transform(descLower.begin(), descLower.end(), descLower.begin(), ::tolower);
  871. if (nameLower.find(searchLower) == std::string::npos &&
  872. descLower.find(searchLower) == std::string::npos) {
  873. continue;
  874. }
  875. }
  876. // Apply date filter (simplified - expects "recent", "old", or YYYY-MM-DD)
  877. if (!dateFilter.empty()) {
  878. auto now = std::filesystem::file_time_type::clock::now();
  879. auto modelTime = modelInfo.modifiedAt;
  880. auto duration = std::chrono::duration_cast<std::chrono::hours>(now - modelTime).count();
  881. if (dateFilter == "recent" && duration > 24 * 7) continue; // Older than 1 week
  882. if (dateFilter == "old" && duration < 24 * 30) continue; // Newer than 1 month
  883. }
  884. // Apply size filter (expects "small", "medium", "large", or size in MB)
  885. if (!sizeFilter.empty()) {
  886. double sizeMB = modelInfo.fileSize / (1024.0 * 1024.0);
  887. if (sizeFilter == "small" && sizeMB > 1024) continue; // > 1GB
  888. if (sizeFilter == "medium" && (sizeMB < 1024 || sizeMB > 4096)) continue; // < 1GB or > 4GB
  889. if (sizeFilter == "large" && sizeMB < 4096) continue; // < 4GB
  890. // Try to parse as specific size in MB
  891. try {
  892. double maxSizeMB = std::stod(sizeFilter);
  893. if (sizeMB > maxSizeMB) continue;
  894. } catch (...) {
  895. // Ignore if parsing fails
  896. }
  897. }
  898. // Build model JSON with only essential information
  899. nlohmann::json modelJson = {
  900. {"name", modelInfo.name},
  901. {"type", ModelManager::modelTypeToString(modelInfo.type)},
  902. {"file_size", modelInfo.fileSize},
  903. {"file_size_mb", modelInfo.fileSize / (1024.0 * 1024.0)},
  904. {"sha256", modelInfo.sha256.empty() ? nullptr : nlohmann::json(modelInfo.sha256)},
  905. {"sha256_short", (modelInfo.sha256.empty() || modelInfo.sha256.length() < 10) ? nullptr : nlohmann::json(modelInfo.sha256.substr(0, 10))}
  906. };
  907. // Add architecture information if available (checkpoints only)
  908. if (!modelInfo.architecture.empty()) {
  909. modelJson["architecture"] = modelInfo.architecture;
  910. modelJson["recommended_vae"] = modelInfo.recommendedVAE.empty() ? nullptr : nlohmann::json(modelInfo.recommendedVAE);
  911. if (modelInfo.recommendedWidth > 0) {
  912. modelJson["recommended_width"] = modelInfo.recommendedWidth;
  913. }
  914. if (modelInfo.recommendedHeight > 0) {
  915. modelJson["recommended_height"] = modelInfo.recommendedHeight;
  916. }
  917. if (modelInfo.recommendedSteps > 0) {
  918. modelJson["recommended_steps"] = modelInfo.recommendedSteps;
  919. }
  920. if (!modelInfo.recommendedSampler.empty()) {
  921. modelJson["recommended_sampler"] = modelInfo.recommendedSampler;
  922. }
  923. if (!modelInfo.requiredModels.empty()) {
  924. modelJson["required_models"] = modelInfo.requiredModels;
  925. }
  926. if (!modelInfo.missingModels.empty()) {
  927. modelJson["missing_models"] = modelInfo.missingModels;
  928. modelJson["has_missing_dependencies"] = true;
  929. } else {
  930. modelJson["has_missing_dependencies"] = false;
  931. }
  932. }
  933. models.push_back(modelJson);
  934. }
  935. // Apply sorting
  936. if (!sortBy.empty()) {
  937. std::sort(models.begin(), models.end(), [&sortBy, &sortOrder](const nlohmann::json& a, const nlohmann::json& b) {
  938. bool ascending = sortOrder != "desc";
  939. if (sortBy == "name") {
  940. return ascending ? a["name"] < b["name"] : a["name"] > b["name"];
  941. } else if (sortBy == "size") {
  942. return ascending ? a["file_size"] < b["file_size"] : a["file_size"] > b["file_size"];
  943. } else if (sortBy == "date") {
  944. return ascending ? a["last_modified"] < b["last_modified"] : a["last_modified"] > b["last_modified"];
  945. } else if (sortBy == "type") {
  946. return ascending ? a["type"] < b["type"] : a["type"] > b["type"];
  947. } else if (sortBy == "loaded") {
  948. return ascending ? a["is_loaded"] < b["is_loaded"] : a["is_loaded"] > b["is_loaded"];
  949. }
  950. return false;
  951. });
  952. }
  953. // Apply pagination only if limit parameter was provided
  954. int totalCount = models.size();
  955. nlohmann::json paginatedModels = nlohmann::json::array();
  956. nlohmann::json paginationInfo = nlohmann::json::object();
  957. if (usePagination) {
  958. // Apply pagination
  959. int totalPages = (totalCount + limit - 1) / limit;
  960. int startIndex = (page - 1) * limit;
  961. int endIndex = std::min(startIndex + limit, totalCount);
  962. for (int i = startIndex; i < endIndex; ++i) {
  963. paginatedModels.push_back(models[i]);
  964. }
  965. paginationInfo = {
  966. {"page", page},
  967. {"limit", limit},
  968. {"total_count", totalCount},
  969. {"total_pages", totalPages},
  970. {"has_next", page < totalPages},
  971. {"has_prev", page > 1}
  972. };
  973. } else {
  974. // Return all models without pagination
  975. paginatedModels = models;
  976. paginationInfo = {
  977. {"page", 1},
  978. {"limit", totalCount},
  979. {"total_count", totalCount},
  980. {"total_pages", 1},
  981. {"has_next", false},
  982. {"has_prev", false}
  983. };
  984. }
  985. // Build comprehensive response
  986. nlohmann::json response = {
  987. {"models", paginatedModels},
  988. {"pagination", paginationInfo},
  989. {"filters_applied", {
  990. {"type", typeFilter.empty() ? nlohmann::json(nullptr) : nlohmann::json(typeFilter)},
  991. {"search", searchQuery.empty() ? nlohmann::json(nullptr) : nlohmann::json(searchQuery)},
  992. {"date", dateFilter.empty() ? nlohmann::json(nullptr) : nlohmann::json(dateFilter)},
  993. {"size", sizeFilter.empty() ? nlohmann::json(nullptr) : nlohmann::json(sizeFilter)},
  994. {"loaded", includeLoaded ? nlohmann::json(true) : nlohmann::json(nullptr)},
  995. {"unloaded", includeUnloaded ? nlohmann::json(true) : nlohmann::json(nullptr)}
  996. }},
  997. {"sorting", {
  998. {"sort_by", sortBy.empty() ? "name" : nlohmann::json(sortBy)},
  999. {"sort_order", sortOrder.empty() ? "asc" : nlohmann::json(sortOrder)}
  1000. }},
  1001. {"statistics", {
  1002. {"loaded_count", m_modelManager->getLoadedModelsCount()},
  1003. {"available_count", m_modelManager->getAvailableModelsCount()}
  1004. }},
  1005. {"request_id", requestId}
  1006. };
  1007. sendJsonResponse(res, response);
  1008. } catch (const std::exception& e) {
  1009. sendErrorResponse(res, std::string("Failed to list models: ") + e.what(), 500, "MODEL_LIST_ERROR", requestId);
  1010. }
  1011. }
  1012. void Server::handleQueueStatus(const httplib::Request& /*req*/, httplib::Response& res) {
  1013. try {
  1014. if (!m_generationQueue) {
  1015. sendErrorResponse(res, "Generation queue not available", 500);
  1016. return;
  1017. }
  1018. // Get detailed queue status
  1019. auto jobs = m_generationQueue->getQueueStatus();
  1020. // Convert jobs to JSON
  1021. nlohmann::json jobsJson = nlohmann::json::array();
  1022. for (const auto& job : jobs) {
  1023. std::string statusStr;
  1024. switch (job.status) {
  1025. case GenerationStatus::QUEUED: statusStr = "queued"; break;
  1026. case GenerationStatus::PROCESSING: statusStr = "processing"; break;
  1027. case GenerationStatus::COMPLETED: statusStr = "completed"; break;
  1028. case GenerationStatus::FAILED: statusStr = "failed"; break;
  1029. }
  1030. // Convert time points to timestamps
  1031. auto queuedTime = std::chrono::duration_cast<std::chrono::milliseconds>(
  1032. job.queuedTime.time_since_epoch()).count();
  1033. auto startTime = std::chrono::duration_cast<std::chrono::milliseconds>(
  1034. job.startTime.time_since_epoch()).count();
  1035. auto endTime = std::chrono::duration_cast<std::chrono::milliseconds>(
  1036. job.endTime.time_since_epoch()).count();
  1037. jobsJson.push_back({
  1038. {"id", job.id},
  1039. {"status", statusStr},
  1040. {"prompt", job.prompt},
  1041. {"queued_time", queuedTime},
  1042. {"start_time", startTime > 0 ? nlohmann::json(startTime) : nlohmann::json(nullptr)},
  1043. {"end_time", endTime > 0 ? nlohmann::json(endTime) : nlohmann::json(nullptr)},
  1044. {"position", job.position},
  1045. {"progress", job.progress}
  1046. });
  1047. }
  1048. nlohmann::json response = {
  1049. {"queue", {
  1050. {"size", m_generationQueue->getQueueSize()},
  1051. {"active_generations", m_generationQueue->getActiveGenerations()},
  1052. {"running", m_generationQueue->isRunning()},
  1053. {"jobs", jobsJson}
  1054. }}
  1055. };
  1056. sendJsonResponse(res, response);
  1057. } catch (const std::exception& e) {
  1058. sendErrorResponse(res, std::string("Queue status check failed: ") + e.what(), 500);
  1059. }
  1060. }
  1061. void Server::handleJobStatus(const httplib::Request& req, httplib::Response& res) {
  1062. try {
  1063. if (!m_generationQueue) {
  1064. sendErrorResponse(res, "Generation queue not available", 500);
  1065. return;
  1066. }
  1067. // Extract job ID from URL path
  1068. std::string jobId = req.matches[1].str();
  1069. if (jobId.empty()) {
  1070. sendErrorResponse(res, "Missing job ID", 400);
  1071. return;
  1072. }
  1073. // Get job information
  1074. auto jobInfo = m_generationQueue->getJobInfo(jobId);
  1075. if (jobInfo.id.empty()) {
  1076. sendErrorResponse(res, "Job not found", 404);
  1077. return;
  1078. }
  1079. // Convert status to string
  1080. std::string statusStr;
  1081. switch (jobInfo.status) {
  1082. case GenerationStatus::QUEUED: statusStr = "queued"; break;
  1083. case GenerationStatus::PROCESSING: statusStr = "processing"; break;
  1084. case GenerationStatus::COMPLETED: statusStr = "completed"; break;
  1085. case GenerationStatus::FAILED: statusStr = "failed"; break;
  1086. }
  1087. // Convert time points to timestamps
  1088. auto queuedTime = std::chrono::duration_cast<std::chrono::milliseconds>(
  1089. jobInfo.queuedTime.time_since_epoch()).count();
  1090. auto startTime = std::chrono::duration_cast<std::chrono::milliseconds>(
  1091. jobInfo.startTime.time_since_epoch()).count();
  1092. auto endTime = std::chrono::duration_cast<std::chrono::milliseconds>(
  1093. jobInfo.endTime.time_since_epoch()).count();
  1094. // Create download URLs for output files
  1095. nlohmann::json outputUrls = nlohmann::json::array();
  1096. for (const auto& filePath : jobInfo.outputFiles) {
  1097. // Extract filename from full path
  1098. std::filesystem::path p(filePath);
  1099. std::string filename = p.filename().string();
  1100. // Create download URL
  1101. std::string url = "/api/queue/job/" + jobInfo.id + "/output/" + filename;
  1102. nlohmann::json fileInfo = {
  1103. {"filename", filename},
  1104. {"url", url},
  1105. {"path", filePath}
  1106. };
  1107. outputUrls.push_back(fileInfo);
  1108. }
  1109. nlohmann::json response = {
  1110. {"job", {
  1111. {"id", jobInfo.id},
  1112. {"status", statusStr},
  1113. {"prompt", jobInfo.prompt},
  1114. {"queued_time", queuedTime},
  1115. {"start_time", startTime > 0 ? nlohmann::json(startTime) : nlohmann::json(nullptr)},
  1116. {"end_time", endTime > 0 ? nlohmann::json(endTime) : nlohmann::json(nullptr)},
  1117. {"position", jobInfo.position},
  1118. {"outputs", outputUrls},
  1119. {"error_message", jobInfo.errorMessage},
  1120. {"progress", jobInfo.progress}
  1121. }}
  1122. };
  1123. sendJsonResponse(res, response);
  1124. } catch (const std::exception& e) {
  1125. sendErrorResponse(res, std::string("Job status check failed: ") + e.what(), 500);
  1126. }
  1127. }
  1128. void Server::handleCancelJob(const httplib::Request& req, httplib::Response& res) {
  1129. try {
  1130. if (!m_generationQueue) {
  1131. sendErrorResponse(res, "Generation queue not available", 500);
  1132. return;
  1133. }
  1134. // Parse JSON request body
  1135. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  1136. // Validate required fields
  1137. if (!requestJson.contains("job_id") || !requestJson["job_id"].is_string()) {
  1138. sendErrorResponse(res, "Missing or invalid 'job_id' field", 400);
  1139. return;
  1140. }
  1141. std::string jobId = requestJson["job_id"];
  1142. // Try to cancel the job
  1143. bool cancelled = m_generationQueue->cancelJob(jobId);
  1144. if (cancelled) {
  1145. nlohmann::json response = {
  1146. {"status", "success"},
  1147. {"message", "Job cancelled successfully"},
  1148. {"job_id", jobId}
  1149. };
  1150. sendJsonResponse(res, response);
  1151. } else {
  1152. nlohmann::json response = {
  1153. {"status", "error"},
  1154. {"message", "Job not found or already processing"},
  1155. {"job_id", jobId}
  1156. };
  1157. sendJsonResponse(res, response, 404);
  1158. }
  1159. } catch (const nlohmann::json::parse_error& e) {
  1160. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400);
  1161. } catch (const std::exception& e) {
  1162. sendErrorResponse(res, std::string("Job cancellation failed: ") + e.what(), 500);
  1163. }
  1164. }
  1165. void Server::handleClearQueue(const httplib::Request& /*req*/, httplib::Response& res) {
  1166. try {
  1167. if (!m_generationQueue) {
  1168. sendErrorResponse(res, "Generation queue not available", 500);
  1169. return;
  1170. }
  1171. // Clear the queue
  1172. m_generationQueue->clearQueue();
  1173. nlohmann::json response = {
  1174. {"status", "success"},
  1175. {"message", "Queue cleared successfully"}
  1176. };
  1177. sendJsonResponse(res, response);
  1178. } catch (const std::exception& e) {
  1179. sendErrorResponse(res, std::string("Queue clear failed: ") + e.what(), 500);
  1180. }
  1181. }
  1182. void Server::handleDownloadOutput(const httplib::Request& req, httplib::Response& res) {
  1183. try {
  1184. // Extract job ID and filename from URL path
  1185. if (req.matches.size() < 3) {
  1186. sendErrorResponse(res, "Invalid request: job ID and filename required", 400, "INVALID_REQUEST", "");
  1187. return;
  1188. }
  1189. std::string jobId = req.matches[1];
  1190. std::string filename = req.matches[2];
  1191. // Validate inputs
  1192. if (jobId.empty() || filename.empty()) {
  1193. sendErrorResponse(res, "Job ID and filename cannot be empty", 400, "INVALID_PARAMETERS", "");
  1194. return;
  1195. }
  1196. // Construct absolute file path using the same logic as when saving:
  1197. // {outputDir}/{jobId}/{filename}
  1198. std::string fullPath = std::filesystem::absolute(m_outputDir + "/" + jobId + "/" + filename).string();
  1199. // Log the request for debugging
  1200. std::cout << "Image download request: jobId=" << jobId << ", filename=" << filename
  1201. << ", fullPath=" << fullPath << std::endl;
  1202. // Check if file exists
  1203. if (!std::filesystem::exists(fullPath)) {
  1204. std::cerr << "Output file not found: " << fullPath << std::endl;
  1205. sendErrorResponse(res, "Output file not found: " + filename, 404, "FILE_NOT_FOUND", "");
  1206. return;
  1207. }
  1208. // Check file size to detect zero-byte files
  1209. auto fileSize = std::filesystem::file_size(fullPath);
  1210. if (fileSize == 0) {
  1211. std::cerr << "Output file is zero bytes: " << fullPath << std::endl;
  1212. sendErrorResponse(res, "Output file is empty (corrupted generation)", 500, "EMPTY_FILE", "");
  1213. return;
  1214. }
  1215. // Check if file is accessible
  1216. std::ifstream file(fullPath, std::ios::binary);
  1217. if (!file.is_open()) {
  1218. std::cerr << "Failed to open output file: " << fullPath << std::endl;
  1219. sendErrorResponse(res, "Output file not accessible", 500, "FILE_ACCESS_ERROR", "");
  1220. return;
  1221. }
  1222. // Read file contents
  1223. std::string fileContent;
  1224. try {
  1225. fileContent = std::string(
  1226. std::istreambuf_iterator<char>(file),
  1227. std::istreambuf_iterator<char>()
  1228. );
  1229. file.close();
  1230. } catch (const std::exception& e) {
  1231. std::cerr << "Failed to read file content: " << e.what() << std::endl;
  1232. sendErrorResponse(res, "Failed to read file content", 500, "FILE_READ_ERROR", "");
  1233. return;
  1234. }
  1235. // Verify we actually read data
  1236. if (fileContent.empty()) {
  1237. std::cerr << "File content is empty after read: " << fullPath << std::endl;
  1238. sendErrorResponse(res, "File content is empty after read", 500, "EMPTY_CONTENT", "");
  1239. return;
  1240. }
  1241. // Determine content type based on file extension
  1242. std::string contentType = "application/octet-stream";
  1243. if (Utils::endsWith(filename, ".png")) {
  1244. contentType = "image/png";
  1245. } else if (Utils::endsWith(filename, ".jpg") || Utils::endsWith(filename, ".jpeg")) {
  1246. contentType = "image/jpeg";
  1247. } else if (Utils::endsWith(filename, ".mp4")) {
  1248. contentType = "video/mp4";
  1249. } else if (Utils::endsWith(filename, ".gif")) {
  1250. contentType = "image/gif";
  1251. } else if (Utils::endsWith(filename, ".webp")) {
  1252. contentType = "image/webp";
  1253. }
  1254. // Set response headers for proper browser handling
  1255. res.set_header("Content-Type", contentType);
  1256. res.set_header("Content-Length", std::to_string(fileContent.length()));
  1257. res.set_header("Cache-Control", "public, max-age=3600"); // Cache for 1 hour
  1258. res.set_header("Access-Control-Allow-Origin", "*"); // CORS for image access
  1259. // Uncomment if you want to force download instead of inline display:
  1260. // res.set_header("Content-Disposition", "attachment; filename=\"" + filename + "\"");
  1261. // Set the content
  1262. res.set_content(fileContent, contentType);
  1263. res.status = 200;
  1264. std::cout << "Successfully served image: " << filename << " (" << fileContent.length() << " bytes)" << std::endl;
  1265. } catch (const std::exception& e) {
  1266. std::cerr << "Exception in handleDownloadOutput: " << e.what() << std::endl;
  1267. sendErrorResponse(res, std::string("Failed to download file: ") + e.what(), 500, "DOWNLOAD_ERROR", "");
  1268. }
  1269. }
  1270. void Server::handleJobOutput(const httplib::Request& req, httplib::Response& res) {
  1271. std::string requestId = generateRequestId();
  1272. try {
  1273. // Extract job ID from URL path
  1274. if (req.matches.size() < 2) {
  1275. sendErrorResponse(res, "Invalid request: job ID required", 400, "INVALID_REQUEST", requestId);
  1276. return;
  1277. }
  1278. std::string jobId = req.matches[1].str();
  1279. // Validate job ID
  1280. if (jobId.empty()) {
  1281. sendErrorResponse(res, "Job ID cannot be empty", 400, "INVALID_PARAMETERS", requestId);
  1282. return;
  1283. }
  1284. // Log the request for debugging
  1285. std::cout << "Job output request: jobId=" << jobId << std::endl;
  1286. // Get job information to check if it exists and is completed
  1287. if (!m_generationQueue) {
  1288. sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
  1289. return;
  1290. }
  1291. auto jobInfo = m_generationQueue->getJobInfo(jobId);
  1292. if (jobInfo.id.empty()) {
  1293. sendErrorResponse(res, "Job not found", 404, "JOB_NOT_FOUND", requestId);
  1294. return;
  1295. }
  1296. // Check if job is completed
  1297. if (jobInfo.status != GenerationStatus::COMPLETED) {
  1298. std::string statusStr;
  1299. switch (jobInfo.status) {
  1300. case GenerationStatus::QUEUED: statusStr = "queued"; break;
  1301. case GenerationStatus::PROCESSING: statusStr = "processing"; break;
  1302. case GenerationStatus::FAILED: statusStr = "failed"; break;
  1303. default: statusStr = "unknown"; break;
  1304. }
  1305. nlohmann::json response = {
  1306. {"error", {
  1307. {"message", "Job not completed yet"},
  1308. {"status_code", 400},
  1309. {"error_code", "JOB_NOT_COMPLETED"},
  1310. {"request_id", requestId},
  1311. {"timestamp", std::chrono::duration_cast<std::chrono::seconds>(
  1312. std::chrono::system_clock::now().time_since_epoch()).count()},
  1313. {"job_status", statusStr}
  1314. }}
  1315. };
  1316. sendJsonResponse(res, response, 400);
  1317. return;
  1318. }
  1319. // Check if job has output files
  1320. if (jobInfo.outputFiles.empty()) {
  1321. sendErrorResponse(res, "No output files found for completed job", 404, "NO_OUTPUT_FILES", requestId);
  1322. return;
  1323. }
  1324. // For simplicity, return the first output file
  1325. // In a more complex implementation, we could return all files or allow file selection
  1326. std::string firstOutputFile = jobInfo.outputFiles[0];
  1327. // Extract filename from full path
  1328. std::filesystem::path filePath(firstOutputFile);
  1329. std::string filename = filePath.filename().string();
  1330. // Construct absolute file path
  1331. std::string fullPath = std::filesystem::absolute(firstOutputFile).string();
  1332. // Check if file exists
  1333. if (!std::filesystem::exists(fullPath)) {
  1334. std::cerr << "Output file not found: " << fullPath << std::endl;
  1335. sendErrorResponse(res, "Output file not found: " + filename, 404, "FILE_NOT_FOUND", requestId);
  1336. return;
  1337. }
  1338. // Check file size to detect zero-byte files
  1339. auto fileSize = std::filesystem::file_size(fullPath);
  1340. if (fileSize == 0) {
  1341. std::cerr << "Output file is zero bytes: " << fullPath << std::endl;
  1342. sendErrorResponse(res, "Output file is empty (corrupted generation)", 500, "EMPTY_FILE", requestId);
  1343. return;
  1344. }
  1345. // Check if file is accessible
  1346. std::ifstream file(fullPath, std::ios::binary);
  1347. if (!file.is_open()) {
  1348. std::cerr << "Failed to open output file: " << fullPath << std::endl;
  1349. sendErrorResponse(res, "Output file not accessible", 500, "FILE_ACCESS_ERROR", requestId);
  1350. return;
  1351. }
  1352. // Read file contents
  1353. std::string fileContent;
  1354. try {
  1355. fileContent = std::string(
  1356. std::istreambuf_iterator<char>(file),
  1357. std::istreambuf_iterator<char>()
  1358. );
  1359. file.close();
  1360. } catch (const std::exception& e) {
  1361. std::cerr << "Failed to read file content: " << e.what() << std::endl;
  1362. sendErrorResponse(res, "Failed to read file content", 500, "FILE_READ_ERROR", requestId);
  1363. return;
  1364. }
  1365. // Verify we actually read data
  1366. if (fileContent.empty()) {
  1367. std::cerr << "File content is empty after read: " << fullPath << std::endl;
  1368. sendErrorResponse(res, "File content is empty after read", 500, "EMPTY_CONTENT", requestId);
  1369. return;
  1370. }
  1371. // Determine content type based on file extension
  1372. std::string contentType = "application/octet-stream";
  1373. if (Utils::endsWith(filename, ".png")) {
  1374. contentType = "image/png";
  1375. } else if (Utils::endsWith(filename, ".jpg") || Utils::endsWith(filename, ".jpeg")) {
  1376. contentType = "image/jpeg";
  1377. } else if (Utils::endsWith(filename, ".mp4")) {
  1378. contentType = "video/mp4";
  1379. } else if (Utils::endsWith(filename, ".gif")) {
  1380. contentType = "image/gif";
  1381. } else if (Utils::endsWith(filename, ".webp")) {
  1382. contentType = "image/webp";
  1383. }
  1384. // Set response headers for proper browser handling
  1385. res.set_header("Content-Type", contentType);
  1386. res.set_header("Content-Length", std::to_string(fileContent.length()));
  1387. res.set_header("Cache-Control", "public, max-age=3600"); // Cache for 1 hour
  1388. res.set_header("Access-Control-Allow-Origin", "*"); // CORS for image access
  1389. // Set additional metadata headers
  1390. res.set_header("X-Job-ID", jobId);
  1391. res.set_header("X-Filename", filename);
  1392. res.set_header("X-File-Size", std::to_string(fileSize));
  1393. // If there are multiple files, indicate this
  1394. if (jobInfo.outputFiles.size() > 1) {
  1395. res.set_header("X-Total-Files", std::to_string(jobInfo.outputFiles.size()));
  1396. res.set_header("X-File-Index", "1");
  1397. }
  1398. // Set the content
  1399. res.set_content(fileContent, contentType);
  1400. res.status = 200;
  1401. std::cout << "Successfully served job output: jobId=" << jobId
  1402. << ", filename=" << filename
  1403. << " (" << fileContent.length() << " bytes)" << std::endl;
  1404. } catch (const std::exception& e) {
  1405. std::cerr << "Exception in handleJobOutput: " << e.what() << std::endl;
  1406. sendErrorResponse(res, std::string("Failed to get job output: ") + e.what(), 500, "OUTPUT_ERROR", requestId);
  1407. }
  1408. }
  1409. void Server::handleImageResize(const httplib::Request& req, httplib::Response& res) {
  1410. std::string requestId = generateRequestId();
  1411. try {
  1412. // Parse JSON request body
  1413. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  1414. // Validate required fields
  1415. if (!requestJson.contains("image") || !requestJson["image"].is_string()) {
  1416. sendErrorResponse(res, "Missing or invalid 'image' field", 400, "INVALID_PARAMETERS", requestId);
  1417. return;
  1418. }
  1419. if (!requestJson.contains("width") || !requestJson["width"].is_number_integer()) {
  1420. sendErrorResponse(res, "Missing or invalid 'width' field", 400, "INVALID_PARAMETERS", requestId);
  1421. return;
  1422. }
  1423. if (!requestJson.contains("height") || !requestJson["height"].is_number_integer()) {
  1424. sendErrorResponse(res, "Missing or invalid 'height' field", 400, "INVALID_PARAMETERS", requestId);
  1425. return;
  1426. }
  1427. std::string imageInput = requestJson["image"];
  1428. int targetWidth = requestJson["width"];
  1429. int targetHeight = requestJson["height"];
  1430. // Validate dimensions
  1431. if (targetWidth < 1 || targetWidth > 4096) {
  1432. sendErrorResponse(res, "Width must be between 1 and 4096", 400, "INVALID_DIMENSIONS", requestId);
  1433. return;
  1434. }
  1435. if (targetHeight < 1 || targetHeight > 4096) {
  1436. sendErrorResponse(res, "Height must be between 1 and 4096", 400, "INVALID_DIMENSIONS", requestId);
  1437. return;
  1438. }
  1439. // Load the source image
  1440. auto [imageData, sourceWidth, sourceHeight, sourceChannels, success, loadError] = loadImageFromInput(imageInput);
  1441. if (!success) {
  1442. sendErrorResponse(res, "Failed to load image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId);
  1443. return;
  1444. }
  1445. // Convert image data to stb_image format for processing
  1446. int channels = 3; // Force RGB
  1447. size_t sourceSize = sourceWidth * sourceHeight * channels;
  1448. std::vector<uint8_t> sourcePixels(sourceSize);
  1449. std::memcpy(sourcePixels.data(), imageData.data(), std::min(imageData.size(), sourceSize));
  1450. // Resize the image using stb_image_resize if available, otherwise use simple scaling
  1451. std::vector<uint8_t> resizedPixels(targetWidth * targetHeight * channels);
  1452. // Simple nearest-neighbor scaling for now (can be improved with better algorithms)
  1453. float xScale = static_cast<float>(sourceWidth) / targetWidth;
  1454. float yScale = static_cast<float>(sourceHeight) / targetHeight;
  1455. for (int y = 0; y < targetHeight; y++) {
  1456. for (int x = 0; x < targetWidth; x++) {
  1457. int sourceX = static_cast<int>(x * xScale);
  1458. int sourceY = static_cast<int>(y * yScale);
  1459. // Clamp to source bounds
  1460. sourceX = std::min(sourceX, sourceWidth - 1);
  1461. sourceY = std::min(sourceY, sourceHeight - 1);
  1462. for (int c = 0; c < channels; c++) {
  1463. resizedPixels[(y * targetWidth + x) * channels + c] =
  1464. sourcePixels[(sourceY * sourceWidth + sourceX) * channels + c];
  1465. }
  1466. }
  1467. }
  1468. // Convert resized image to base64
  1469. std::string base64Data = Utils::base64Encode(resizedPixels);
  1470. // Determine MIME type based on input
  1471. std::string mimeType = "image/jpeg"; // default
  1472. if (Utils::startsWith(imageInput, "data:image/png")) {
  1473. mimeType = "image/png";
  1474. } else if (Utils::startsWith(imageInput, "data:image/gif")) {
  1475. mimeType = "image/gif";
  1476. } else if (Utils::startsWith(imageInput, "data:image/webp")) {
  1477. mimeType = "image/webp";
  1478. } else if (Utils::startsWith(imageInput, "data:image/bmp")) {
  1479. mimeType = "image/bmp";
  1480. }
  1481. // Create data URL format
  1482. std::string dataUrl = "data:" + mimeType + ";base64," + base64Data;
  1483. // Build response
  1484. nlohmann::json response = {
  1485. {"success", true},
  1486. {"original_width", sourceWidth},
  1487. {"original_height", sourceHeight},
  1488. {"resized_width", targetWidth},
  1489. {"resized_height", targetHeight},
  1490. {"mime_type", mimeType},
  1491. {"base64_data", dataUrl},
  1492. {"file_size_bytes", resizedPixels.size()},
  1493. {"request_id", requestId}
  1494. };
  1495. sendJsonResponse(res, response, 200);
  1496. std::cout << "Successfully resized image from " << sourceWidth << "x" << sourceHeight
  1497. << " to " << targetWidth << "x" << targetHeight
  1498. << " (" << resizedPixels.size() << " bytes)" << std::endl;
  1499. } catch (const nlohmann::json::parse_error& e) {
  1500. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  1501. } catch (const std::exception& e) {
  1502. std::cerr << "Exception in handleImageResize: " << e.what() << std::endl;
  1503. sendErrorResponse(res, std::string("Failed to resize image: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  1504. }
  1505. }
  1506. void Server::handleImageCrop(const httplib::Request& req, httplib::Response& res) {
  1507. std::string requestId = generateRequestId();
  1508. try {
  1509. // Parse JSON request body
  1510. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  1511. // Validate required fields
  1512. if (!requestJson.contains("image") || !requestJson["image"].is_string()) {
  1513. sendErrorResponse(res, "Missing or invalid 'image' field", 400, "INVALID_PARAMETERS", requestId);
  1514. return;
  1515. }
  1516. if (!requestJson.contains("x") || !requestJson["x"].is_number_integer()) {
  1517. sendErrorResponse(res, "Missing or invalid 'x' field", 400, "INVALID_PARAMETERS", requestId);
  1518. return;
  1519. }
  1520. if (!requestJson.contains("y") || !requestJson["y"].is_number_integer()) {
  1521. sendErrorResponse(res, "Missing or invalid 'y' field", 400, "INVALID_PARAMETERS", requestId);
  1522. return;
  1523. }
  1524. if (!requestJson.contains("width") || !requestJson["width"].is_number_integer()) {
  1525. sendErrorResponse(res, "Missing or invalid 'width' field", 400, "INVALID_PARAMETERS", requestId);
  1526. return;
  1527. }
  1528. if (!requestJson.contains("height") || !requestJson["height"].is_number_integer()) {
  1529. sendErrorResponse(res, "Missing or invalid 'height' field", 400, "INVALID_PARAMETERS", requestId);
  1530. return;
  1531. }
  1532. std::string imageInput = requestJson["image"];
  1533. int cropX = requestJson["x"];
  1534. int cropY = requestJson["y"];
  1535. int cropWidth = requestJson["width"];
  1536. int cropHeight = requestJson["height"];
  1537. // Load the source image
  1538. auto [imageData, sourceWidth, sourceHeight, sourceChannels, success, loadError] = loadImageFromInput(imageInput);
  1539. if (!success) {
  1540. sendErrorResponse(res, "Failed to load image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId);
  1541. return;
  1542. }
  1543. // Validate crop dimensions
  1544. if (cropX < 0 || cropY < 0) {
  1545. sendErrorResponse(res, "Crop coordinates must be non-negative", 400, "INVALID_CROP_AREA", requestId);
  1546. return;
  1547. }
  1548. if (cropX + cropWidth > sourceWidth || cropY + cropHeight > sourceHeight) {
  1549. sendErrorResponse(res, "Crop area exceeds image dimensions", 400, "INVALID_CROP_AREA", requestId);
  1550. return;
  1551. }
  1552. if (cropWidth < 1 || cropHeight < 1) {
  1553. sendErrorResponse(res, "Crop width and height must be at least 1", 400, "INVALID_CROP_AREA", requestId);
  1554. return;
  1555. }
  1556. // Convert image data to stb_image format for processing
  1557. int channels = 3; // Force RGB
  1558. size_t sourceSize = sourceWidth * sourceHeight * channels;
  1559. std::vector<uint8_t> sourcePixels(sourceSize);
  1560. std::memcpy(sourcePixels.data(), imageData.data(), std::min(imageData.size(), sourceSize));
  1561. // Crop the image
  1562. std::vector<uint8_t> croppedPixels(cropWidth * cropHeight * channels);
  1563. for (int y = 0; y < cropHeight; y++) {
  1564. for (int x = 0; x < cropWidth; x++) {
  1565. int sourceX = cropX + x;
  1566. int sourceY = cropY + y;
  1567. for (int c = 0; c < channels; c++) {
  1568. croppedPixels[(y * cropWidth + x) * channels + c] =
  1569. sourcePixels[(sourceY * sourceWidth + sourceX) * channels + c];
  1570. }
  1571. }
  1572. }
  1573. // Convert cropped image to base64
  1574. std::string base64Data = Utils::base64Encode(croppedPixels);
  1575. // Determine MIME type based on input
  1576. std::string mimeType = "image/jpeg"; // default
  1577. if (Utils::startsWith(imageInput, "data:image/png")) {
  1578. mimeType = "image/png";
  1579. } else if (Utils::startsWith(imageInput, "data:image/gif")) {
  1580. mimeType = "image/gif";
  1581. } else if (Utils::startsWith(imageInput, "data:image/webp")) {
  1582. mimeType = "image/webp";
  1583. } else if (Utils::startsWith(imageInput, "data:image/bmp")) {
  1584. mimeType = "image/bmp";
  1585. }
  1586. // Create data URL format
  1587. std::string dataUrl = "data:" + mimeType + ";base64," + base64Data;
  1588. // Build response
  1589. nlohmann::json response = {
  1590. {"success", true},
  1591. {"original_width", sourceWidth},
  1592. {"original_height", sourceHeight},
  1593. {"crop_x", cropX},
  1594. {"crop_y", cropY},
  1595. {"cropped_width", cropWidth},
  1596. {"cropped_height", cropHeight},
  1597. {"mime_type", mimeType},
  1598. {"base64_data", dataUrl},
  1599. {"file_size_bytes", croppedPixels.size()},
  1600. {"request_id", requestId}
  1601. };
  1602. sendJsonResponse(res, response, 200);
  1603. std::cout << "Successfully cropped image from " << sourceWidth << "x" << sourceHeight
  1604. << " to " << cropWidth << "x" << cropHeight
  1605. << " at (" << cropX << "," << cropY << ")"
  1606. << " (" << croppedPixels.size() << " bytes)" << std::endl;
  1607. } catch (const nlohmann::json::parse_error& e) {
  1608. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  1609. } catch (const std::exception& e) {
  1610. std::cerr << "Exception in handleImageCrop: " << e.what() << std::endl;
  1611. sendErrorResponse(res, std::string("Failed to crop image: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  1612. }
  1613. }
  1614. void Server::handleDownloadImageFromUrl(const httplib::Request& req, httplib::Response& res) {
  1615. std::string requestId = generateRequestId();
  1616. try {
  1617. // Parse query parameters
  1618. std::string imageUrl = req.get_param_value("url");
  1619. if (imageUrl.empty()) {
  1620. sendErrorResponse(res, "Missing 'url' parameter", 400, "MISSING_URL", requestId);
  1621. return;
  1622. }
  1623. // Basic URL format validation
  1624. if (!Utils::startsWith(imageUrl, "http://") && !Utils::startsWith(imageUrl, "https://")) {
  1625. sendErrorResponse(res, "Invalid URL format. URL must start with http:// or https://", 400, "INVALID_URL_FORMAT", requestId);
  1626. return;
  1627. }
  1628. // Extract filename from URL for content type detection
  1629. std::string filename = imageUrl;
  1630. size_t lastSlash = imageUrl.find_last_of('/');
  1631. if (lastSlash != std::string::npos) {
  1632. filename = imageUrl.substr(lastSlash + 1);
  1633. }
  1634. // Remove query parameters and fragments
  1635. size_t questionMark = filename.find('?');
  1636. if (questionMark != std::string::npos) {
  1637. filename = filename.substr(0, questionMark);
  1638. }
  1639. size_t hashMark = filename.find('#');
  1640. if (hashMark != std::string::npos) {
  1641. filename = filename.substr(0, hashMark);
  1642. }
  1643. // Check if URL has image extension
  1644. std::string extension;
  1645. size_t lastDot = filename.find_last_of('.');
  1646. if (lastDot != std::string::npos) {
  1647. extension = filename.substr(lastDot + 1);
  1648. std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower);
  1649. }
  1650. // Validate image extension
  1651. const std::vector<std::string> validExtensions = {"jpg", "jpeg", "png", "gif", "webp", "bmp"};
  1652. if (extension.empty() || std::find(validExtensions.begin(), validExtensions.end(), extension) == validExtensions.end()) {
  1653. sendErrorResponse(res, "URL must point to an image file with a valid extension: " +
  1654. std::accumulate(validExtensions.begin(), validExtensions.end(), std::string(),
  1655. [](const std::string& a, const std::string& b) {
  1656. return a.empty() ? b : a + ", " + b;
  1657. }), 400, "INVALID_IMAGE_EXTENSION", requestId);
  1658. return;
  1659. }
  1660. // Load image using existing loadImageFromInput function
  1661. auto [imageData, width, height, channels, success, error] = loadImageFromInput(imageUrl);
  1662. if (!success) {
  1663. sendErrorResponse(res, "Failed to download image from URL: " + error, 400, "IMAGE_DOWNLOAD_FAILED", requestId);
  1664. return;
  1665. }
  1666. // Convert image data to base64
  1667. std::string base64Data = Utils::base64Encode(imageData);
  1668. // Determine MIME type based on extension
  1669. std::string mimeType = "image/jpeg"; // default
  1670. if (extension == "png") {
  1671. mimeType = "image/png";
  1672. } else if (extension == "gif") {
  1673. mimeType = "image/gif";
  1674. } else if (extension == "webp") {
  1675. mimeType = "image/webp";
  1676. } else if (extension == "bmp") {
  1677. mimeType = "image/bmp";
  1678. } else if (extension == "jpg" || extension == "jpeg") {
  1679. mimeType = "image/jpeg";
  1680. }
  1681. // Create data URL format
  1682. std::string dataUrl = "data:" + mimeType + ";base64," + base64Data;
  1683. // Build response
  1684. nlohmann::json response = {
  1685. {"success", true},
  1686. {"url", imageUrl},
  1687. {"filename", filename},
  1688. {"width", width},
  1689. {"height", height},
  1690. {"channels", channels},
  1691. {"mime_type", mimeType},
  1692. {"base64_data", dataUrl},
  1693. {"file_size_bytes", imageData.size()},
  1694. {"request_id", requestId}
  1695. };
  1696. sendJsonResponse(res, response, 200);
  1697. std::cout << "Successfully downloaded and encoded image from URL: " << imageUrl
  1698. << " (" << width << "x" << height << ", " << imageData.size() << " bytes)" << std::endl;
  1699. } catch (const nlohmann::json::parse_error& e) {
  1700. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  1701. } catch (const std::exception& e) {
  1702. std::cerr << "Exception in handleDownloadImageFromUrl: " << e.what() << std::endl;
  1703. sendErrorResponse(res, std::string("Failed to download image from URL: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  1704. }
  1705. }
  1706. void Server::sendJsonResponse(httplib::Response& res, const nlohmann::json& json, int status_code) {
  1707. res.set_header("Content-Type", "application/json");
  1708. res.status = status_code;
  1709. res.body = json.dump();
  1710. }
  1711. void Server::sendErrorResponse(httplib::Response& res, const std::string& message, int status_code,
  1712. const std::string& error_code, const std::string& request_id) {
  1713. nlohmann::json errorResponse = {
  1714. {"error", {
  1715. {"message", message},
  1716. {"status_code", status_code},
  1717. {"error_code", error_code},
  1718. {"request_id", request_id},
  1719. {"timestamp", std::chrono::duration_cast<std::chrono::seconds>(
  1720. std::chrono::system_clock::now().time_since_epoch()).count()}
  1721. }}
  1722. };
  1723. sendJsonResponse(res, errorResponse, status_code);
  1724. }
  1725. std::pair<bool, std::string> Server::validateGenerationParameters(const nlohmann::json& params) {
  1726. // Validate required fields
  1727. if (!params.contains("prompt") || !params["prompt"].is_string()) {
  1728. return {false, "Missing or invalid 'prompt' field"};
  1729. }
  1730. const std::string& prompt = params["prompt"];
  1731. if (prompt.empty()) {
  1732. return {false, "Prompt cannot be empty"};
  1733. }
  1734. if (prompt.length() > 10000) {
  1735. return {false, "Prompt too long (max 10000 characters)"};
  1736. }
  1737. // Validate negative prompt if present
  1738. if (params.contains("negative_prompt")) {
  1739. if (!params["negative_prompt"].is_string()) {
  1740. return {false, "Invalid 'negative_prompt' field, must be string"};
  1741. }
  1742. if (params["negative_prompt"].get<std::string>().length() > 10000) {
  1743. return {false, "Negative prompt too long (max 10000 characters)"};
  1744. }
  1745. }
  1746. // Validate width
  1747. if (params.contains("width")) {
  1748. if (!params["width"].is_number_integer()) {
  1749. return {false, "Invalid 'width' field, must be integer"};
  1750. }
  1751. int width = params["width"];
  1752. if (width < 64 || width > 2048 || width % 64 != 0) {
  1753. return {false, "Width must be between 64 and 2048 and divisible by 64"};
  1754. }
  1755. }
  1756. // Validate height
  1757. if (params.contains("height")) {
  1758. if (!params["height"].is_number_integer()) {
  1759. return {false, "Invalid 'height' field, must be integer"};
  1760. }
  1761. int height = params["height"];
  1762. if (height < 64 || height > 2048 || height % 64 != 0) {
  1763. return {false, "Height must be between 64 and 2048 and divisible by 64"};
  1764. }
  1765. }
  1766. // Validate batch count
  1767. if (params.contains("batch_count")) {
  1768. if (!params["batch_count"].is_number_integer()) {
  1769. return {false, "Invalid 'batch_count' field, must be integer"};
  1770. }
  1771. int batchCount = params["batch_count"];
  1772. if (batchCount < 1 || batchCount > 100) {
  1773. return {false, "Batch count must be between 1 and 100"};
  1774. }
  1775. }
  1776. // Validate steps
  1777. if (params.contains("steps")) {
  1778. if (!params["steps"].is_number_integer()) {
  1779. return {false, "Invalid 'steps' field, must be integer"};
  1780. }
  1781. int steps = params["steps"];
  1782. if (steps < 1 || steps > 150) {
  1783. return {false, "Steps must be between 1 and 150"};
  1784. }
  1785. }
  1786. // Validate CFG scale
  1787. if (params.contains("cfg_scale")) {
  1788. if (!params["cfg_scale"].is_number()) {
  1789. return {false, "Invalid 'cfg_scale' field, must be number"};
  1790. }
  1791. float cfgScale = params["cfg_scale"];
  1792. if (cfgScale < 1.0f || cfgScale > 30.0f) {
  1793. return {false, "CFG scale must be between 1.0 and 30.0"};
  1794. }
  1795. }
  1796. // Validate seed
  1797. if (params.contains("seed")) {
  1798. if (!params["seed"].is_string() && !params["seed"].is_number_integer()) {
  1799. return {false, "Invalid 'seed' field, must be string or integer"};
  1800. }
  1801. }
  1802. // Validate sampling method
  1803. if (params.contains("sampling_method")) {
  1804. if (!params["sampling_method"].is_string()) {
  1805. return {false, "Invalid 'sampling_method' field, must be string"};
  1806. }
  1807. std::string method = params["sampling_method"];
  1808. std::vector<std::string> validMethods = {
  1809. "euler", "euler_a", "heun", "dpm2", "dpm++2s_a", "dpm++2m",
  1810. "dpm++2mv2", "ipndm", "ipndm_v", "lcm", "ddim_trailing", "tcd", "default"
  1811. };
  1812. if (std::find(validMethods.begin(), validMethods.end(), method) == validMethods.end()) {
  1813. return {false, "Invalid sampling method"};
  1814. }
  1815. }
  1816. // Validate scheduler
  1817. if (params.contains("scheduler")) {
  1818. if (!params["scheduler"].is_string()) {
  1819. return {false, "Invalid 'scheduler' field, must be string"};
  1820. }
  1821. std::string scheduler = params["scheduler"];
  1822. std::vector<std::string> validSchedulers = {
  1823. "discrete", "karras", "exponential", "ays", "gits",
  1824. "smoothstep", "sgm_uniform", "simple", "default"
  1825. };
  1826. if (std::find(validSchedulers.begin(), validSchedulers.end(), scheduler) == validSchedulers.end()) {
  1827. return {false, "Invalid scheduler"};
  1828. }
  1829. }
  1830. // Validate strength
  1831. if (params.contains("strength")) {
  1832. if (!params["strength"].is_number()) {
  1833. return {false, "Invalid 'strength' field, must be number"};
  1834. }
  1835. float strength = params["strength"];
  1836. if (strength < 0.0f || strength > 1.0f) {
  1837. return {false, "Strength must be between 0.0 and 1.0"};
  1838. }
  1839. }
  1840. // Validate control strength
  1841. if (params.contains("control_strength")) {
  1842. if (!params["control_strength"].is_number()) {
  1843. return {false, "Invalid 'control_strength' field, must be number"};
  1844. }
  1845. float controlStrength = params["control_strength"];
  1846. if (controlStrength < 0.0f || controlStrength > 1.0f) {
  1847. return {false, "Control strength must be between 0.0 and 1.0"};
  1848. }
  1849. }
  1850. // Validate clip skip
  1851. if (params.contains("clip_skip")) {
  1852. if (!params["clip_skip"].is_number_integer()) {
  1853. return {false, "Invalid 'clip_skip' field, must be integer"};
  1854. }
  1855. int clipSkip = params["clip_skip"];
  1856. if (clipSkip < -1 || clipSkip > 12) {
  1857. return {false, "Clip skip must be between -1 and 12"};
  1858. }
  1859. }
  1860. // Validate threads
  1861. if (params.contains("threads")) {
  1862. if (!params["threads"].is_number_integer()) {
  1863. return {false, "Invalid 'threads' field, must be integer"};
  1864. }
  1865. int threads = params["threads"];
  1866. if (threads < -1 || threads > 32) {
  1867. return {false, "Threads must be between -1 (auto) and 32"};
  1868. }
  1869. }
  1870. return {true, ""};
  1871. }
  1872. SamplingMethod Server::parseSamplingMethod(const std::string& method) {
  1873. if (method == "euler") return SamplingMethod::EULER;
  1874. else if (method == "euler_a") return SamplingMethod::EULER_A;
  1875. else if (method == "heun") return SamplingMethod::HEUN;
  1876. else if (method == "dpm2") return SamplingMethod::DPM2;
  1877. else if (method == "dpm++2s_a") return SamplingMethod::DPMPP2S_A;
  1878. else if (method == "dpm++2m") return SamplingMethod::DPMPP2M;
  1879. else if (method == "dpm++2mv2") return SamplingMethod::DPMPP2MV2;
  1880. else if (method == "ipndm") return SamplingMethod::IPNDM;
  1881. else if (method == "ipndm_v") return SamplingMethod::IPNDM_V;
  1882. else if (method == "lcm") return SamplingMethod::LCM;
  1883. else if (method == "ddim_trailing") return SamplingMethod::DDIM_TRAILING;
  1884. else if (method == "tcd") return SamplingMethod::TCD;
  1885. else return SamplingMethod::DEFAULT;
  1886. }
  1887. Scheduler Server::parseScheduler(const std::string& scheduler) {
  1888. if (scheduler == "discrete") return Scheduler::DISCRETE;
  1889. else if (scheduler == "karras") return Scheduler::KARRAS;
  1890. else if (scheduler == "exponential") return Scheduler::EXPONENTIAL;
  1891. else if (scheduler == "ays") return Scheduler::AYS;
  1892. else if (scheduler == "gits") return Scheduler::GITS;
  1893. else if (scheduler == "smoothstep") return Scheduler::SMOOTHSTEP;
  1894. else if (scheduler == "sgm_uniform") return Scheduler::SGM_UNIFORM;
  1895. else if (scheduler == "simple") return Scheduler::SIMPLE;
  1896. else return Scheduler::DEFAULT;
  1897. }
  1898. std::string Server::generateRequestId() {
  1899. std::random_device rd;
  1900. std::mt19937 gen(rd());
  1901. std::uniform_int_distribution<> dis(100000, 999999);
  1902. return "req_" + std::to_string(dis(gen));
  1903. }
  1904. std::tuple<std::vector<uint8_t>, int, int, int, bool, std::string>
  1905. Server::loadImageFromInput(const std::string& input) {
  1906. std::vector<uint8_t> imageData;
  1907. int width = 0, height = 0, channels = 0;
  1908. // Auto-detect input source type
  1909. // 1. Check if input is a URL (starts with http:// or https://)
  1910. if (Utils::startsWith(input, "http://") || Utils::startsWith(input, "https://")) {
  1911. // Parse URL to extract host and path
  1912. std::string url = input;
  1913. std::string scheme, host, path;
  1914. int port = 80;
  1915. // Determine scheme and port
  1916. if (Utils::startsWith(url, "https://")) {
  1917. scheme = "https";
  1918. port = 443;
  1919. url = url.substr(8); // Remove "https://"
  1920. } else {
  1921. scheme = "http";
  1922. port = 80;
  1923. url = url.substr(7); // Remove "http://"
  1924. }
  1925. // Extract host and path
  1926. size_t slashPos = url.find('/');
  1927. if (slashPos != std::string::npos) {
  1928. host = url.substr(0, slashPos);
  1929. path = url.substr(slashPos);
  1930. } else {
  1931. host = url;
  1932. path = "/";
  1933. }
  1934. // Check for custom port
  1935. size_t colonPos = host.find(':');
  1936. if (colonPos != std::string::npos) {
  1937. try {
  1938. port = std::stoi(host.substr(colonPos + 1));
  1939. host = host.substr(0, colonPos);
  1940. } catch (...) {
  1941. return {imageData, 0, 0, 0, false, "Invalid port in URL"};
  1942. }
  1943. }
  1944. // Download image using httplib
  1945. try {
  1946. httplib::Result res;
  1947. if (scheme == "https") {
  1948. #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
  1949. httplib::SSLClient client(host, port);
  1950. client.set_follow_location(true);
  1951. client.set_connection_timeout(30, 0); // 30 seconds
  1952. client.set_read_timeout(60, 0); // 60 seconds
  1953. res = client.Get(path.c_str());
  1954. #else
  1955. return {imageData, 0, 0, 0, false, "HTTPS not supported (OpenSSL not available)"};
  1956. #endif
  1957. } else {
  1958. httplib::Client client(host, port);
  1959. client.set_follow_location(true);
  1960. client.set_connection_timeout(30, 0); // 30 seconds
  1961. client.set_read_timeout(60, 0); // 60 seconds
  1962. res = client.Get(path.c_str());
  1963. }
  1964. if (!res) {
  1965. return {imageData, 0, 0, 0, false, "Failed to download image from URL: Connection error"};
  1966. }
  1967. if (res->status != 200) {
  1968. return {imageData, 0, 0, 0, false, "Failed to download image from URL: HTTP " + std::to_string(res->status)};
  1969. }
  1970. // Convert response body to vector
  1971. std::vector<uint8_t> downloadedData(res->body.begin(), res->body.end());
  1972. // Load image from memory
  1973. int w, h, c;
  1974. unsigned char* pixels = stbi_load_from_memory(
  1975. downloadedData.data(),
  1976. downloadedData.size(),
  1977. &w, &h, &c,
  1978. 3 // Force RGB
  1979. );
  1980. if (!pixels) {
  1981. return {imageData, 0, 0, 0, false, "Failed to decode image from URL"};
  1982. }
  1983. width = w;
  1984. height = h;
  1985. channels = 3;
  1986. size_t dataSize = width * height * channels;
  1987. imageData.resize(dataSize);
  1988. std::memcpy(imageData.data(), pixels, dataSize);
  1989. stbi_image_free(pixels);
  1990. } catch (const std::exception& e) {
  1991. return {imageData, 0, 0, 0, false, "Failed to download image from URL: " + std::string(e.what())};
  1992. }
  1993. }
  1994. // 2. Check if input is base64 encoded data URI (starts with "data:image")
  1995. else if (Utils::startsWith(input, "data:image")) {
  1996. // Extract base64 data after the comma
  1997. size_t commaPos = input.find(',');
  1998. if (commaPos == std::string::npos) {
  1999. return {imageData, 0, 0, 0, false, "Invalid data URI format"};
  2000. }
  2001. std::string base64Data = input.substr(commaPos + 1);
  2002. std::vector<uint8_t> decodedData = Utils::base64Decode(base64Data);
  2003. // Load image from memory using stb_image
  2004. int w, h, c;
  2005. unsigned char* pixels = stbi_load_from_memory(
  2006. decodedData.data(),
  2007. decodedData.size(),
  2008. &w, &h, &c,
  2009. 3 // Force RGB
  2010. );
  2011. if (!pixels) {
  2012. return {imageData, 0, 0, 0, false, "Failed to decode image from base64 data URI"};
  2013. }
  2014. width = w;
  2015. height = h;
  2016. channels = 3; // We forced RGB
  2017. // Copy pixel data
  2018. size_t dataSize = width * height * channels;
  2019. imageData.resize(dataSize);
  2020. std::memcpy(imageData.data(), pixels, dataSize);
  2021. stbi_image_free(pixels);
  2022. }
  2023. // 3. Check if input is raw base64 (long string without slashes, likely base64)
  2024. else if (input.length() > 100 && input.find('/') == std::string::npos && input.find('.') == std::string::npos) {
  2025. // Likely raw base64 without data URI prefix
  2026. std::vector<uint8_t> decodedData = Utils::base64Decode(input);
  2027. int w, h, c;
  2028. unsigned char* pixels = stbi_load_from_memory(
  2029. decodedData.data(),
  2030. decodedData.size(),
  2031. &w, &h, &c,
  2032. 3 // Force RGB
  2033. );
  2034. if (!pixels) {
  2035. return {imageData, 0, 0, 0, false, "Failed to decode image from base64"};
  2036. }
  2037. width = w;
  2038. height = h;
  2039. channels = 3;
  2040. size_t dataSize = width * height * channels;
  2041. imageData.resize(dataSize);
  2042. std::memcpy(imageData.data(), pixels, dataSize);
  2043. stbi_image_free(pixels);
  2044. }
  2045. // 4. Treat as local file path
  2046. else {
  2047. int w, h, c;
  2048. unsigned char* pixels = stbi_load(input.c_str(), &w, &h, &c, 3);
  2049. if (!pixels) {
  2050. return {imageData, 0, 0, 0, false, "Failed to load image from file: " + input};
  2051. }
  2052. width = w;
  2053. height = h;
  2054. channels = 3;
  2055. size_t dataSize = width * height * channels;
  2056. imageData.resize(dataSize);
  2057. std::memcpy(imageData.data(), pixels, dataSize);
  2058. stbi_image_free(pixels);
  2059. }
  2060. return {imageData, width, height, channels, true, ""};
  2061. }
  2062. std::string Server::samplingMethodToString(SamplingMethod method) {
  2063. switch (method) {
  2064. case SamplingMethod::EULER: return "euler";
  2065. case SamplingMethod::EULER_A: return "euler_a";
  2066. case SamplingMethod::HEUN: return "heun";
  2067. case SamplingMethod::DPM2: return "dpm2";
  2068. case SamplingMethod::DPMPP2S_A: return "dpm++2s_a";
  2069. case SamplingMethod::DPMPP2M: return "dpm++2m";
  2070. case SamplingMethod::DPMPP2MV2: return "dpm++2mv2";
  2071. case SamplingMethod::IPNDM: return "ipndm";
  2072. case SamplingMethod::IPNDM_V: return "ipndm_v";
  2073. case SamplingMethod::LCM: return "lcm";
  2074. case SamplingMethod::DDIM_TRAILING: return "ddim_trailing";
  2075. case SamplingMethod::TCD: return "tcd";
  2076. default: return "default";
  2077. }
  2078. }
  2079. std::string Server::schedulerToString(Scheduler scheduler) {
  2080. switch (scheduler) {
  2081. case Scheduler::DISCRETE: return "discrete";
  2082. case Scheduler::KARRAS: return "karras";
  2083. case Scheduler::EXPONENTIAL: return "exponential";
  2084. case Scheduler::AYS: return "ays";
  2085. case Scheduler::GITS: return "gits";
  2086. case Scheduler::SMOOTHSTEP: return "smoothstep";
  2087. case Scheduler::SGM_UNIFORM: return "sgm_uniform";
  2088. case Scheduler::SIMPLE: return "simple";
  2089. default: return "default";
  2090. }
  2091. }
  2092. uint64_t Server::estimateGenerationTime(const GenerationRequest& request) {
  2093. // Basic estimation based on parameters
  2094. uint64_t baseTime = 1000; // 1 second base time
  2095. // Factor in steps
  2096. baseTime *= request.steps;
  2097. // Factor in resolution
  2098. double resolutionFactor = (request.width * request.height) / (512.0 * 512.0);
  2099. baseTime = static_cast<uint64_t>(baseTime * resolutionFactor);
  2100. // Factor in batch count
  2101. baseTime *= request.batchCount;
  2102. // Adjust for sampling method (some are faster than others)
  2103. switch (request.samplingMethod) {
  2104. case SamplingMethod::LCM:
  2105. baseTime /= 4; // LCM is much faster
  2106. break;
  2107. case SamplingMethod::EULER:
  2108. case SamplingMethod::EULER_A:
  2109. baseTime *= 0.8; // Euler methods are faster
  2110. break;
  2111. case SamplingMethod::DPM2:
  2112. case SamplingMethod::DPMPP2S_A:
  2113. baseTime *= 1.2; // DPM methods are slower
  2114. break;
  2115. default:
  2116. break;
  2117. }
  2118. return baseTime;
  2119. }
  2120. size_t Server::estimateMemoryUsage(const GenerationRequest& request) {
  2121. // Basic memory estimation in bytes
  2122. size_t baseMemory = 1024 * 1024 * 1024; // 1GB base
  2123. // Factor in resolution
  2124. double resolutionFactor = (request.width * request.height) / (512.0 * 512.0);
  2125. baseMemory = static_cast<size_t>(baseMemory * resolutionFactor);
  2126. // Factor in batch count
  2127. baseMemory *= request.batchCount;
  2128. // Additional memory for certain features
  2129. if (request.diffusionFlashAttn) {
  2130. baseMemory += 512 * 1024 * 1024; // Extra 512MB for flash attention
  2131. }
  2132. if (!request.controlNetPath.empty()) {
  2133. baseMemory += 1024 * 1024 * 1024; // Extra 1GB for ControlNet
  2134. }
  2135. return baseMemory;
  2136. }
  2137. // Specialized generation endpoints
  2138. void Server::handleText2Img(const httplib::Request& req, httplib::Response& res) {
  2139. std::string requestId = generateRequestId();
  2140. try {
  2141. if (!m_generationQueue) {
  2142. sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
  2143. return;
  2144. }
  2145. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  2146. // Validate required fields for text2img
  2147. if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
  2148. sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
  2149. return;
  2150. }
  2151. // Validate all parameters
  2152. auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
  2153. if (!isValid) {
  2154. sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
  2155. return;
  2156. }
  2157. // Check if any model is loaded
  2158. if (!m_modelManager) {
  2159. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  2160. return;
  2161. }
  2162. // Get currently loaded checkpoint model
  2163. auto allModels = m_modelManager->getAllModels();
  2164. std::string loadedModelName;
  2165. for (const auto& [modelName, modelInfo] : allModels) {
  2166. if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
  2167. loadedModelName = modelName;
  2168. break;
  2169. }
  2170. }
  2171. if (loadedModelName.empty()) {
  2172. sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
  2173. return;
  2174. }
  2175. // Create generation request specifically for text2img
  2176. GenerationRequest genRequest;
  2177. genRequest.id = requestId;
  2178. genRequest.modelName = loadedModelName; // Use the currently loaded model
  2179. genRequest.prompt = requestJson["prompt"];
  2180. genRequest.negativePrompt = requestJson.value("negative_prompt", "");
  2181. genRequest.width = requestJson.value("width", 512);
  2182. genRequest.height = requestJson.value("height", 512);
  2183. genRequest.batchCount = requestJson.value("batch_count", 1);
  2184. genRequest.steps = requestJson.value("steps", 20);
  2185. genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
  2186. genRequest.seed = requestJson.value("seed", "random");
  2187. // Parse optional parameters
  2188. if (requestJson.contains("sampling_method")) {
  2189. genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
  2190. }
  2191. if (requestJson.contains("scheduler")) {
  2192. genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
  2193. }
  2194. // Set text2img specific defaults
  2195. genRequest.strength = 1.0f; // Full strength for text2img
  2196. // Optional VAE model
  2197. if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
  2198. std::string vaeModelId = requestJson["vae_model"];
  2199. if (!vaeModelId.empty()) {
  2200. auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
  2201. if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
  2202. genRequest.vaePath = vaeInfo.path;
  2203. } else {
  2204. sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
  2205. return;
  2206. }
  2207. }
  2208. }
  2209. // Optional TAESD model
  2210. if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
  2211. std::string taesdModelId = requestJson["taesd_model"];
  2212. if (!taesdModelId.empty()) {
  2213. auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
  2214. if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
  2215. genRequest.taesdPath = taesdInfo.path;
  2216. } else {
  2217. sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
  2218. return;
  2219. }
  2220. }
  2221. }
  2222. // Enqueue request
  2223. auto future = m_generationQueue->enqueueRequest(genRequest);
  2224. nlohmann::json params = {
  2225. {"prompt", genRequest.prompt},
  2226. {"negative_prompt", genRequest.negativePrompt},
  2227. {"model", genRequest.modelName},
  2228. {"width", genRequest.width},
  2229. {"height", genRequest.height},
  2230. {"batch_count", genRequest.batchCount},
  2231. {"steps", genRequest.steps},
  2232. {"cfg_scale", genRequest.cfgScale},
  2233. {"seed", genRequest.seed},
  2234. {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
  2235. {"scheduler", schedulerToString(genRequest.scheduler)}
  2236. };
  2237. // Add VAE/TAESD if specified
  2238. if (!genRequest.vaePath.empty()) {
  2239. params["vae_model"] = requestJson.value("vae_model", "");
  2240. }
  2241. if (!genRequest.taesdPath.empty()) {
  2242. params["taesd_model"] = requestJson.value("taesd_model", "");
  2243. }
  2244. nlohmann::json response = {
  2245. {"request_id", requestId},
  2246. {"status", "queued"},
  2247. {"message", "Text-to-image generation request queued successfully"},
  2248. {"queue_position", m_generationQueue->getQueueSize()},
  2249. {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
  2250. {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
  2251. {"type", "text2img"},
  2252. {"parameters", params}
  2253. };
  2254. sendJsonResponse(res, response, 202);
  2255. } catch (const nlohmann::json::parse_error& e) {
  2256. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  2257. } catch (const std::exception& e) {
  2258. sendErrorResponse(res, std::string("Text-to-image request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  2259. }
  2260. }
  2261. void Server::handleImg2Img(const httplib::Request& req, httplib::Response& res) {
  2262. std::string requestId = generateRequestId();
  2263. try {
  2264. if (!m_generationQueue) {
  2265. sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
  2266. return;
  2267. }
  2268. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  2269. // Validate required fields for img2img
  2270. if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
  2271. sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
  2272. return;
  2273. }
  2274. if (!requestJson.contains("init_image") || !requestJson["init_image"].is_string()) {
  2275. sendErrorResponse(res, "Missing or invalid 'init_image' field", 400, "INVALID_PARAMETERS", requestId);
  2276. return;
  2277. }
  2278. // Validate all parameters
  2279. auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
  2280. if (!isValid) {
  2281. sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
  2282. return;
  2283. }
  2284. // Check if any model is loaded
  2285. if (!m_modelManager) {
  2286. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  2287. return;
  2288. }
  2289. // Get currently loaded checkpoint model
  2290. auto allModels = m_modelManager->getAllModels();
  2291. std::string loadedModelName;
  2292. for (const auto& [modelName, modelInfo] : allModels) {
  2293. if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
  2294. loadedModelName = modelName;
  2295. break;
  2296. }
  2297. }
  2298. if (loadedModelName.empty()) {
  2299. sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
  2300. return;
  2301. }
  2302. // Load the init image
  2303. std::string initImageInput = requestJson["init_image"];
  2304. auto [imageData, imgWidth, imgHeight, imgChannels, success, loadError] = loadImageFromInput(initImageInput);
  2305. if (!success) {
  2306. sendErrorResponse(res, "Failed to load init image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId);
  2307. return;
  2308. }
  2309. // Create generation request specifically for img2img
  2310. GenerationRequest genRequest;
  2311. genRequest.id = requestId;
  2312. genRequest.requestType = GenerationRequest::RequestType::IMG2IMG;
  2313. genRequest.modelName = loadedModelName; // Use the currently loaded model
  2314. genRequest.prompt = requestJson["prompt"];
  2315. genRequest.negativePrompt = requestJson.value("negative_prompt", "");
  2316. genRequest.width = requestJson.value("width", imgWidth); // Default to input image dimensions
  2317. genRequest.height = requestJson.value("height", imgHeight);
  2318. genRequest.batchCount = requestJson.value("batch_count", 1);
  2319. genRequest.steps = requestJson.value("steps", 20);
  2320. genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
  2321. genRequest.seed = requestJson.value("seed", "random");
  2322. genRequest.strength = requestJson.value("strength", 0.75f);
  2323. // Set init image data
  2324. genRequest.initImageData = imageData;
  2325. genRequest.initImageWidth = imgWidth;
  2326. genRequest.initImageHeight = imgHeight;
  2327. genRequest.initImageChannels = imgChannels;
  2328. // Parse optional parameters
  2329. if (requestJson.contains("sampling_method")) {
  2330. genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
  2331. }
  2332. if (requestJson.contains("scheduler")) {
  2333. genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
  2334. }
  2335. // Optional VAE model
  2336. if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
  2337. std::string vaeModelId = requestJson["vae_model"];
  2338. if (!vaeModelId.empty()) {
  2339. auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
  2340. if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
  2341. genRequest.vaePath = vaeInfo.path;
  2342. } else {
  2343. sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
  2344. return;
  2345. }
  2346. }
  2347. }
  2348. // Optional TAESD model
  2349. if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
  2350. std::string taesdModelId = requestJson["taesd_model"];
  2351. if (!taesdModelId.empty()) {
  2352. auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
  2353. if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
  2354. genRequest.taesdPath = taesdInfo.path;
  2355. } else {
  2356. sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
  2357. return;
  2358. }
  2359. }
  2360. }
  2361. // Enqueue request
  2362. auto future = m_generationQueue->enqueueRequest(genRequest);
  2363. nlohmann::json params = {
  2364. {"prompt", genRequest.prompt},
  2365. {"negative_prompt", genRequest.negativePrompt},
  2366. {"init_image", requestJson["init_image"]},
  2367. {"model", genRequest.modelName},
  2368. {"width", genRequest.width},
  2369. {"height", genRequest.height},
  2370. {"batch_count", genRequest.batchCount},
  2371. {"steps", genRequest.steps},
  2372. {"cfg_scale", genRequest.cfgScale},
  2373. {"seed", genRequest.seed},
  2374. {"strength", genRequest.strength},
  2375. {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
  2376. {"scheduler", schedulerToString(genRequest.scheduler)}
  2377. };
  2378. // Add VAE/TAESD if specified
  2379. if (!genRequest.vaePath.empty()) {
  2380. params["vae_model"] = requestJson.value("vae_model", "");
  2381. }
  2382. if (!genRequest.taesdPath.empty()) {
  2383. params["taesd_model"] = requestJson.value("taesd_model", "");
  2384. }
  2385. nlohmann::json response = {
  2386. {"request_id", requestId},
  2387. {"status", "queued"},
  2388. {"message", "Image-to-image generation request queued successfully"},
  2389. {"queue_position", m_generationQueue->getQueueSize()},
  2390. {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
  2391. {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
  2392. {"type", "img2img"},
  2393. {"parameters", params}
  2394. };
  2395. sendJsonResponse(res, response, 202);
  2396. } catch (const nlohmann::json::parse_error& e) {
  2397. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  2398. } catch (const std::exception& e) {
  2399. sendErrorResponse(res, std::string("Image-to-image request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  2400. }
  2401. }
  2402. void Server::handleControlNet(const httplib::Request& req, httplib::Response& res) {
  2403. std::string requestId = generateRequestId();
  2404. try {
  2405. if (!m_generationQueue) {
  2406. sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
  2407. return;
  2408. }
  2409. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  2410. // Validate required fields for ControlNet
  2411. if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
  2412. sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
  2413. return;
  2414. }
  2415. if (!requestJson.contains("control_image") || !requestJson["control_image"].is_string()) {
  2416. sendErrorResponse(res, "Missing or invalid 'control_image' field", 400, "INVALID_PARAMETERS", requestId);
  2417. return;
  2418. }
  2419. // Validate all parameters
  2420. auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
  2421. if (!isValid) {
  2422. sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
  2423. return;
  2424. }
  2425. // Check if any model is loaded
  2426. if (!m_modelManager) {
  2427. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  2428. return;
  2429. }
  2430. // Get currently loaded checkpoint model
  2431. auto allModels = m_modelManager->getAllModels();
  2432. std::string loadedModelName;
  2433. for (const auto& [modelName, modelInfo] : allModels) {
  2434. if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
  2435. loadedModelName = modelName;
  2436. break;
  2437. }
  2438. }
  2439. if (loadedModelName.empty()) {
  2440. sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
  2441. return;
  2442. }
  2443. // Create generation request specifically for ControlNet
  2444. GenerationRequest genRequest;
  2445. genRequest.id = requestId;
  2446. genRequest.modelName = loadedModelName; // Use the currently loaded model
  2447. genRequest.prompt = requestJson["prompt"];
  2448. genRequest.negativePrompt = requestJson.value("negative_prompt", "");
  2449. genRequest.width = requestJson.value("width", 512);
  2450. genRequest.height = requestJson.value("height", 512);
  2451. genRequest.batchCount = requestJson.value("batch_count", 1);
  2452. genRequest.steps = requestJson.value("steps", 20);
  2453. genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
  2454. genRequest.seed = requestJson.value("seed", "random");
  2455. genRequest.controlStrength = requestJson.value("control_strength", 0.9f);
  2456. genRequest.controlNetPath = requestJson.value("control_net_model", "");
  2457. // Parse optional parameters
  2458. if (requestJson.contains("sampling_method")) {
  2459. genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
  2460. }
  2461. if (requestJson.contains("scheduler")) {
  2462. genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
  2463. }
  2464. // Optional VAE model
  2465. if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
  2466. std::string vaeModelId = requestJson["vae_model"];
  2467. if (!vaeModelId.empty()) {
  2468. auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
  2469. if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
  2470. genRequest.vaePath = vaeInfo.path;
  2471. } else {
  2472. sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
  2473. return;
  2474. }
  2475. }
  2476. }
  2477. // Optional TAESD model
  2478. if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
  2479. std::string taesdModelId = requestJson["taesd_model"];
  2480. if (!taesdModelId.empty()) {
  2481. auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
  2482. if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
  2483. genRequest.taesdPath = taesdInfo.path;
  2484. } else {
  2485. sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
  2486. return;
  2487. }
  2488. }
  2489. }
  2490. // Store control image path (would be handled in actual implementation)
  2491. genRequest.outputPath = requestJson.value("control_image", "");
  2492. // Enqueue request
  2493. auto future = m_generationQueue->enqueueRequest(genRequest);
  2494. nlohmann::json params = {
  2495. {"prompt", genRequest.prompt},
  2496. {"negative_prompt", genRequest.negativePrompt},
  2497. {"control_image", requestJson["control_image"]},
  2498. {"control_net_model", genRequest.controlNetPath},
  2499. {"model", genRequest.modelName},
  2500. {"width", genRequest.width},
  2501. {"height", genRequest.height},
  2502. {"batch_count", genRequest.batchCount},
  2503. {"steps", genRequest.steps},
  2504. {"cfg_scale", genRequest.cfgScale},
  2505. {"seed", genRequest.seed},
  2506. {"control_strength", genRequest.controlStrength},
  2507. {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
  2508. {"scheduler", schedulerToString(genRequest.scheduler)}
  2509. };
  2510. // Add VAE/TAESD if specified
  2511. if (!genRequest.vaePath.empty()) {
  2512. params["vae_model"] = requestJson.value("vae_model", "");
  2513. }
  2514. if (!genRequest.taesdPath.empty()) {
  2515. params["taesd_model"] = requestJson.value("taesd_model", "");
  2516. }
  2517. nlohmann::json response = {
  2518. {"request_id", requestId},
  2519. {"status", "queued"},
  2520. {"message", "ControlNet generation request queued successfully"},
  2521. {"queue_position", m_generationQueue->getQueueSize()},
  2522. {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
  2523. {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
  2524. {"type", "controlnet"},
  2525. {"parameters", params}
  2526. };
  2527. sendJsonResponse(res, response, 202);
  2528. } catch (const nlohmann::json::parse_error& e) {
  2529. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  2530. } catch (const std::exception& e) {
  2531. sendErrorResponse(res, std::string("ControlNet request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  2532. }
  2533. }
  2534. void Server::handleUpscale(const httplib::Request& req, httplib::Response& res) {
  2535. std::string requestId = generateRequestId();
  2536. try {
  2537. if (!m_generationQueue) {
  2538. sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
  2539. return;
  2540. }
  2541. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  2542. // Validate required fields for upscaler
  2543. if (!requestJson.contains("image") || !requestJson["image"].is_string()) {
  2544. sendErrorResponse(res, "Missing or invalid 'image' field", 400, "INVALID_PARAMETERS", requestId);
  2545. return;
  2546. }
  2547. if (!requestJson.contains("esrgan_model") || !requestJson["esrgan_model"].is_string()) {
  2548. sendErrorResponse(res, "Missing or invalid 'esrgan_model' field (model hash or name)", 400, "INVALID_PARAMETERS", requestId);
  2549. return;
  2550. }
  2551. // Check if model manager is available
  2552. if (!m_modelManager) {
  2553. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  2554. return;
  2555. }
  2556. // Get the ESRGAN/upscaler model
  2557. std::string esrganModelId = requestJson["esrgan_model"];
  2558. auto modelInfo = m_modelManager->getModelInfo(esrganModelId);
  2559. if (modelInfo.name.empty()) {
  2560. sendErrorResponse(res, "ESRGAN model not found: " + esrganModelId, 404, "MODEL_NOT_FOUND", requestId);
  2561. return;
  2562. }
  2563. if (modelInfo.type != ModelType::ESRGAN && modelInfo.type != ModelType::UPSCALER) {
  2564. sendErrorResponse(res, "Model is not an ESRGAN/upscaler model", 400, "INVALID_MODEL_TYPE", requestId);
  2565. return;
  2566. }
  2567. // Load the input image
  2568. std::string imageInput = requestJson["image"];
  2569. auto [imageData, imgWidth, imgHeight, imgChannels, success, loadError] = loadImageFromInput(imageInput);
  2570. if (!success) {
  2571. sendErrorResponse(res, "Failed to load image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId);
  2572. return;
  2573. }
  2574. // Create upscaler request
  2575. GenerationRequest genRequest;
  2576. genRequest.id = requestId;
  2577. genRequest.requestType = GenerationRequest::RequestType::UPSCALER;
  2578. genRequest.esrganPath = modelInfo.path;
  2579. genRequest.upscaleFactor = requestJson.value("upscale_factor", 4);
  2580. genRequest.nThreads = requestJson.value("threads", -1);
  2581. genRequest.offloadParamsToCpu = requestJson.value("offload_to_cpu", false);
  2582. genRequest.diffusionConvDirect = requestJson.value("direct", false);
  2583. // Set input image data
  2584. genRequest.initImageData = imageData;
  2585. genRequest.initImageWidth = imgWidth;
  2586. genRequest.initImageHeight = imgHeight;
  2587. genRequest.initImageChannels = imgChannels;
  2588. // Enqueue request
  2589. auto future = m_generationQueue->enqueueRequest(genRequest);
  2590. nlohmann::json response = {
  2591. {"request_id", requestId},
  2592. {"status", "queued"},
  2593. {"message", "Upscale request queued successfully"},
  2594. {"queue_position", m_generationQueue->getQueueSize()},
  2595. {"type", "upscale"},
  2596. {"parameters", {
  2597. {"esrgan_model", esrganModelId},
  2598. {"upscale_factor", genRequest.upscaleFactor},
  2599. {"input_width", imgWidth},
  2600. {"input_height", imgHeight},
  2601. {"output_width", imgWidth * genRequest.upscaleFactor},
  2602. {"output_height", imgHeight * genRequest.upscaleFactor}
  2603. }}
  2604. };
  2605. sendJsonResponse(res, response, 202);
  2606. } catch (const nlohmann::json::parse_error& e) {
  2607. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  2608. } catch (const std::exception& e) {
  2609. sendErrorResponse(res, std::string("Upscale request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  2610. }
  2611. }
  2612. void Server::handleInpainting(const httplib::Request& req, httplib::Response& res) {
  2613. std::string requestId = generateRequestId();
  2614. try {
  2615. if (!m_generationQueue) {
  2616. sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
  2617. return;
  2618. }
  2619. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  2620. // Validate required fields for inpainting
  2621. if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
  2622. sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
  2623. return;
  2624. }
  2625. if (!requestJson.contains("source_image") || !requestJson["source_image"].is_string()) {
  2626. sendErrorResponse(res, "Missing or invalid 'source_image' field", 400, "INVALID_PARAMETERS", requestId);
  2627. return;
  2628. }
  2629. if (!requestJson.contains("mask_image") || !requestJson["mask_image"].is_string()) {
  2630. sendErrorResponse(res, "Missing or invalid 'mask_image' field", 400, "INVALID_PARAMETERS", requestId);
  2631. return;
  2632. }
  2633. // Validate all parameters
  2634. auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
  2635. if (!isValid) {
  2636. sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
  2637. return;
  2638. }
  2639. // Check if any model is loaded
  2640. if (!m_modelManager) {
  2641. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  2642. return;
  2643. }
  2644. // Get currently loaded checkpoint model
  2645. auto allModels = m_modelManager->getAllModels();
  2646. std::string loadedModelName;
  2647. for (const auto& [modelName, modelInfo] : allModels) {
  2648. if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
  2649. loadedModelName = modelName;
  2650. break;
  2651. }
  2652. }
  2653. if (loadedModelName.empty()) {
  2654. sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
  2655. return;
  2656. }
  2657. // Load the source image
  2658. std::string sourceImageInput = requestJson["source_image"];
  2659. auto [sourceImageData, sourceImgWidth, sourceImgHeight, sourceImgChannels, sourceSuccess, sourceLoadError] = loadImageFromInput(sourceImageInput);
  2660. if (!sourceSuccess) {
  2661. sendErrorResponse(res, "Failed to load source image: " + sourceLoadError, 400, "IMAGE_LOAD_ERROR", requestId);
  2662. return;
  2663. }
  2664. // Load the mask image
  2665. std::string maskImageInput = requestJson["mask_image"];
  2666. auto [maskImageData, maskImgWidth, maskImgHeight, maskImgChannels, maskSuccess, maskLoadError] = loadImageFromInput(maskImageInput);
  2667. if (!maskSuccess) {
  2668. sendErrorResponse(res, "Failed to load mask image: " + maskLoadError, 400, "MASK_LOAD_ERROR", requestId);
  2669. return;
  2670. }
  2671. // Validate that source and mask images have compatible dimensions
  2672. if (sourceImgWidth != maskImgWidth || sourceImgHeight != maskImgHeight) {
  2673. sendErrorResponse(res, "Source and mask images must have the same dimensions", 400, "DIMENSION_MISMATCH", requestId);
  2674. return;
  2675. }
  2676. // Create generation request specifically for inpainting
  2677. GenerationRequest genRequest;
  2678. genRequest.id = requestId;
  2679. genRequest.requestType = GenerationRequest::RequestType::INPAINTING;
  2680. genRequest.modelName = loadedModelName; // Use the currently loaded model
  2681. genRequest.prompt = requestJson["prompt"];
  2682. genRequest.negativePrompt = requestJson.value("negative_prompt", "");
  2683. genRequest.width = requestJson.value("width", sourceImgWidth); // Default to input image dimensions
  2684. genRequest.height = requestJson.value("height", sourceImgHeight);
  2685. genRequest.batchCount = requestJson.value("batch_count", 1);
  2686. genRequest.steps = requestJson.value("steps", 20);
  2687. genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
  2688. genRequest.seed = requestJson.value("seed", "random");
  2689. genRequest.strength = requestJson.value("strength", 0.75f);
  2690. // Set source image data
  2691. genRequest.initImageData = sourceImageData;
  2692. genRequest.initImageWidth = sourceImgWidth;
  2693. genRequest.initImageHeight = sourceImgHeight;
  2694. genRequest.initImageChannels = sourceImgChannels;
  2695. // Set mask image data
  2696. genRequest.maskImageData = maskImageData;
  2697. genRequest.maskImageWidth = maskImgWidth;
  2698. genRequest.maskImageHeight = maskImgHeight;
  2699. genRequest.maskImageChannels = maskImgChannels;
  2700. // Parse optional parameters
  2701. if (requestJson.contains("sampling_method")) {
  2702. genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
  2703. }
  2704. if (requestJson.contains("scheduler")) {
  2705. genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
  2706. }
  2707. // Optional VAE model
  2708. if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
  2709. std::string vaeModelId = requestJson["vae_model"];
  2710. if (!vaeModelId.empty()) {
  2711. auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
  2712. if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
  2713. genRequest.vaePath = vaeInfo.path;
  2714. } else {
  2715. sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
  2716. return;
  2717. }
  2718. }
  2719. }
  2720. // Optional TAESD model
  2721. if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
  2722. std::string taesdModelId = requestJson["taesd_model"];
  2723. if (!taesdModelId.empty()) {
  2724. auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
  2725. if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
  2726. genRequest.taesdPath = taesdInfo.path;
  2727. } else {
  2728. sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
  2729. return;
  2730. }
  2731. }
  2732. }
  2733. // Enqueue request
  2734. auto future = m_generationQueue->enqueueRequest(genRequest);
  2735. nlohmann::json params = {
  2736. {"prompt", genRequest.prompt},
  2737. {"negative_prompt", genRequest.negativePrompt},
  2738. {"source_image", requestJson["source_image"]},
  2739. {"mask_image", requestJson["mask_image"]},
  2740. {"model", genRequest.modelName},
  2741. {"width", genRequest.width},
  2742. {"height", genRequest.height},
  2743. {"batch_count", genRequest.batchCount},
  2744. {"steps", genRequest.steps},
  2745. {"cfg_scale", genRequest.cfgScale},
  2746. {"seed", genRequest.seed},
  2747. {"strength", genRequest.strength},
  2748. {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
  2749. {"scheduler", schedulerToString(genRequest.scheduler)}
  2750. };
  2751. // Add VAE/TAESD if specified
  2752. if (!genRequest.vaePath.empty()) {
  2753. params["vae_model"] = requestJson.value("vae_model", "");
  2754. }
  2755. if (!genRequest.taesdPath.empty()) {
  2756. params["taesd_model"] = requestJson.value("taesd_model", "");
  2757. }
  2758. nlohmann::json response = {
  2759. {"request_id", requestId},
  2760. {"status", "queued"},
  2761. {"message", "Inpainting generation request queued successfully"},
  2762. {"queue_position", m_generationQueue->getQueueSize()},
  2763. {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
  2764. {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
  2765. {"type", "inpainting"},
  2766. {"parameters", params}
  2767. };
  2768. sendJsonResponse(res, response, 202);
  2769. } catch (const nlohmann::json::parse_error& e) {
  2770. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  2771. } catch (const std::exception& e) {
  2772. sendErrorResponse(res, std::string("Inpainting request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  2773. }
  2774. }
  2775. // Utility endpoints
  2776. void Server::handleSamplers(const httplib::Request& /*req*/, httplib::Response& res) {
  2777. try {
  2778. nlohmann::json samplers = {
  2779. {"samplers", {
  2780. {
  2781. {"name", "euler"},
  2782. {"description", "Euler sampler - fast and simple"},
  2783. {"recommended_steps", 20}
  2784. },
  2785. {
  2786. {"name", "euler_a"},
  2787. {"description", "Euler ancestral sampler - adds randomness"},
  2788. {"recommended_steps", 20}
  2789. },
  2790. {
  2791. {"name", "heun"},
  2792. {"description", "Heun sampler - more accurate but slower"},
  2793. {"recommended_steps", 20}
  2794. },
  2795. {
  2796. {"name", "dpm2"},
  2797. {"description", "DPM2 sampler - second-order DPM"},
  2798. {"recommended_steps", 20}
  2799. },
  2800. {
  2801. {"name", "dpm++2s_a"},
  2802. {"description", "DPM++ 2s ancestral sampler"},
  2803. {"recommended_steps", 20}
  2804. },
  2805. {
  2806. {"name", "dpm++2m"},
  2807. {"description", "DPM++ 2m sampler - multistep"},
  2808. {"recommended_steps", 20}
  2809. },
  2810. {
  2811. {"name", "dpm++2mv2"},
  2812. {"description", "DPM++ 2m v2 sampler - improved multistep"},
  2813. {"recommended_steps", 20}
  2814. },
  2815. {
  2816. {"name", "ipndm"},
  2817. {"description", "IPNDM sampler - improved noise prediction"},
  2818. {"recommended_steps", 20}
  2819. },
  2820. {
  2821. {"name", "ipndm_v"},
  2822. {"description", "IPNDM v sampler - variant of IPNDM"},
  2823. {"recommended_steps", 20}
  2824. },
  2825. {
  2826. {"name", "lcm"},
  2827. {"description", "LCM sampler - Latent Consistency Model, very fast"},
  2828. {"recommended_steps", 4}
  2829. },
  2830. {
  2831. {"name", "ddim_trailing"},
  2832. {"description", "DDIM trailing sampler - deterministic"},
  2833. {"recommended_steps", 20}
  2834. },
  2835. {
  2836. {"name", "tcd"},
  2837. {"description", "TCD sampler - Trajectory Consistency Distillation"},
  2838. {"recommended_steps", 8}
  2839. },
  2840. {
  2841. {"name", "default"},
  2842. {"description", "Use model's default sampler"},
  2843. {"recommended_steps", 20}
  2844. }
  2845. }}
  2846. };
  2847. sendJsonResponse(res, samplers);
  2848. } catch (const std::exception& e) {
  2849. sendErrorResponse(res, std::string("Failed to get samplers: ") + e.what(), 500);
  2850. }
  2851. }
  2852. void Server::handleSchedulers(const httplib::Request& /*req*/, httplib::Response& res) {
  2853. try {
  2854. nlohmann::json schedulers = {
  2855. {"schedulers", {
  2856. {
  2857. {"name", "discrete"},
  2858. {"description", "Discrete scheduler - standard noise schedule"}
  2859. },
  2860. {
  2861. {"name", "karras"},
  2862. {"description", "Karras scheduler - improved noise schedule"}
  2863. },
  2864. {
  2865. {"name", "exponential"},
  2866. {"description", "Exponential scheduler - exponential noise decay"}
  2867. },
  2868. {
  2869. {"name", "ays"},
  2870. {"description", "AYS scheduler - Adaptive Your Scheduler"}
  2871. },
  2872. {
  2873. {"name", "gits"},
  2874. {"description", "GITS scheduler - Generalized Iterative Time Steps"}
  2875. },
  2876. {
  2877. {"name", "smoothstep"},
  2878. {"description", "Smoothstep scheduler - smooth transition function"}
  2879. },
  2880. {
  2881. {"name", "sgm_uniform"},
  2882. {"description", "SGM uniform scheduler - uniform noise schedule"}
  2883. },
  2884. {
  2885. {"name", "simple"},
  2886. {"description", "Simple scheduler - basic linear schedule"}
  2887. },
  2888. {
  2889. {"name", "default"},
  2890. {"description", "Use model's default scheduler"}
  2891. }
  2892. }}
  2893. };
  2894. sendJsonResponse(res, schedulers);
  2895. } catch (const std::exception& e) {
  2896. sendErrorResponse(res, std::string("Failed to get schedulers: ") + e.what(), 500);
  2897. }
  2898. }
  2899. void Server::handleParameters(const httplib::Request& /*req*/, httplib::Response& res) {
  2900. try {
  2901. nlohmann::json parameters = {
  2902. {"parameters", {
  2903. {
  2904. {"name", "prompt"},
  2905. {"type", "string"},
  2906. {"required", true},
  2907. {"description", "Text prompt for image generation"},
  2908. {"min_length", 1},
  2909. {"max_length", 10000},
  2910. {"example", "a beautiful landscape with mountains"}
  2911. },
  2912. {
  2913. {"name", "negative_prompt"},
  2914. {"type", "string"},
  2915. {"required", false},
  2916. {"description", "Negative prompt to guide generation away from"},
  2917. {"min_length", 0},
  2918. {"max_length", 10000},
  2919. {"example", "blurry, low quality, distorted"}
  2920. },
  2921. {
  2922. {"name", "width"},
  2923. {"type", "integer"},
  2924. {"required", false},
  2925. {"description", "Image width in pixels"},
  2926. {"min", 64},
  2927. {"max", 2048},
  2928. {"multiple_of", 64},
  2929. {"default", 512}
  2930. },
  2931. {
  2932. {"name", "height"},
  2933. {"type", "integer"},
  2934. {"required", false},
  2935. {"description", "Image height in pixels"},
  2936. {"min", 64},
  2937. {"max", 2048},
  2938. {"multiple_of", 64},
  2939. {"default", 512}
  2940. },
  2941. {
  2942. {"name", "steps"},
  2943. {"type", "integer"},
  2944. {"required", false},
  2945. {"description", "Number of diffusion steps"},
  2946. {"min", 1},
  2947. {"max", 150},
  2948. {"default", 20}
  2949. },
  2950. {
  2951. {"name", "cfg_scale"},
  2952. {"type", "number"},
  2953. {"required", false},
  2954. {"description", "Classifier-Free Guidance scale"},
  2955. {"min", 1.0},
  2956. {"max", 30.0},
  2957. {"default", 7.5}
  2958. },
  2959. {
  2960. {"name", "seed"},
  2961. {"type", "string|integer"},
  2962. {"required", false},
  2963. {"description", "Seed for generation (use 'random' for random seed)"},
  2964. {"example", "42"}
  2965. },
  2966. {
  2967. {"name", "sampling_method"},
  2968. {"type", "string"},
  2969. {"required", false},
  2970. {"description", "Sampling method to use"},
  2971. {"enum", {"euler", "euler_a", "heun", "dpm2", "dpm++2s_a", "dpm++2m", "dpm++2mv2", "ipndm", "ipndm_v", "lcm", "ddim_trailing", "tcd", "default"}},
  2972. {"default", "default"}
  2973. },
  2974. {
  2975. {"name", "scheduler"},
  2976. {"type", "string"},
  2977. {"required", false},
  2978. {"description", "Scheduler to use"},
  2979. {"enum", {"discrete", "karras", "exponential", "ays", "gits", "smoothstep", "sgm_uniform", "simple", "default"}},
  2980. {"default", "default"}
  2981. },
  2982. {
  2983. {"name", "batch_count"},
  2984. {"type", "integer"},
  2985. {"required", false},
  2986. {"description", "Number of images to generate"},
  2987. {"min", 1},
  2988. {"max", 100},
  2989. {"default", 1}
  2990. },
  2991. {
  2992. {"name", "strength"},
  2993. {"type", "number"},
  2994. {"required", false},
  2995. {"description", "Strength for img2img (0.0-1.0)"},
  2996. {"min", 0.0},
  2997. {"max", 1.0},
  2998. {"default", 0.75}
  2999. },
  3000. {
  3001. {"name", "control_strength"},
  3002. {"type", "number"},
  3003. {"required", false},
  3004. {"description", "ControlNet strength (0.0-1.0)"},
  3005. {"min", 0.0},
  3006. {"max", 1.0},
  3007. {"default", 0.9}
  3008. }
  3009. }},
  3010. {"openapi", {
  3011. {"version", "3.0.0"},
  3012. {"info", {
  3013. {"title", "Stable Diffusion REST API"},
  3014. {"version", "1.0.0"},
  3015. {"description", "Comprehensive REST API for stable-diffusion.cpp functionality"}
  3016. }},
  3017. {"components", {
  3018. {"schemas", {
  3019. {"GenerationRequest", {
  3020. {"type", "object"},
  3021. {"required", {"prompt"}},
  3022. {"properties", {
  3023. {"prompt", {{"type", "string"}, {"description", "Text prompt for generation"}}},
  3024. {"negative_prompt", {{"type", "string"}, {"description", "Negative prompt"}}},
  3025. {"width", {{"type", "integer"}, {"minimum", 64}, {"maximum", 2048}, {"default", 512}}},
  3026. {"height", {{"type", "integer"}, {"minimum", 64}, {"maximum", 2048}, {"default", 512}}},
  3027. {"steps", {{"type", "integer"}, {"minimum", 1}, {"maximum", 150}, {"default", 20}}},
  3028. {"cfg_scale", {{"type", "number"}, {"minimum", 1.0}, {"maximum", 30.0}, {"default", 7.5}}}
  3029. }}
  3030. }}
  3031. }}
  3032. }}
  3033. }}
  3034. };
  3035. sendJsonResponse(res, parameters);
  3036. } catch (const std::exception& e) {
  3037. sendErrorResponse(res, std::string("Failed to get parameters: ") + e.what(), 500);
  3038. }
  3039. }
  3040. void Server::handleValidate(const httplib::Request& req, httplib::Response& res) {
  3041. std::string requestId = generateRequestId();
  3042. try {
  3043. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  3044. // Validate parameters
  3045. auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
  3046. nlohmann::json response = {
  3047. {"request_id", requestId},
  3048. {"valid", isValid},
  3049. {"message", isValid ? "Parameters are valid" : errorMessage},
  3050. {"errors", isValid ? nlohmann::json::array() : nlohmann::json::array({errorMessage})}
  3051. };
  3052. sendJsonResponse(res, response, isValid ? 200 : 400);
  3053. } catch (const nlohmann::json::parse_error& e) {
  3054. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  3055. } catch (const std::exception& e) {
  3056. sendErrorResponse(res, std::string("Validation failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  3057. }
  3058. }
  3059. void Server::handleEstimate(const httplib::Request& req, httplib::Response& res) {
  3060. std::string requestId = generateRequestId();
  3061. try {
  3062. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  3063. // Validate parameters first
  3064. auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
  3065. if (!isValid) {
  3066. sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
  3067. return;
  3068. }
  3069. // Create a temporary request to estimate
  3070. GenerationRequest genRequest;
  3071. genRequest.prompt = requestJson["prompt"];
  3072. genRequest.width = requestJson.value("width", 512);
  3073. genRequest.height = requestJson.value("height", 512);
  3074. genRequest.batchCount = requestJson.value("batch_count", 1);
  3075. genRequest.steps = requestJson.value("steps", 20);
  3076. genRequest.diffusionFlashAttn = requestJson.value("diffusion_flash_attn", false);
  3077. genRequest.controlNetPath = requestJson.value("control_net_path", "");
  3078. if (requestJson.contains("sampling_method")) {
  3079. genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
  3080. }
  3081. // Calculate estimates
  3082. uint64_t estimatedTime = estimateGenerationTime(genRequest);
  3083. size_t estimatedMemory = estimateMemoryUsage(genRequest);
  3084. nlohmann::json response = {
  3085. {"request_id", requestId},
  3086. {"estimated_time_seconds", estimatedTime / 1000},
  3087. {"estimated_memory_mb", estimatedMemory / (1024 * 1024)},
  3088. {"parameters", {
  3089. {"resolution", std::to_string(genRequest.width) + "x" + std::to_string(genRequest.height)},
  3090. {"steps", genRequest.steps},
  3091. {"batch_count", genRequest.batchCount},
  3092. {"sampling_method", samplingMethodToString(genRequest.samplingMethod)}
  3093. }}
  3094. };
  3095. sendJsonResponse(res, response);
  3096. } catch (const nlohmann::json::parse_error& e) {
  3097. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  3098. } catch (const std::exception& e) {
  3099. sendErrorResponse(res, std::string("Estimation failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  3100. }
  3101. }
  3102. void Server::handleConfig(const httplib::Request& /*req*/, httplib::Response& res) {
  3103. std::string requestId = generateRequestId();
  3104. try {
  3105. // Get current configuration
  3106. nlohmann::json config = {
  3107. {"request_id", requestId},
  3108. {"config", {
  3109. {"server", {
  3110. {"host", m_host},
  3111. {"port", m_port},
  3112. {"max_concurrent_generations", 1}
  3113. }},
  3114. {"generation", {
  3115. {"default_width", 512},
  3116. {"default_height", 512},
  3117. {"default_steps", 20},
  3118. {"default_cfg_scale", 7.5},
  3119. {"max_batch_count", 100},
  3120. {"max_steps", 150},
  3121. {"max_resolution", 2048}
  3122. }},
  3123. {"rate_limiting", {
  3124. {"requests_per_minute", 60},
  3125. {"enabled", true}
  3126. }}
  3127. }}
  3128. };
  3129. sendJsonResponse(res, config);
  3130. } catch (const std::exception& e) {
  3131. sendErrorResponse(res, std::string("Config operation failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  3132. }
  3133. }
  3134. void Server::handleSystem(const httplib::Request& /*req*/, httplib::Response& res) {
  3135. try {
  3136. nlohmann::json system = {
  3137. {"system", {
  3138. {"version", "1.0.0"},
  3139. {"build", "stable-diffusion.cpp-rest"},
  3140. {"uptime", std::chrono::duration_cast<std::chrono::seconds>(
  3141. std::chrono::steady_clock::now().time_since_epoch()).count()},
  3142. {"capabilities", {
  3143. {"text2img", true},
  3144. {"img2img", true},
  3145. {"controlnet", true},
  3146. {"batch_generation", true},
  3147. {"parameter_validation", true},
  3148. {"estimation", true}
  3149. }},
  3150. {"supported_formats", {
  3151. {"input", {"png", "jpg", "jpeg", "webp"}},
  3152. {"output", {"png", "jpg", "jpeg", "webp"}}
  3153. }},
  3154. {"limits", {
  3155. {"max_resolution", 2048},
  3156. {"max_steps", 150},
  3157. {"max_batch_count", 100},
  3158. {"max_prompt_length", 10000}
  3159. }}
  3160. }},
  3161. {"hardware", {
  3162. {"cpu_threads", std::thread::hardware_concurrency()}
  3163. }}
  3164. };
  3165. sendJsonResponse(res, system);
  3166. } catch (const std::exception& e) {
  3167. sendErrorResponse(res, std::string("System info failed: ") + e.what(), 500);
  3168. }
  3169. }
  3170. void Server::handleSystemRestart(const httplib::Request& /*req*/, httplib::Response& res) {
  3171. try {
  3172. nlohmann::json response = {
  3173. {"message", "Server restart initiated. The server will shut down gracefully and exit. Please use a process manager to automatically restart it."},
  3174. {"status", "restarting"}
  3175. };
  3176. sendJsonResponse(res, response);
  3177. // Schedule server stop after response is sent
  3178. // Using a separate thread to allow the response to be sent first
  3179. std::thread([this]() {
  3180. std::this_thread::sleep_for(std::chrono::seconds(1));
  3181. this->stop();
  3182. // Exit with code 42 to signal restart intent to process manager
  3183. std::exit(42);
  3184. }).detach();
  3185. } catch (const std::exception& e) {
  3186. sendErrorResponse(res, std::string("Restart failed: ") + e.what(), 500);
  3187. }
  3188. }
  3189. // Helper methods for model management
  3190. nlohmann::json Server::getModelCapabilities(ModelType type) {
  3191. nlohmann::json capabilities = nlohmann::json::object();
  3192. switch (type) {
  3193. case ModelType::CHECKPOINT:
  3194. capabilities = {
  3195. {"text2img", true},
  3196. {"img2img", true},
  3197. {"inpainting", true},
  3198. {"outpainting", true},
  3199. {"controlnet", true},
  3200. {"lora", true},
  3201. {"vae", true},
  3202. {"sampling_methods", {"euler", "euler_a", "heun", "dpm2", "dpm++2s_a", "dpm++2m", "dpm++2mv2", "ipndm", "ipndm_v", "lcm", "ddim_trailing", "tcd"}},
  3203. {"schedulers", {"discrete", "karras", "exponential", "ays", "gits", "smoothstep", "sgm_uniform", "simple"}},
  3204. {"recommended_resolution", "512x512"},
  3205. {"max_resolution", "2048x2048"},
  3206. {"supports_batch", true}
  3207. };
  3208. break;
  3209. case ModelType::LORA:
  3210. capabilities = {
  3211. {"text2img", true},
  3212. {"img2img", true},
  3213. {"inpainting", true},
  3214. {"controlnet", false},
  3215. {"lora", true},
  3216. {"vae", false},
  3217. {"requires_checkpoint", true},
  3218. {"strength_range", {0.0, 2.0}},
  3219. {"recommended_strength", 1.0}
  3220. };
  3221. break;
  3222. case ModelType::CONTROLNET:
  3223. capabilities = {
  3224. {"text2img", false},
  3225. {"img2img", true},
  3226. {"inpainting", true},
  3227. {"controlnet", true},
  3228. {"requires_checkpoint", true},
  3229. {"control_modes", {"canny", "depth", "pose", "scribble", "hed", "mlsd", "normal", "seg"}},
  3230. {"strength_range", {0.0, 1.0}},
  3231. {"recommended_strength", 0.9}
  3232. };
  3233. break;
  3234. case ModelType::VAE:
  3235. capabilities = {
  3236. {"text2img", false},
  3237. {"img2img", false},
  3238. {"inpainting", false},
  3239. {"vae", true},
  3240. {"requires_checkpoint", true},
  3241. {"encoding", true},
  3242. {"decoding", true},
  3243. {"precision", {"fp16", "fp32"}}
  3244. };
  3245. break;
  3246. case ModelType::EMBEDDING:
  3247. capabilities = {
  3248. {"text2img", true},
  3249. {"img2img", true},
  3250. {"inpainting", true},
  3251. {"embedding", true},
  3252. {"requires_checkpoint", true},
  3253. {"token_count", 1},
  3254. {"compatible_with", {"checkpoint", "lora"}}
  3255. };
  3256. break;
  3257. case ModelType::TAESD:
  3258. capabilities = {
  3259. {"text2img", false},
  3260. {"img2img", false},
  3261. {"inpainting", false},
  3262. {"vae", true},
  3263. {"requires_checkpoint", true},
  3264. {"fast_decoding", true},
  3265. {"real_time", true},
  3266. {"precision", {"fp16", "fp32"}}
  3267. };
  3268. break;
  3269. case ModelType::ESRGAN:
  3270. capabilities = {
  3271. {"text2img", false},
  3272. {"img2img", false},
  3273. {"inpainting", false},
  3274. {"upscaling", true},
  3275. {"scale_factors", {2, 4}},
  3276. {"models", {"ESRGAN", "RealESRGAN", "SwinIR"}},
  3277. {"supports_alpha", false}
  3278. };
  3279. break;
  3280. default:
  3281. capabilities = {
  3282. {"text2img", false},
  3283. {"img2img", false},
  3284. {"inpainting", false},
  3285. {"capabilities", {}}
  3286. };
  3287. break;
  3288. }
  3289. return capabilities;
  3290. }
  3291. nlohmann::json Server::getModelTypeStatistics() {
  3292. if (!m_modelManager) return nlohmann::json::object();
  3293. nlohmann::json stats = nlohmann::json::object();
  3294. auto allModels = m_modelManager->getAllModels();
  3295. // Initialize counters for each type
  3296. std::map<ModelType, int> typeCounts;
  3297. std::map<ModelType, int> loadedCounts;
  3298. std::map<ModelType, size_t> sizeByType;
  3299. for (const auto& pair : allModels) {
  3300. ModelType type = pair.second.type;
  3301. typeCounts[type]++;
  3302. if (pair.second.isLoaded) {
  3303. loadedCounts[type]++;
  3304. }
  3305. sizeByType[type] += pair.second.fileSize;
  3306. }
  3307. // Build statistics JSON
  3308. for (const auto& count : typeCounts) {
  3309. std::string typeName = ModelManager::modelTypeToString(count.first);
  3310. stats[typeName] = {
  3311. {"total_count", count.second},
  3312. {"loaded_count", loadedCounts[count.first]},
  3313. {"total_size_bytes", sizeByType[count.first]},
  3314. {"total_size_mb", sizeByType[count.first] / (1024.0 * 1024.0)},
  3315. {"average_size_mb", count.second > 0 ? (sizeByType[count.first] / (1024.0 * 1024.0)) / count.second : 0.0}
  3316. };
  3317. }
  3318. return stats;
  3319. }
  3320. // Additional helper methods for model management
  3321. nlohmann::json Server::getModelCompatibility(const ModelManager::ModelInfo& modelInfo) {
  3322. nlohmann::json compatibility = {
  3323. {"is_compatible", true},
  3324. {"compatibility_score", 100},
  3325. {"issues", nlohmann::json::array()},
  3326. {"warnings", nlohmann::json::array()},
  3327. {"requirements", {
  3328. {"min_memory_mb", 1024},
  3329. {"recommended_memory_mb", 2048},
  3330. {"supported_formats", {"safetensors", "ckpt", "gguf"}},
  3331. {"required_dependencies", {}}
  3332. }}
  3333. };
  3334. // Check for specific compatibility issues based on model type
  3335. if (modelInfo.type == ModelType::LORA) {
  3336. compatibility["requirements"]["required_dependencies"] = {"checkpoint"};
  3337. } else if (modelInfo.type == ModelType::CONTROLNET) {
  3338. compatibility["requirements"]["required_dependencies"] = {"checkpoint"};
  3339. } else if (modelInfo.type == ModelType::VAE) {
  3340. compatibility["requirements"]["required_dependencies"] = {"checkpoint"};
  3341. }
  3342. return compatibility;
  3343. }
  3344. nlohmann::json Server::getModelRequirements(ModelType type) {
  3345. nlohmann::json requirements = {
  3346. {"min_memory_mb", 1024},
  3347. {"recommended_memory_mb", 2048},
  3348. {"min_disk_space_mb", 1024},
  3349. {"supported_formats", {"safetensors", "ckpt", "gguf"}},
  3350. {"required_dependencies", nlohmann::json::array()},
  3351. {"optional_dependencies", nlohmann::json::array()},
  3352. {"system_requirements", {
  3353. {"cpu_cores", 4},
  3354. {"cpu_architecture", "x86_64"},
  3355. {"os", "Linux/Windows/macOS"},
  3356. {"gpu_memory_mb", 2048},
  3357. {"gpu_compute_capability", "3.5+"}
  3358. }}
  3359. };
  3360. switch (type) {
  3361. case ModelType::CHECKPOINT:
  3362. requirements["min_memory_mb"] = 2048;
  3363. requirements["recommended_memory_mb"] = 4096;
  3364. requirements["min_disk_space_mb"] = 2048;
  3365. requirements["supported_formats"] = {"safetensors", "ckpt", "gguf"};
  3366. break;
  3367. case ModelType::LORA:
  3368. requirements["min_memory_mb"] = 512;
  3369. requirements["recommended_memory_mb"] = 1024;
  3370. requirements["min_disk_space_mb"] = 100;
  3371. requirements["supported_formats"] = {"safetensors", "ckpt"};
  3372. requirements["required_dependencies"] = {"checkpoint"};
  3373. break;
  3374. case ModelType::CONTROLNET:
  3375. requirements["min_memory_mb"] = 1024;
  3376. requirements["recommended_memory_mb"] = 2048;
  3377. requirements["min_disk_space_mb"] = 500;
  3378. requirements["supported_formats"] = {"safetensors", "pth"};
  3379. requirements["required_dependencies"] = {"checkpoint"};
  3380. break;
  3381. case ModelType::VAE:
  3382. requirements["min_memory_mb"] = 512;
  3383. requirements["recommended_memory_mb"] = 1024;
  3384. requirements["min_disk_space_mb"] = 200;
  3385. requirements["supported_formats"] = {"safetensors", "pt", "ckpt", "gguf"};
  3386. requirements["required_dependencies"] = {"checkpoint"};
  3387. break;
  3388. case ModelType::EMBEDDING:
  3389. requirements["min_memory_mb"] = 64;
  3390. requirements["recommended_memory_mb"] = 256;
  3391. requirements["min_disk_space_mb"] = 10;
  3392. requirements["supported_formats"] = {"safetensors", "pt"};
  3393. requirements["required_dependencies"] = {"checkpoint"};
  3394. break;
  3395. case ModelType::TAESD:
  3396. requirements["min_memory_mb"] = 256;
  3397. requirements["recommended_memory_mb"] = 512;
  3398. requirements["min_disk_space_mb"] = 100;
  3399. requirements["supported_formats"] = {"safetensors", "pth", "gguf"};
  3400. requirements["required_dependencies"] = {"checkpoint"};
  3401. break;
  3402. case ModelType::ESRGAN:
  3403. requirements["min_memory_mb"] = 1024;
  3404. requirements["recommended_memory_mb"] = 2048;
  3405. requirements["min_disk_space_mb"] = 500;
  3406. requirements["supported_formats"] = {"pth", "pt"};
  3407. requirements["optional_dependencies"] = {"checkpoint"};
  3408. break;
  3409. default:
  3410. break;
  3411. }
  3412. return requirements;
  3413. }
  3414. nlohmann::json Server::getRecommendedUsage(ModelType type) {
  3415. nlohmann::json usage = {
  3416. {"text2img", false},
  3417. {"img2img", false},
  3418. {"inpainting", false},
  3419. {"controlnet", false},
  3420. {"lora", false},
  3421. {"vae", false},
  3422. {"recommended_resolution", "512x512"},
  3423. {"recommended_steps", 20},
  3424. {"recommended_cfg_scale", 7.5},
  3425. {"recommended_batch_size", 1}
  3426. };
  3427. switch (type) {
  3428. case ModelType::CHECKPOINT:
  3429. usage = {
  3430. {"text2img", true},
  3431. {"img2img", true},
  3432. {"inpainting", true},
  3433. {"controlnet", true},
  3434. {"lora", true},
  3435. {"vae", true},
  3436. {"recommended_resolution", "512x512"},
  3437. {"recommended_steps", 20},
  3438. {"recommended_cfg_scale", 7.5},
  3439. {"recommended_batch_size", 1}
  3440. };
  3441. break;
  3442. case ModelType::LORA:
  3443. usage = {
  3444. {"text2img", true},
  3445. {"img2img", true},
  3446. {"inpainting", true},
  3447. {"controlnet", false},
  3448. {"lora", true},
  3449. {"vae", false},
  3450. {"recommended_strength", 1.0},
  3451. {"recommended_usage", "Style transfer, character customization"}
  3452. };
  3453. break;
  3454. case ModelType::CONTROLNET:
  3455. usage = {
  3456. {"text2img", false},
  3457. {"img2img", true},
  3458. {"inpainting", true},
  3459. {"controlnet", true},
  3460. {"lora", false},
  3461. {"vae", false},
  3462. {"recommended_strength", 0.9},
  3463. {"recommended_usage", "Precise control over output"}
  3464. };
  3465. break;
  3466. case ModelType::VAE:
  3467. usage = {
  3468. {"text2img", false},
  3469. {"img2img", false},
  3470. {"inpainting", false},
  3471. {"controlnet", false},
  3472. {"lora", false},
  3473. {"vae", true},
  3474. {"recommended_usage", "Improved encoding/decoding quality"}
  3475. };
  3476. break;
  3477. case ModelType::EMBEDDING:
  3478. usage = {
  3479. {"text2img", true},
  3480. {"img2img", true},
  3481. {"inpainting", true},
  3482. {"controlnet", false},
  3483. {"lora", false},
  3484. {"vae", false},
  3485. {"embedding", true},
  3486. {"recommended_usage", "Concept control, style words"}
  3487. };
  3488. break;
  3489. case ModelType::TAESD:
  3490. usage = {
  3491. {"text2img", false},
  3492. {"img2img", false},
  3493. {"inpainting", false},
  3494. {"controlnet", false},
  3495. {"lora", false},
  3496. {"vae", true},
  3497. {"recommended_usage", "Real-time decoding"}
  3498. };
  3499. break;
  3500. case ModelType::ESRGAN:
  3501. usage = {
  3502. {"text2img", false},
  3503. {"img2img", false},
  3504. {"inpainting", false},
  3505. {"controlnet", false},
  3506. {"lora", false},
  3507. {"vae", false},
  3508. {"upscaling", true},
  3509. {"recommended_usage", "Image upscaling and quality enhancement"}
  3510. };
  3511. break;
  3512. default:
  3513. break;
  3514. }
  3515. return usage;
  3516. }
  3517. std::string Server::getModelTypeFromDirectoryName(const std::string& dirName) {
  3518. if (dirName == "stable-diffusion" || dirName == "checkpoints") {
  3519. return "checkpoint";
  3520. } else if (dirName == "lora") {
  3521. return "lora";
  3522. } else if (dirName == "controlnet") {
  3523. return "controlnet";
  3524. } else if (dirName == "vae") {
  3525. return "vae";
  3526. } else if (dirName == "taesd") {
  3527. return "taesd";
  3528. } else if (dirName == "esrgan" || dirName == "upscaler") {
  3529. return "esrgan";
  3530. } else if (dirName == "embeddings" || dirName == "textual-inversion") {
  3531. return "embedding";
  3532. } else {
  3533. return "unknown";
  3534. }
  3535. }
  3536. std::string Server::getDirectoryDescription(const std::string& dirName) {
  3537. if (dirName == "stable-diffusion" || dirName == "checkpoints") {
  3538. return "Main stable diffusion model files";
  3539. } else if (dirName == "lora") {
  3540. return "LoRA adapter models for style transfer";
  3541. } else if (dirName == "controlnet") {
  3542. return "ControlNet models for precise control";
  3543. } else if (dirName == "vae") {
  3544. return "VAE models for improved encoding/decoding";
  3545. } else if (dirName == "taesd") {
  3546. return "TAESD models for real-time decoding";
  3547. } else if (dirName == "esrgan" || dirName == "upscaler") {
  3548. return "ESRGAN models for image upscaling";
  3549. } else if (dirName == "embeddings" || dirName == "textual-inversion") {
  3550. return "Text embeddings for concept control";
  3551. } else {
  3552. return "Unknown model directory";
  3553. }
  3554. }
  3555. nlohmann::json Server::getDirectoryContents(const std::string& dirPath) {
  3556. nlohmann::json contents = nlohmann::json::array();
  3557. try {
  3558. if (std::filesystem::exists(dirPath) && std::filesystem::is_directory(dirPath)) {
  3559. for (const auto& entry : std::filesystem::directory_iterator(dirPath)) {
  3560. if (entry.is_regular_file()) {
  3561. nlohmann::json file = {
  3562. {"name", entry.path().filename().string()},
  3563. {"path", entry.path().string()},
  3564. {"size", std::filesystem::file_size(entry.path())},
  3565. {"size_mb", std::filesystem::file_size(entry.path()) / (1024.0 * 1024.0)},
  3566. {"last_modified", std::chrono::duration_cast<std::chrono::seconds>(
  3567. std::filesystem::last_write_time(entry.path()).time_since_epoch()).count()}
  3568. };
  3569. contents.push_back(file);
  3570. }
  3571. }
  3572. }
  3573. } catch (const std::exception& e) {
  3574. // Return empty array if directory access fails
  3575. }
  3576. return contents;
  3577. }
  3578. nlohmann::json Server::getLargestModel(const std::map<std::string, ModelManager::ModelInfo>& allModels) {
  3579. nlohmann::json largest = nlohmann::json::object();
  3580. size_t maxSize = 0;
  3581. std::string largestName;
  3582. for (const auto& pair : allModels) {
  3583. if (pair.second.fileSize > maxSize) {
  3584. maxSize = pair.second.fileSize;
  3585. largestName = pair.second.name;
  3586. }
  3587. }
  3588. if (!largestName.empty()) {
  3589. largest = {
  3590. {"name", largestName},
  3591. {"size", maxSize},
  3592. {"size_mb", maxSize / (1024.0 * 1024.0)},
  3593. {"type", ModelManager::modelTypeToString(allModels.at(largestName).type)}
  3594. };
  3595. }
  3596. return largest;
  3597. }
  3598. nlohmann::json Server::getSmallestModel(const std::map<std::string, ModelManager::ModelInfo>& allModels) {
  3599. nlohmann::json smallest = nlohmann::json::object();
  3600. size_t minSize = SIZE_MAX;
  3601. std::string smallestName;
  3602. for (const auto& pair : allModels) {
  3603. if (pair.second.fileSize < minSize) {
  3604. minSize = pair.second.fileSize;
  3605. smallestName = pair.second.name;
  3606. }
  3607. }
  3608. if (!smallestName.empty()) {
  3609. smallest = {
  3610. {"name", smallestName},
  3611. {"size", minSize},
  3612. {"size_mb", minSize / (1024.0 * 1024.0)},
  3613. {"type", ModelManager::modelTypeToString(allModels.at(smallestName).type)}
  3614. };
  3615. }
  3616. return smallest;
  3617. }
  3618. nlohmann::json Server::validateModelFile(const std::string& modelPath, const std::string& modelType) {
  3619. nlohmann::json validation = {
  3620. {"is_valid", false},
  3621. {"errors", nlohmann::json::array()},
  3622. {"warnings", nlohmann::json::array()},
  3623. {"file_info", nlohmann::json::object()},
  3624. {"compatibility", nlohmann::json::object()},
  3625. {"recommendations", nlohmann::json::array()}
  3626. };
  3627. try {
  3628. if (!std::filesystem::exists(modelPath)) {
  3629. validation["errors"].push_back("File does not exist");
  3630. return validation;
  3631. }
  3632. if (!std::filesystem::is_regular_file(modelPath)) {
  3633. validation["errors"].push_back("Path is not a regular file");
  3634. return validation;
  3635. }
  3636. // Check file extension
  3637. std::string extension = std::filesystem::path(modelPath).extension().string();
  3638. if (extension.empty()) {
  3639. validation["errors"].push_back("Missing file extension");
  3640. return validation;
  3641. }
  3642. // Remove dot and convert to lowercase
  3643. if (extension[0] == '.') {
  3644. extension = extension.substr(1);
  3645. }
  3646. std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower);
  3647. // Validate extension based on model type
  3648. ModelType type = ModelManager::stringToModelType(modelType);
  3649. bool validExtension = false;
  3650. switch (type) {
  3651. case ModelType::CHECKPOINT:
  3652. validExtension = (extension == "safetensors" || extension == "ckpt" || extension == "gguf");
  3653. break;
  3654. case ModelType::LORA:
  3655. validExtension = (extension == "safetensors" || extension == "ckpt");
  3656. break;
  3657. case ModelType::CONTROLNET:
  3658. validExtension = (extension == "safetensors" || extension == "pth");
  3659. break;
  3660. case ModelType::VAE:
  3661. validExtension = (extension == "safetensors" || extension == "pt" || extension == "ckpt" || extension == "gguf");
  3662. break;
  3663. case ModelType::EMBEDDING:
  3664. validExtension = (extension == "safetensors" || extension == "pt");
  3665. break;
  3666. case ModelType::TAESD:
  3667. validExtension = (extension == "safetensors" || extension == "pth" || extension == "gguf");
  3668. break;
  3669. case ModelType::ESRGAN:
  3670. validExtension = (extension == "pth" || extension == "pt");
  3671. break;
  3672. default:
  3673. break;
  3674. }
  3675. if (!validExtension) {
  3676. validation["errors"].push_back("Invalid file extension for model type: " + extension);
  3677. }
  3678. // Check file size
  3679. size_t fileSize = std::filesystem::file_size(modelPath);
  3680. if (fileSize == 0) {
  3681. validation["errors"].push_back("File is empty");
  3682. } else if (fileSize > 8ULL * 1024 * 1024 * 1024) { // 8GB
  3683. validation["warnings"].push_back("Very large file may cause performance issues");
  3684. }
  3685. // Build file info
  3686. validation["file_info"] = {
  3687. {"path", modelPath},
  3688. {"size", fileSize},
  3689. {"size_mb", fileSize / (1024.0 * 1024.0)},
  3690. {"extension", extension},
  3691. {"last_modified", std::chrono::duration_cast<std::chrono::seconds>(
  3692. std::filesystem::last_write_time(modelPath).time_since_epoch()).count()}
  3693. };
  3694. // Check compatibility
  3695. validation["compatibility"] = {
  3696. {"extension_valid", validExtension},
  3697. {"size_appropriate", fileSize <= 4ULL * 1024 * 1024 * 1024}, // 4GB
  3698. {"recommended_format", "safetensors"}
  3699. };
  3700. // Add recommendations
  3701. if (!validExtension) {
  3702. validation["recommendations"].push_back("Convert to SafeTensors format for better security and performance");
  3703. }
  3704. if (fileSize > 2ULL * 1024 * 1024 * 1024) { // 2GB
  3705. validation["recommendations"].push_back("Consider using a smaller model for better performance");
  3706. }
  3707. // If no errors found, mark as valid
  3708. if (validation["errors"].empty()) {
  3709. validation["is_valid"] = true;
  3710. }
  3711. } catch (const std::exception& e) {
  3712. validation["errors"].push_back("Validation failed: " + std::string(e.what()));
  3713. }
  3714. return validation;
  3715. }
  3716. nlohmann::json Server::checkModelCompatibility(const ModelManager::ModelInfo& modelInfo, const std::string& systemInfo) {
  3717. nlohmann::json compatibility = {
  3718. {"is_compatible", true},
  3719. {"compatibility_score", 100},
  3720. {"issues", nlohmann::json::array()},
  3721. {"warnings", nlohmann::json::array()},
  3722. {"requirements", nlohmann::json::object()},
  3723. {"recommendations", nlohmann::json::array()},
  3724. {"system_info", nlohmann::json::object()}
  3725. };
  3726. // Check system compatibility
  3727. if (systemInfo == "auto") {
  3728. compatibility["system_info"] = {
  3729. {"cpu_cores", std::thread::hardware_concurrency()}
  3730. };
  3731. }
  3732. // Check model-specific compatibility issues
  3733. if (modelInfo.type == ModelType::CHECKPOINT) {
  3734. if (modelInfo.fileSize > 4ULL * 1024 * 1024 * 1024) { // 4GB
  3735. compatibility["warnings"].push_back("Large checkpoint model may require significant memory");
  3736. compatibility["compatibility_score"] = 80;
  3737. }
  3738. if (modelInfo.fileSize < 500 * 1024 * 1024) { // 500MB
  3739. compatibility["warnings"].push_back("Small checkpoint model may have limited capabilities");
  3740. compatibility["compatibility_score"] = 85;
  3741. }
  3742. } else if (modelInfo.type == ModelType::LORA) {
  3743. if (modelInfo.fileSize > 500 * 1024 * 1024) { // 500MB
  3744. compatibility["warnings"].push_back("Large LoRA may impact performance");
  3745. compatibility["compatibility_score"] = 75;
  3746. }
  3747. }
  3748. return compatibility;
  3749. }
  3750. nlohmann::json Server::calculateSpecificRequirements(const std::string& modelType, const std::string& resolution, const std::string& batchSize) {
  3751. (void)modelType; // Suppress unused parameter warning
  3752. nlohmann::json specific = {
  3753. {"memory_requirements", nlohmann::json::object()},
  3754. {"performance_impact", nlohmann::json::object()},
  3755. {"quality_expectations", nlohmann::json::object()}
  3756. };
  3757. // Parse resolution
  3758. int width = 512, height = 512;
  3759. try {
  3760. size_t xPos = resolution.find('x');
  3761. if (xPos != std::string::npos) {
  3762. width = std::stoi(resolution.substr(0, xPos));
  3763. height = std::stoi(resolution.substr(xPos + 1));
  3764. }
  3765. } catch (...) {
  3766. // Use defaults if parsing fails
  3767. }
  3768. // Parse batch size
  3769. int batch = 1;
  3770. try {
  3771. batch = std::stoi(batchSize);
  3772. } catch (...) {
  3773. // Use default if parsing fails
  3774. }
  3775. // Calculate memory requirements based on resolution and batch
  3776. size_t pixels = width * height;
  3777. size_t baseMemory = 1024 * 1024 * 1024; // 1GB base
  3778. size_t resolutionMemory = (pixels * 4) / (512 * 512); // Scale based on 512x512
  3779. size_t batchMemory = (batch - 1) * baseMemory * 0.5; // Additional memory for batch
  3780. specific["memory_requirements"] = {
  3781. {"base_memory_mb", baseMemory / (1024 * 1024)},
  3782. {"resolution_memory_mb", resolutionMemory / (1024 * 1024)},
  3783. {"batch_memory_mb", batchMemory / (1024 * 1024)},
  3784. {"total_memory_mb", (baseMemory + resolutionMemory + batchMemory) / (1024 * 1024)}
  3785. };
  3786. // Calculate performance impact
  3787. double performanceFactor = 1.0;
  3788. if (pixels > 512 * 512) {
  3789. performanceFactor = 1.5;
  3790. }
  3791. if (batch > 1) {
  3792. performanceFactor *= 1.2;
  3793. }
  3794. specific["performance_impact"] = {
  3795. {"resolution_factor", pixels > 512 * 512 ? 1.5 : 1.0},
  3796. {"batch_factor", batch > 1 ? 1.2 : 1.0},
  3797. {"overall_factor", performanceFactor}
  3798. };
  3799. return specific;
  3800. }
  3801. // Enhanced model management endpoint implementations
  3802. void Server::handleModelInfo(const httplib::Request& req, httplib::Response& res) {
  3803. std::string requestId = generateRequestId();
  3804. try {
  3805. if (!m_modelManager) {
  3806. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  3807. return;
  3808. }
  3809. // Extract model ID from URL path
  3810. std::string modelId = req.matches[1].str();
  3811. if (modelId.empty()) {
  3812. sendErrorResponse(res, "Missing model ID", 400, "MISSING_MODEL_ID", requestId);
  3813. return;
  3814. }
  3815. // Get model information
  3816. auto modelInfo = m_modelManager->getModelInfo(modelId);
  3817. if (modelInfo.name.empty()) {
  3818. sendErrorResponse(res, "Model not found", 404, "MODEL_NOT_FOUND", requestId);
  3819. return;
  3820. }
  3821. // Build comprehensive model information
  3822. nlohmann::json response = {
  3823. {"model", {
  3824. {"name", modelInfo.name},
  3825. {"path", modelInfo.path},
  3826. {"type", ModelManager::modelTypeToString(modelInfo.type)},
  3827. {"is_loaded", modelInfo.isLoaded},
  3828. {"file_size", modelInfo.fileSize},
  3829. {"file_size_mb", modelInfo.fileSize / (1024.0 * 1024.0)},
  3830. {"description", modelInfo.description},
  3831. {"metadata", modelInfo.metadata},
  3832. {"capabilities", getModelCapabilities(modelInfo.type)},
  3833. {"compatibility", getModelCompatibility(modelInfo)},
  3834. {"requirements", getModelRequirements(modelInfo.type)},
  3835. {"recommended_usage", getRecommendedUsage(modelInfo.type)},
  3836. {"last_modified", std::chrono::duration_cast<std::chrono::seconds>(
  3837. modelInfo.modifiedAt.time_since_epoch()).count()}
  3838. }},
  3839. {"request_id", requestId}
  3840. };
  3841. sendJsonResponse(res, response);
  3842. } catch (const std::exception& e) {
  3843. sendErrorResponse(res, std::string("Failed to get model info: ") + e.what(), 500, "MODEL_INFO_ERROR", requestId);
  3844. }
  3845. }
  3846. void Server::handleLoadModelById(const httplib::Request& req, httplib::Response& res) {
  3847. std::string requestId = generateRequestId();
  3848. try {
  3849. if (!m_modelManager) {
  3850. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  3851. return;
  3852. }
  3853. // Extract model ID from URL path (could be hash or name)
  3854. std::string modelIdentifier = req.matches[1].str();
  3855. if (modelIdentifier.empty()) {
  3856. sendErrorResponse(res, "Missing model identifier", 400, "MISSING_MODEL_ID", requestId);
  3857. return;
  3858. }
  3859. // Try to find by hash first (if it looks like a hash - 10+ hex chars)
  3860. std::string modelId = modelIdentifier;
  3861. if (modelIdentifier.length() >= 10 &&
  3862. std::all_of(modelIdentifier.begin(), modelIdentifier.end(),
  3863. [](char c) { return std::isxdigit(c); })) {
  3864. std::string foundName = m_modelManager->findModelByHash(modelIdentifier);
  3865. if (!foundName.empty()) {
  3866. modelId = foundName;
  3867. std::cout << "Resolved hash " << modelIdentifier << " to model: " << modelId << std::endl;
  3868. }
  3869. }
  3870. // Parse optional parameters from request body
  3871. nlohmann::json requestJson;
  3872. if (!req.body.empty()) {
  3873. try {
  3874. requestJson = nlohmann::json::parse(req.body);
  3875. } catch (const nlohmann::json::parse_error& e) {
  3876. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  3877. return;
  3878. }
  3879. }
  3880. // Unload previous model if one is loaded
  3881. std::string previousModel;
  3882. {
  3883. std::lock_guard<std::mutex> lock(m_currentModelMutex);
  3884. previousModel = m_currentlyLoadedModel;
  3885. }
  3886. if (!previousModel.empty() && previousModel != modelId) {
  3887. std::cout << "Unloading previous model: " << previousModel << std::endl;
  3888. m_modelManager->unloadModel(previousModel);
  3889. }
  3890. // Load model
  3891. bool success = m_modelManager->loadModel(modelId);
  3892. if (success) {
  3893. // Update currently loaded model
  3894. {
  3895. std::lock_guard<std::mutex> lock(m_currentModelMutex);
  3896. m_currentlyLoadedModel = modelId;
  3897. }
  3898. auto modelInfo = m_modelManager->getModelInfo(modelId);
  3899. nlohmann::json response = {
  3900. {"status", "success"},
  3901. {"model", {
  3902. {"name", modelInfo.name},
  3903. {"path", modelInfo.path},
  3904. {"type", ModelManager::modelTypeToString(modelInfo.type)},
  3905. {"is_loaded", modelInfo.isLoaded}
  3906. }},
  3907. {"request_id", requestId}
  3908. };
  3909. sendJsonResponse(res, response);
  3910. } else {
  3911. sendErrorResponse(res, "Failed to load model", 400, "MODEL_LOAD_FAILED", requestId);
  3912. }
  3913. } catch (const std::exception& e) {
  3914. sendErrorResponse(res, std::string("Model load failed: ") + e.what(), 500, "MODEL_LOAD_ERROR", requestId);
  3915. }
  3916. }
  3917. void Server::handleUnloadModelById(const httplib::Request& req, httplib::Response& res) {
  3918. std::string requestId = generateRequestId();
  3919. try {
  3920. if (!m_modelManager) {
  3921. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  3922. return;
  3923. }
  3924. // Extract model ID from URL path
  3925. std::string modelId = req.matches[1].str();
  3926. if (modelId.empty()) {
  3927. sendErrorResponse(res, "Missing model ID", 400, "MISSING_MODEL_ID", requestId);
  3928. return;
  3929. }
  3930. // Unload model
  3931. bool success = m_modelManager->unloadModel(modelId);
  3932. if (success) {
  3933. // Clear currently loaded model if it matches
  3934. {
  3935. std::lock_guard<std::mutex> lock(m_currentModelMutex);
  3936. if (m_currentlyLoadedModel == modelId) {
  3937. m_currentlyLoadedModel = "";
  3938. }
  3939. }
  3940. nlohmann::json response = {
  3941. {"status", "success"},
  3942. {"model", {
  3943. {"name", modelId},
  3944. {"is_loaded", false}
  3945. }},
  3946. {"request_id", requestId}
  3947. };
  3948. sendJsonResponse(res, response);
  3949. } else {
  3950. sendErrorResponse(res, "Failed to unload model or model not found", 404, "MODEL_UNLOAD_FAILED", requestId);
  3951. }
  3952. } catch (const std::exception& e) {
  3953. sendErrorResponse(res, std::string("Model unload failed: ") + e.what(), 500, "MODEL_UNLOAD_ERROR", requestId);
  3954. }
  3955. }
  3956. void Server::handleModelTypes(const httplib::Request& /*req*/, httplib::Response& res) {
  3957. std::string requestId = generateRequestId();
  3958. try {
  3959. nlohmann::json types = {
  3960. {"model_types", {
  3961. {
  3962. {"type", "checkpoint"},
  3963. {"description", "Main stable diffusion model files for text-to-image, image-to-image, and inpainting"},
  3964. {"extensions", {"safetensors", "ckpt", "gguf"}},
  3965. {"capabilities", {"text2img", "img2img", "inpainting", "controlnet", "lora", "vae"}},
  3966. {"recommended_for", "General purpose image generation"}
  3967. },
  3968. {
  3969. {"type", "lora"},
  3970. {"description", "LoRA adapter models for style transfer and character customization"},
  3971. {"extensions", {"safetensors", "ckpt"}},
  3972. {"capabilities", {"style_transfer", "character_customization"}},
  3973. {"requires", {"checkpoint"}},
  3974. {"recommended_for", "Style modification and character-specific generation"}
  3975. },
  3976. {
  3977. {"type", "controlnet"},
  3978. {"description", "ControlNet models for precise control over output composition"},
  3979. {"extensions", {"safetensors", "pth"}},
  3980. {"capabilities", {"precise_control", "composition_control"}},
  3981. {"requires", {"checkpoint"}},
  3982. {"recommended_for", "Precise control over image generation"}
  3983. },
  3984. {
  3985. {"type", "vae"},
  3986. {"description", "VAE models for improved encoding and decoding quality"},
  3987. {"extensions", {"safetensors", "pt", "ckpt", "gguf"}},
  3988. {"capabilities", {"encoding", "decoding", "quality_improvement"}},
  3989. {"requires", {"checkpoint"}},
  3990. {"recommended_for", "Improved image quality and encoding"}
  3991. },
  3992. {
  3993. {"type", "embedding"},
  3994. {"description", "Text embeddings for concept control and style words"},
  3995. {"extensions", {"safetensors", "pt"}},
  3996. {"capabilities", {"concept_control", "style_words"}},
  3997. {"requires", {"checkpoint"}},
  3998. {"recommended_for", "Concept control and specific styles"}
  3999. },
  4000. {
  4001. {"type", "taesd"},
  4002. {"description", "TAESD models for real-time decoding"},
  4003. {"extensions", {"safetensors", "pth", "gguf"}},
  4004. {"capabilities", {"real_time_decoding", "fast_preview"}},
  4005. {"requires", {"checkpoint"}},
  4006. {"recommended_for", "Real-time applications and fast previews"}
  4007. },
  4008. {
  4009. {"type", "esrgan"},
  4010. {"description", "ESRGAN models for image upscaling and enhancement"},
  4011. {"extensions", {"pth", "pt"}},
  4012. {"capabilities", {"upscaling", "enhancement", "quality_improvement"}},
  4013. {"recommended_for", "Image upscaling and quality enhancement"}
  4014. }
  4015. }},
  4016. {"request_id", requestId}
  4017. };
  4018. sendJsonResponse(res, types);
  4019. } catch (const std::exception& e) {
  4020. sendErrorResponse(res, std::string("Failed to get model types: ") + e.what(), 500, "MODEL_TYPES_ERROR", requestId);
  4021. }
  4022. }
  4023. void Server::handleModelDirectories(const httplib::Request& /*req*/, httplib::Response& res) {
  4024. std::string requestId = generateRequestId();
  4025. try {
  4026. if (!m_modelManager) {
  4027. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  4028. return;
  4029. }
  4030. std::string modelsDir = m_modelManager->getModelsDirectory();
  4031. nlohmann::json directories = nlohmann::json::array();
  4032. // Define expected model directories
  4033. std::vector<std::string> modelDirs = {
  4034. "stable-diffusion", "checkpoints", "lora", "controlnet",
  4035. "vae", "taesd", "esrgan", "embeddings"
  4036. };
  4037. for (const auto& dirName : modelDirs) {
  4038. std::string dirPath = modelsDir + "/" + dirName;
  4039. std::string type = getModelTypeFromDirectoryName(dirName);
  4040. std::string description = getDirectoryDescription(dirName);
  4041. nlohmann::json dirInfo = {
  4042. {"name", dirName},
  4043. {"path", dirPath},
  4044. {"type", type},
  4045. {"description", description},
  4046. {"exists", std::filesystem::exists(dirPath) && std::filesystem::is_directory(dirPath)},
  4047. {"contents", getDirectoryContents(dirPath)}
  4048. };
  4049. directories.push_back(dirInfo);
  4050. }
  4051. nlohmann::json response = {
  4052. {"models_directory", modelsDir},
  4053. {"directories", directories},
  4054. {"request_id", requestId}
  4055. };
  4056. sendJsonResponse(res, response);
  4057. } catch (const std::exception& e) {
  4058. sendErrorResponse(res, std::string("Failed to get model directories: ") + e.what(), 500, "MODEL_DIRECTORIES_ERROR", requestId);
  4059. }
  4060. }
  4061. void Server::handleRefreshModels(const httplib::Request& /*req*/, httplib::Response& res) {
  4062. std::string requestId = generateRequestId();
  4063. try {
  4064. if (!m_modelManager) {
  4065. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  4066. return;
  4067. }
  4068. // Force refresh of model cache
  4069. bool success = m_modelManager->scanModelsDirectory();
  4070. if (success) {
  4071. nlohmann::json response = {
  4072. {"status", "success"},
  4073. {"message", "Model cache refreshed successfully"},
  4074. {"models_found", m_modelManager->getAvailableModelsCount()},
  4075. {"models_loaded", m_modelManager->getLoadedModelsCount()},
  4076. {"models_directory", m_modelManager->getModelsDirectory()},
  4077. {"request_id", requestId}
  4078. };
  4079. sendJsonResponse(res, response);
  4080. } else {
  4081. sendErrorResponse(res, "Failed to refresh model cache", 500, "MODEL_REFRESH_FAILED", requestId);
  4082. }
  4083. } catch (const std::exception& e) {
  4084. sendErrorResponse(res, std::string("Model refresh failed: ") + e.what(), 500, "MODEL_REFRESH_ERROR", requestId);
  4085. }
  4086. }
  4087. void Server::handleHashModels(const httplib::Request& req, httplib::Response& res) {
  4088. std::string requestId = generateRequestId();
  4089. try {
  4090. if (!m_generationQueue || !m_modelManager) {
  4091. sendErrorResponse(res, "Services not available", 500, "SERVICE_UNAVAILABLE", requestId);
  4092. return;
  4093. }
  4094. // Parse request body
  4095. nlohmann::json requestJson;
  4096. if (!req.body.empty()) {
  4097. requestJson = nlohmann::json::parse(req.body);
  4098. }
  4099. HashRequest hashReq;
  4100. hashReq.id = requestId;
  4101. hashReq.forceRehash = requestJson.value("force_rehash", false);
  4102. if (requestJson.contains("models") && requestJson["models"].is_array()) {
  4103. for (const auto& model : requestJson["models"]) {
  4104. hashReq.modelNames.push_back(model.get<std::string>());
  4105. }
  4106. }
  4107. // Enqueue hash request
  4108. auto future = m_generationQueue->enqueueHashRequest(hashReq);
  4109. nlohmann::json response = {
  4110. {"request_id", requestId},
  4111. {"status", "queued"},
  4112. {"message", "Hash job queued successfully"},
  4113. {"models_to_hash", hashReq.modelNames.empty() ? "all_unhashed" : std::to_string(hashReq.modelNames.size())}
  4114. };
  4115. sendJsonResponse(res, response, 202);
  4116. } catch (const nlohmann::json::parse_error& e) {
  4117. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  4118. } catch (const std::exception& e) {
  4119. sendErrorResponse(res, std::string("Hash request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  4120. }
  4121. }
  4122. void Server::handleConvertModel(const httplib::Request& req, httplib::Response& res) {
  4123. std::string requestId = generateRequestId();
  4124. try {
  4125. if (!m_generationQueue || !m_modelManager) {
  4126. sendErrorResponse(res, "Services not available", 500, "SERVICE_UNAVAILABLE", requestId);
  4127. return;
  4128. }
  4129. // Parse request body
  4130. nlohmann::json requestJson;
  4131. try {
  4132. requestJson = nlohmann::json::parse(req.body);
  4133. } catch (const nlohmann::json::parse_error& e) {
  4134. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  4135. return;
  4136. }
  4137. // Validate required fields
  4138. if (!requestJson.contains("model_name")) {
  4139. sendErrorResponse(res, "Missing required field: model_name", 400, "MISSING_FIELD", requestId);
  4140. return;
  4141. }
  4142. if (!requestJson.contains("quantization_type")) {
  4143. sendErrorResponse(res, "Missing required field: quantization_type", 400, "MISSING_FIELD", requestId);
  4144. return;
  4145. }
  4146. std::string modelName = requestJson["model_name"].get<std::string>();
  4147. std::string quantizationType = requestJson["quantization_type"].get<std::string>();
  4148. // Validate quantization type
  4149. const std::vector<std::string> validTypes = {"f32", "f16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "q2_K", "q3_K", "q4_K"};
  4150. if (std::find(validTypes.begin(), validTypes.end(), quantizationType) == validTypes.end()) {
  4151. sendErrorResponse(res, "Invalid quantization_type. Valid types: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K",
  4152. 400, "INVALID_QUANTIZATION_TYPE", requestId);
  4153. return;
  4154. }
  4155. // Get model info to find the full path
  4156. auto modelInfo = m_modelManager->getModelInfo(modelName);
  4157. if (modelInfo.name.empty()) {
  4158. sendErrorResponse(res, "Model not found: " + modelName, 404, "MODEL_NOT_FOUND", requestId);
  4159. return;
  4160. }
  4161. // Check if model is already GGUF
  4162. if (modelInfo.fullPath.find(".gguf") != std::string::npos) {
  4163. sendErrorResponse(res, "Model is already in GGUF format. Cannot convert GGUF to GGUF.",
  4164. 400, "ALREADY_GGUF", requestId);
  4165. return;
  4166. }
  4167. // Build output path
  4168. std::string outputPath = requestJson.value("output_path", "");
  4169. if (outputPath.empty()) {
  4170. // Generate default output path: model_name_quantization.gguf
  4171. namespace fs = std::filesystem;
  4172. fs::path inputPath(modelInfo.fullPath);
  4173. std::string baseName = inputPath.stem().string();
  4174. std::string outputDir = inputPath.parent_path().string();
  4175. outputPath = outputDir + "/" + baseName + "_" + quantizationType + ".gguf";
  4176. }
  4177. // Create conversion request
  4178. ConversionRequest convReq;
  4179. convReq.id = requestId;
  4180. convReq.modelName = modelName;
  4181. convReq.modelPath = modelInfo.fullPath;
  4182. convReq.outputPath = outputPath;
  4183. convReq.quantizationType = quantizationType;
  4184. // Enqueue conversion request
  4185. auto future = m_generationQueue->enqueueConversionRequest(convReq);
  4186. nlohmann::json response = {
  4187. {"request_id", requestId},
  4188. {"status", "queued"},
  4189. {"message", "Model conversion queued successfully"},
  4190. {"model_name", modelName},
  4191. {"input_path", modelInfo.fullPath},
  4192. {"output_path", outputPath},
  4193. {"quantization_type", quantizationType}
  4194. };
  4195. sendJsonResponse(res, response, 202);
  4196. } catch (const std::exception& e) {
  4197. sendErrorResponse(res, std::string("Conversion request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  4198. }
  4199. }
  4200. void Server::handleModelStats(const httplib::Request& /*req*/, httplib::Response& res) {
  4201. std::string requestId = generateRequestId();
  4202. try {
  4203. if (!m_modelManager) {
  4204. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  4205. return;
  4206. }
  4207. auto allModels = m_modelManager->getAllModels();
  4208. nlohmann::json response = {
  4209. {"statistics", {
  4210. {"total_models", allModels.size()},
  4211. {"loaded_models", m_modelManager->getLoadedModelsCount()},
  4212. {"available_models", m_modelManager->getAvailableModelsCount()},
  4213. {"model_types", getModelTypeStatistics()},
  4214. {"largest_model", getLargestModel(allModels)},
  4215. {"smallest_model", getSmallestModel(allModels)}
  4216. }},
  4217. {"request_id", requestId}
  4218. };
  4219. sendJsonResponse(res, response);
  4220. } catch (const std::exception& e) {
  4221. sendErrorResponse(res, std::string("Failed to get model stats: ") + e.what(), 500, "MODEL_STATS_ERROR", requestId);
  4222. }
  4223. }
  4224. void Server::handleBatchModels(const httplib::Request& req, httplib::Response& res) {
  4225. std::string requestId = generateRequestId();
  4226. try {
  4227. if (!m_modelManager) {
  4228. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  4229. return;
  4230. }
  4231. // Parse JSON request body
  4232. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  4233. if (!requestJson.contains("operation") || !requestJson["operation"].is_string()) {
  4234. sendErrorResponse(res, "Missing or invalid 'operation' field", 400, "INVALID_OPERATION", requestId);
  4235. return;
  4236. }
  4237. if (!requestJson.contains("models") || !requestJson["models"].is_array()) {
  4238. sendErrorResponse(res, "Missing or invalid 'models' field", 400, "INVALID_MODELS", requestId);
  4239. return;
  4240. }
  4241. std::string operation = requestJson["operation"];
  4242. nlohmann::json models = requestJson["models"];
  4243. nlohmann::json results = nlohmann::json::array();
  4244. for (const auto& model : models) {
  4245. if (!model.is_string()) {
  4246. results.push_back({
  4247. {"model", model},
  4248. {"success", false},
  4249. {"error", "Invalid model name"}
  4250. });
  4251. continue;
  4252. }
  4253. std::string modelName = model;
  4254. bool success = false;
  4255. std::string error = "";
  4256. if (operation == "load") {
  4257. success = m_modelManager->loadModel(modelName);
  4258. if (!success) error = "Failed to load model";
  4259. } else if (operation == "unload") {
  4260. success = m_modelManager->unloadModel(modelName);
  4261. if (!success) error = "Failed to unload model";
  4262. } else {
  4263. error = "Unsupported operation";
  4264. }
  4265. results.push_back({
  4266. {"model", modelName},
  4267. {"success", success},
  4268. {"error", error.empty() ? nlohmann::json(nullptr) : nlohmann::json(error)}
  4269. });
  4270. }
  4271. nlohmann::json response = {
  4272. {"operation", operation},
  4273. {"results", results},
  4274. {"successful_count", std::count_if(results.begin(), results.end(),
  4275. [](const nlohmann::json& result) { return result["success"].get<bool>(); })},
  4276. {"failed_count", std::count_if(results.begin(), results.end(),
  4277. [](const nlohmann::json& result) { return !result["success"].get<bool>(); })},
  4278. {"request_id", requestId}
  4279. };
  4280. sendJsonResponse(res, response);
  4281. } catch (const nlohmann::json::parse_error& e) {
  4282. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  4283. } catch (const std::exception& e) {
  4284. sendErrorResponse(res, std::string("Batch operation failed: ") + e.what(), 500, "BATCH_OPERATION_ERROR", requestId);
  4285. }
  4286. }
  4287. void Server::handleValidateModel(const httplib::Request& req, httplib::Response& res) {
  4288. std::string requestId = generateRequestId();
  4289. try {
  4290. // Parse JSON request body
  4291. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  4292. if (!requestJson.contains("model_path") || !requestJson["model_path"].is_string()) {
  4293. sendErrorResponse(res, "Missing or invalid 'model_path' field", 400, "INVALID_MODEL_PATH", requestId);
  4294. return;
  4295. }
  4296. std::string modelPath = requestJson["model_path"];
  4297. std::string modelType = requestJson.value("model_type", "checkpoint");
  4298. // Validate model file
  4299. nlohmann::json validation = validateModelFile(modelPath, modelType);
  4300. nlohmann::json response = {
  4301. {"validation", validation},
  4302. {"request_id", requestId}
  4303. };
  4304. sendJsonResponse(res, response);
  4305. } catch (const nlohmann::json::parse_error& e) {
  4306. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  4307. } catch (const std::exception& e) {
  4308. sendErrorResponse(res, std::string("Model validation failed: ") + e.what(), 500, "MODEL_VALIDATION_ERROR", requestId);
  4309. }
  4310. }
  4311. void Server::handleCheckCompatibility(const httplib::Request& req, httplib::Response& res) {
  4312. std::string requestId = generateRequestId();
  4313. try {
  4314. if (!m_modelManager) {
  4315. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  4316. return;
  4317. }
  4318. // Parse JSON request body
  4319. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  4320. if (!requestJson.contains("model_name") || !requestJson["model_name"].is_string()) {
  4321. sendErrorResponse(res, "Missing or invalid 'model_name' field", 400, "INVALID_MODEL_NAME", requestId);
  4322. return;
  4323. }
  4324. std::string modelName = requestJson["model_name"];
  4325. std::string systemInfo = requestJson.value("system_info", "auto");
  4326. // Get model information
  4327. auto modelInfo = m_modelManager->getModelInfo(modelName);
  4328. if (modelInfo.name.empty()) {
  4329. sendErrorResponse(res, "Model not found", 404, "MODEL_NOT_FOUND", requestId);
  4330. return;
  4331. }
  4332. // Check compatibility
  4333. nlohmann::json compatibility = checkModelCompatibility(modelInfo, systemInfo);
  4334. nlohmann::json response = {
  4335. {"model", modelName},
  4336. {"compatibility", compatibility},
  4337. {"request_id", requestId}
  4338. };
  4339. sendJsonResponse(res, response);
  4340. } catch (const nlohmann::json::parse_error& e) {
  4341. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  4342. } catch (const std::exception& e) {
  4343. sendErrorResponse(res, std::string("Compatibility check failed: ") + e.what(), 500, "COMPATIBILITY_CHECK_ERROR", requestId);
  4344. }
  4345. }
  4346. void Server::handleModelRequirements(const httplib::Request& req, httplib::Response& res) {
  4347. std::string requestId = generateRequestId();
  4348. try {
  4349. // Parse JSON request body
  4350. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  4351. std::string modelType = requestJson.value("model_type", "checkpoint");
  4352. std::string resolution = requestJson.value("resolution", "512x512");
  4353. std::string batchSize = requestJson.value("batch_size", "1");
  4354. // Calculate specific requirements
  4355. nlohmann::json requirements = calculateSpecificRequirements(modelType, resolution, batchSize);
  4356. // Get general requirements for model type
  4357. ModelType type = ModelManager::stringToModelType(modelType);
  4358. nlohmann::json generalRequirements = getModelRequirements(type);
  4359. nlohmann::json response = {
  4360. {"model_type", modelType},
  4361. {"configuration", {
  4362. {"resolution", resolution},
  4363. {"batch_size", batchSize}
  4364. }},
  4365. {"specific_requirements", requirements},
  4366. {"general_requirements", generalRequirements},
  4367. {"request_id", requestId}
  4368. };
  4369. sendJsonResponse(res, response);
  4370. } catch (const nlohmann::json::parse_error& e) {
  4371. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  4372. } catch (const std::exception& e) {
  4373. sendErrorResponse(res, std::string("Requirements calculation failed: ") + e.what(), 500, "REQUIREMENTS_ERROR", requestId);
  4374. }
  4375. }
  4376. void Server::serverThreadFunction(const std::string& host, int port) {
  4377. try {
  4378. std::cout << "Server thread starting, attempting to bind to " << host << ":" << port << std::endl;
  4379. // Check if port is available before attempting to bind
  4380. std::cout << "Checking if port " << port << " is available..." << std::endl;
  4381. // Try to create a test socket to check if port is in use
  4382. int test_socket = socket(AF_INET, SOCK_STREAM, 0);
  4383. if (test_socket >= 0) {
  4384. // Set SO_REUSEADDR to avoid TIME_WAIT issues
  4385. int opt = 1;
  4386. if (setsockopt(test_socket, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) {
  4387. std::cerr << "Warning: Failed to set SO_REUSEADDR on test socket: " << strerror(errno) << std::endl;
  4388. }
  4389. // Also set SO_REUSEPORT if available (for better concurrent binding handling)
  4390. #ifdef SO_REUSEPORT
  4391. int reuseport = 1;
  4392. if (setsockopt(test_socket, SOL_SOCKET, SO_REUSEPORT, &reuseport, sizeof(reuseport)) < 0) {
  4393. std::cerr << "Warning: Failed to set SO_REUSEPORT on test socket: " << strerror(errno) << std::endl;
  4394. }
  4395. #endif
  4396. struct sockaddr_in addr;
  4397. addr.sin_family = AF_INET;
  4398. addr.sin_port = htons(port);
  4399. addr.sin_addr.s_addr = INADDR_ANY;
  4400. // Try to bind to the port
  4401. if (bind(test_socket, (struct sockaddr*)&addr, sizeof(addr)) < 0) {
  4402. close(test_socket);
  4403. std::cerr << "ERROR: Port " << port << " is already in use! Cannot start server." << std::endl;
  4404. std::cerr << "This could be due to:" << std::endl;
  4405. std::cerr << "1. Another instance is already running on this port" << std::endl;
  4406. std::cerr << "2. A previous instance crashed and the socket is in TIME_WAIT state" << std::endl;
  4407. std::cerr << "3. The port is being used by another application" << std::endl;
  4408. std::cerr << std::endl;
  4409. std::cerr << "Solutions:" << std::endl;
  4410. std::cerr << "- Wait 30-60 seconds for TIME_WAIT to expire (if server crashed)" << std::endl;
  4411. std::cerr << "- Kill any existing processes: sudo lsof -ti:" << port << " | xargs kill -9" << std::endl;
  4412. std::cerr << "- Use a different port with -p <port>" << std::endl;
  4413. m_isRunning.store(false);
  4414. m_startupFailed.store(true);
  4415. return;
  4416. }
  4417. close(test_socket);
  4418. }
  4419. std::cout << "Port " << port << " is available, proceeding with server startup..." << std::endl;
  4420. std::cout << "Calling listen()..." << std::endl;
  4421. // We need to set m_isRunning after successful bind but before blocking
  4422. // cpp-httplib doesn't provide a callback, so we set it optimistically
  4423. // and clear it if listen() returns false
  4424. m_isRunning.store(true);
  4425. bool listenResult = m_httpServer->listen(host.c_str(), port);
  4426. std::cout << "listen() returned: " << (listenResult ? "true" : "false") << std::endl;
  4427. // If we reach here, server has stopped (either normally or due to error)
  4428. m_isRunning.store(false);
  4429. if (!listenResult) {
  4430. std::cerr << "Server listen failed! This usually means port is in use or permission denied." << std::endl;
  4431. }
  4432. } catch (const std::exception& e) {
  4433. std::cerr << "Exception in server thread: " << e.what() << std::endl;
  4434. m_isRunning.store(false);
  4435. }
  4436. }