server.cpp 211 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894489548964897489848994900490149024903490449054906490749084909491049114912491349144915491649174918491949204921492249234924492549264927492849294930493149324933493449354936493749384939494049414942494349444945494649474948494949504951495249534954495549564957495849594960496149624963496449654966496749684969497049714972497349744975497649774978497949804981498249834984498549864987498849894990499149924993499449954996499749984999500050015002500350045005500650075008500950105011501250135014501550165017501850195020502150225023502450255026502750285029503050315032503350345035503650375038503950405041504250435044504550465047504850495050505150525053505450555056505750585059506050615062506350645065506650675068506950705071507250735074507550765077507850795080508150825083508450855086508750885089509050915092509350945095509650975098509951005101510251035104510551065107510851095110511151125113511451155116511751185119512051215122512351245125512651275128512951305131513251335134513551365137513851395140514151425143514451455146514751485149515051515152515351545155515651575158515951605161516251635164516551665167516851695170517151725173517451755176517751785179518051815182518351845185518651875188518951905191519251935194519551965197519851995200520152025203520452055206520752085209521052115212521352145215521652175218521952205221522252235224522552265227522852295230523152325233523452355236523752385239524052415242524352445245524652475248524952505251525252535254525552565257525852595260526152625263526452655266526752685269527052715272527352745275527652775278527952805281528252835284528552865287
  1. #include "server.h"
  2. #include "model_manager.h"
  3. #include "generation_queue.h"
  4. #include "utils.h"
  5. #include "auth_middleware.h"
  6. #include "user_manager.h"
  7. #include "version.h"
  8. #include <httplib.h>
  9. #include <nlohmann/json.hpp>
  10. #include <iostream>
  11. #include <sstream>
  12. #include <fstream>
  13. #include <chrono>
  14. #include <random>
  15. #include <algorithm>
  16. #include <thread>
  17. #include <filesystem>
  18. // Include stb_image for loading images (implementation is in generation_queue.cpp)
  19. #include "../stable-diffusion.cpp-src/thirdparty/stb_image.h"
  20. #include <sys/socket.h>
  21. #include <netinet/in.h>
  22. #include <unistd.h>
  23. #include <arpa/inet.h>
  24. Server::Server(ModelManager* modelManager, GenerationQueue* generationQueue, const std::string& outputDir, const std::string& uiDir, const ServerConfig& config)
  25. : m_modelManager(modelManager)
  26. , m_generationQueue(generationQueue)
  27. , m_isRunning(false)
  28. , m_startupFailed(false)
  29. , m_port(config.port)
  30. , m_outputDir(outputDir)
  31. , m_uiDir(uiDir)
  32. , m_userManager(nullptr)
  33. , m_authMiddleware(nullptr)
  34. , m_config(config)
  35. {
  36. m_httpServer = std::make_unique<httplib::Server>();
  37. }
  38. Server::~Server() {
  39. stop();
  40. }
  41. bool Server::start(const std::string& host, int port) {
  42. if (m_isRunning.load()) {
  43. return false;
  44. }
  45. m_host = host;
  46. m_port = port;
  47. // Validate host and port
  48. if (host.empty() || (port < 1 || port > 65535)) {
  49. return false;
  50. }
  51. // Set up CORS headers
  52. setupCORS();
  53. // Register API endpoints
  54. registerEndpoints();
  55. // Reset startup flags
  56. m_startupFailed.store(false);
  57. // Start server in a separate thread
  58. m_serverThread = std::thread(&Server::serverThreadFunction, this, host, port);
  59. // Wait for server to actually start and bind to the port
  60. // Give more time for server to actually start and bind
  61. for (int i = 0; i < 100; i++) { // Wait up to 10 seconds
  62. std::this_thread::sleep_for(std::chrono::milliseconds(100));
  63. // Check if startup failed early
  64. if (m_startupFailed.load()) {
  65. if (m_serverThread.joinable()) {
  66. m_serverThread.join();
  67. }
  68. return false;
  69. }
  70. if (m_isRunning.load()) {
  71. // Give it a moment more to ensure server is fully started
  72. std::this_thread::sleep_for(std::chrono::milliseconds(500));
  73. if (m_isRunning.load()) {
  74. return true;
  75. }
  76. }
  77. }
  78. if (m_isRunning.load()) {
  79. return true;
  80. } else {
  81. if (m_serverThread.joinable()) {
  82. m_serverThread.join();
  83. }
  84. return false;
  85. }
  86. }
  87. void Server::stop() {
  88. // Use atomic check to ensure thread safety
  89. bool wasRunning = m_isRunning.exchange(false);
  90. if (!wasRunning) {
  91. return; // Already stopped
  92. }
  93. if (m_httpServer) {
  94. m_httpServer->stop();
  95. // Give the server a moment to stop the blocking listen call
  96. std::this_thread::sleep_for(std::chrono::milliseconds(100));
  97. // If server thread is still running, try to force unblock the listen call
  98. // by making a quick connection to the server port
  99. if (m_serverThread.joinable()) {
  100. try {
  101. // Create a quick connection to interrupt the blocking listen
  102. httplib::Client client("127.0.0.1", m_port);
  103. client.set_connection_timeout(0, m_config.connectionTimeoutMs * 1000); // Convert ms to microseconds
  104. client.set_read_timeout(0, m_config.readTimeoutMs * 1000); // Convert ms to microseconds
  105. client.set_write_timeout(0, m_config.writeTimeoutMs * 1000); // Convert ms to microseconds
  106. auto res = client.Get("/api/health");
  107. // We don't care about the response, just trying to unblock
  108. } catch (...) {
  109. // Ignore any connection errors - we're just trying to unblock
  110. }
  111. }
  112. }
  113. if (m_serverThread.joinable()) {
  114. m_serverThread.join();
  115. }
  116. }
  117. bool Server::isRunning() const {
  118. return m_isRunning.load();
  119. }
  120. void Server::waitForStop() {
  121. if (m_serverThread.joinable()) {
  122. m_serverThread.join();
  123. }
  124. }
  125. void Server::registerEndpoints() {
  126. // Register authentication endpoints first (before applying middleware)
  127. registerAuthEndpoints();
  128. // Health check endpoint (public)
  129. m_httpServer->Get("/api/health", [this](const httplib::Request& req, httplib::Response& res) {
  130. handleHealthCheck(req, res);
  131. });
  132. // API status endpoint (public)
  133. m_httpServer->Get("/api/status", [this](const httplib::Request& req, httplib::Response& res) {
  134. handleApiStatus(req, res);
  135. });
  136. // Version information endpoint (public)
  137. m_httpServer->Get("/api/version", [this](const httplib::Request& req, httplib::Response& res) {
  138. handleVersion(req, res);
  139. });
  140. // Apply authentication middleware to protected endpoints
  141. auto withAuth = [this](std::function<void(const httplib::Request&, httplib::Response&)> handler) {
  142. return [this, handler](const httplib::Request& req, httplib::Response& res) {
  143. if (m_authMiddleware) {
  144. AuthContext authContext = m_authMiddleware->authenticate(req, res);
  145. if (!authContext.authenticated) {
  146. m_authMiddleware->sendAuthError(res, authContext.errorMessage, authContext.errorCode);
  147. return;
  148. }
  149. }
  150. handler(req, res);
  151. };
  152. };
  153. // Specialized generation endpoints (protected)
  154. m_httpServer->Post("/api/generate/text2img", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  155. handleText2Img(req, res);
  156. }));
  157. m_httpServer->Post("/api/generate/img2img", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  158. handleImg2Img(req, res);
  159. }));
  160. m_httpServer->Post("/api/generate/controlnet", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  161. handleControlNet(req, res);
  162. }));
  163. m_httpServer->Post("/api/generate/upscale", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  164. handleUpscale(req, res);
  165. }));
  166. m_httpServer->Post("/api/generate/inpainting", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  167. handleInpainting(req, res);
  168. }));
  169. // Utility endpoints (now protected - require authentication)
  170. m_httpServer->Get("/api/samplers", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  171. handleSamplers(req, res);
  172. }));
  173. m_httpServer->Get("/api/schedulers", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  174. handleSchedulers(req, res);
  175. }));
  176. m_httpServer->Get("/api/parameters", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  177. handleParameters(req, res);
  178. }));
  179. m_httpServer->Post("/api/validate", [this](const httplib::Request& req, httplib::Response& res) {
  180. handleValidate(req, res);
  181. });
  182. m_httpServer->Post("/api/estimate", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  183. handleEstimate(req, res);
  184. }));
  185. m_httpServer->Get("/api/config", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  186. handleConfig(req, res);
  187. }));
  188. m_httpServer->Get("/api/system", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  189. handleSystem(req, res);
  190. }));
  191. m_httpServer->Post("/api/system/restart", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  192. handleSystemRestart(req, res);
  193. }));
  194. // Models list endpoint (now protected - require authentication)
  195. m_httpServer->Get("/api/models", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  196. handleModelsList(req, res);
  197. }));
  198. // Model-specific endpoints
  199. m_httpServer->Get("/api/models/(.*)", [this](const httplib::Request& req, httplib::Response& res) {
  200. handleModelInfo(req, res);
  201. });
  202. m_httpServer->Post("/api/models/(.*)/load", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  203. handleLoadModelById(req, res);
  204. }));
  205. m_httpServer->Post("/api/models/(.*)/unload", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  206. handleUnloadModelById(req, res);
  207. }));
  208. // Model management endpoints (now protected - require authentication)
  209. m_httpServer->Get("/api/models/types", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  210. handleModelTypes(req, res);
  211. }));
  212. m_httpServer->Get("/api/models/directories", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  213. handleModelDirectories(req, res);
  214. }));
  215. m_httpServer->Post("/api/models/refresh", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  216. handleRefreshModels(req, res);
  217. }));
  218. m_httpServer->Post("/api/models/hash", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  219. handleHashModels(req, res);
  220. }));
  221. m_httpServer->Post("/api/models/convert", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  222. handleConvertModel(req, res);
  223. }));
  224. m_httpServer->Get("/api/models/stats", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  225. handleModelStats(req, res);
  226. }));
  227. m_httpServer->Post("/api/models/batch", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  228. handleBatchModels(req, res);
  229. }));
  230. // Model validation endpoints (already protected with withAuth)
  231. m_httpServer->Post("/api/models/validate", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  232. handleValidateModel(req, res);
  233. }));
  234. m_httpServer->Post("/api/models/compatible", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  235. handleCheckCompatibility(req, res);
  236. }));
  237. m_httpServer->Post("/api/models/requirements", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  238. handleModelRequirements(req, res);
  239. }));
  240. // Queue status endpoint (now protected - require authentication)
  241. m_httpServer->Get("/api/queue/status", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  242. handleQueueStatus(req, res);
  243. }));
  244. // Download job output file endpoint (must be before job status endpoint to match more specific pattern first)
  245. // Note: This endpoint is public to allow frontend to display generated images without authentication
  246. m_httpServer->Get("/api/queue/job/(.*)/output/(.*)", [this](const httplib::Request& req, httplib::Response& res) {
  247. handleDownloadOutput(req, res);
  248. });
  249. // Get job output by job ID endpoint (public to allow frontend to display generated images without authentication)
  250. m_httpServer->Get("/api/v1/jobs/(.*)/output", [this](const httplib::Request& req, httplib::Response& res) {
  251. handleJobOutput(req, res);
  252. });
  253. // Download image from URL endpoint (public for CORS-free image handling)
  254. m_httpServer->Get("/api/image/download", [this](const httplib::Request& req, httplib::Response& res) {
  255. handleDownloadImageFromUrl(req, res);
  256. });
  257. // Image resize endpoint (protected)
  258. m_httpServer->Post("/api/image/resize", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  259. handleImageResize(req, res);
  260. }));
  261. // Image crop endpoint (protected)
  262. m_httpServer->Post("/api/image/crop", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  263. handleImageCrop(req, res);
  264. }));
  265. // Job status endpoint (now protected - require authentication)
  266. m_httpServer->Get("/api/queue/job/(.*)", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  267. handleJobStatus(req, res);
  268. }));
  269. // Cancel job endpoint (protected)
  270. m_httpServer->Post("/api/queue/cancel", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  271. handleCancelJob(req, res);
  272. }));
  273. // Clear queue endpoint (protected)
  274. m_httpServer->Post("/api/queue/clear", withAuth([this](const httplib::Request& req, httplib::Response& res) {
  275. handleClearQueue(req, res);
  276. }));
  277. // Serve static web UI files if uiDir is configured
  278. if (!m_uiDir.empty() && std::filesystem::exists(m_uiDir)) {
  279. std::cout << "Serving static UI files from: " << m_uiDir << " at /ui" << std::endl;
  280. // Read UI version from version.nlohmann::json if available
  281. std::string uiVersion = "unknown";
  282. std::string versionFilePath = m_uiDir + "/version.nlohmann::json";
  283. if (std::filesystem::exists(versionFilePath)) {
  284. try {
  285. std::ifstream versionFile(versionFilePath);
  286. if (versionFile.is_open()) {
  287. nlohmann::json versionData = nlohmann::json::parse(versionFile);
  288. if (versionData.contains("version")) {
  289. uiVersion = versionData["version"].get<std::string>();
  290. }
  291. versionFile.close();
  292. }
  293. } catch (const std::exception& e) {
  294. std::cerr << "Failed to read UI version: " << e.what() << std::endl;
  295. }
  296. }
  297. std::cout << "UI version: " << uiVersion << std::endl;
  298. // Serve dynamic config.js that provides runtime configuration to the web UI
  299. m_httpServer->Get("/ui/config.js", [this, uiVersion](const httplib::Request& /*req*/, httplib::Response& res) {
  300. // Generate JavaScript configuration with current server settings
  301. std::ostringstream configJs;
  302. configJs << "// Auto-generated configuration\n"
  303. << "window.__SERVER_CONFIG__ = {\n"
  304. << " apiUrl: 'http://" << m_host << ":" << m_port << "',\n"
  305. << " apiBasePath: '/api',\n"
  306. << " host: '" << m_host << "',\n"
  307. << " port: " << m_port << ",\n"
  308. << " uiVersion: '" << uiVersion << "',\n";
  309. // Add authentication method information
  310. if (m_authMiddleware) {
  311. auto authConfig = m_authMiddleware->getConfig();
  312. std::string authMethod = "none";
  313. switch (authConfig.authMethod) {
  314. case AuthMethod::UNIX:
  315. authMethod = "unix";
  316. break;
  317. case AuthMethod::JWT:
  318. authMethod = "jwt";
  319. break;
  320. default:
  321. authMethod = "none";
  322. break;
  323. }
  324. configJs << " authMethod: '" << authMethod << "',\n"
  325. << " authEnabled: " << (authConfig.authMethod != AuthMethod::NONE ? "true" : "false") << "\n";
  326. } else {
  327. configJs << " authMethod: 'none',\n"
  328. << " authEnabled: false\n";
  329. }
  330. configJs << "};\n";
  331. // No cache for config.js - always fetch fresh
  332. res.set_header("Cache-Control", "no-cache, no-store, must-revalidate");
  333. res.set_header("Pragma", "no-cache");
  334. res.set_header("Expires", "0");
  335. res.set_content(configJs.str(), "application/javascript");
  336. });
  337. // Set up file request handler for caching static assets
  338. m_httpServer->set_file_request_handler([uiVersion](const httplib::Request& req, httplib::Response& res) {
  339. // Add cache headers based on file type and version
  340. std::string path = req.path;
  341. // For versioned static assets (.js, .css, images), use long cache
  342. if (path.find("/_next/") != std::string::npos ||
  343. path.find(".js") != std::string::npos ||
  344. path.find(".css") != std::string::npos ||
  345. path.find(".png") != std::string::npos ||
  346. path.find(".jpg") != std::string::npos ||
  347. path.find(".svg") != std::string::npos ||
  348. path.find(".ico") != std::string::npos ||
  349. path.find(".woff") != std::string::npos ||
  350. path.find(".woff2") != std::string::npos ||
  351. path.find(".ttf") != std::string::npos) {
  352. // Long cache (1 year) for static assets
  353. res.set_header("Cache-Control", "public, max-age=31536000, immutable");
  354. // Add ETag based on UI version for cache validation
  355. res.set_header("ETag", "\"" + uiVersion + "\"");
  356. // Check If-None-Match for conditional requests
  357. if (req.has_header("If-None-Match")) {
  358. std::string clientETag = req.get_header_value("If-None-Match");
  359. if (clientETag == "\"" + uiVersion + "\"") {
  360. res.status = 304; // Not Modified
  361. return;
  362. }
  363. }
  364. } else if (path.find(".html") != std::string::npos || path == "/ui/" || path == "/ui") {
  365. // HTML files should revalidate but can be cached briefly
  366. res.set_header("Cache-Control", "public, max-age=0, must-revalidate");
  367. res.set_header("ETag", "\"" + uiVersion + "\"");
  368. }
  369. });
  370. // Create a handler for UI routes with authentication check
  371. auto uiHandler = [this](const httplib::Request& req, httplib::Response& res) {
  372. // Check if authentication is enabled
  373. if (m_authMiddleware) {
  374. auto authConfig = m_authMiddleware->getConfig();
  375. if (authConfig.authMethod != AuthMethod::NONE) {
  376. // Authentication is enabled, check if user is authenticated
  377. AuthContext authContext = m_authMiddleware->authenticate(req, res);
  378. // For Unix auth, we need to check if the user is authenticated
  379. // The authenticateUnix function will return a guest context for UI requests
  380. // when no Authorization header is present, but we still need to show the login page
  381. if (!authContext.authenticated) {
  382. // Check if this is a request for a static asset (JS, CSS, images)
  383. // These should be served even without authentication to allow the login page to work
  384. bool isStaticAsset = false;
  385. std::string path = req.path;
  386. if (path.find(".js") != std::string::npos ||
  387. path.find(".css") != std::string::npos ||
  388. path.find(".png") != std::string::npos ||
  389. path.find(".jpg") != std::string::npos ||
  390. path.find(".jpeg") != std::string::npos ||
  391. path.find(".svg") != std::string::npos ||
  392. path.find(".ico") != std::string::npos ||
  393. path.find("/_next/") != std::string::npos) {
  394. isStaticAsset = true;
  395. }
  396. // For static assets, allow them to be served without authentication
  397. if (isStaticAsset) {
  398. // Continue to serve the file
  399. } else {
  400. // For HTML requests, redirect to login page
  401. if (req.path.find(".html") != std::string::npos ||
  402. req.path == "/ui/" || req.path == "/ui") {
  403. // Serve the login page instead of the requested page
  404. std::string loginPagePath = m_uiDir + "/login.html";
  405. if (std::filesystem::exists(loginPagePath)) {
  406. std::ifstream loginFile(loginPagePath);
  407. if (loginFile.is_open()) {
  408. std::string content((std::istreambuf_iterator<char>(loginFile)),
  409. std::istreambuf_iterator<char>());
  410. res.set_content(content, "text/html");
  411. return;
  412. }
  413. }
  414. // If login.html doesn't exist, serve a simple login page
  415. std::string simpleLoginPage = R"(
  416. <!DOCTYPE html>
  417. <html>
  418. <head>
  419. <title>Login Required</title>
  420. <style>
  421. body { font-family: Arial, sans-serif; max-width: 500px; margin: 100px auto; padding: 20px; }
  422. .form-group { margin-bottom: 15px; }
  423. label { display: block; margin-bottom: 5px; }
  424. input { width: 100%; padding: 8px; box-sizing: border-box; }
  425. button { background-color: #007bff; color: white; padding: 10px 15px; border: none; cursor: pointer; }
  426. .error { color: red; margin-top: 10px; }
  427. </style>
  428. </head>
  429. <body>
  430. <h1>Login Required</h1>
  431. <p>Please enter your username to continue.</p>
  432. <form id="loginForm">
  433. <div class="form-group">
  434. <label for="username">Username:</label>
  435. <input type="text" id="username" name="username" required>
  436. </div>
  437. <button type="submit">Login</button>
  438. </form>
  439. <div id="error" class="error"></div>
  440. <script>
  441. document.getElementById('loginForm').addEventListener('submit', async (e) => {
  442. e.preventDefault();
  443. const username = document.getElementById('username').value;
  444. const errorDiv = document.getElementById('error');
  445. try {
  446. const response = await fetch('/api/auth/login', {
  447. method: 'POST',
  448. headers: { 'Content-Type': 'application/nlohmann::json' },
  449. body: JSON.stringify({ username })
  450. });
  451. if (response.ok) {
  452. const data = await response.nlohmann::json();
  453. localStorage.setItem('auth_token', data.token);
  454. localStorage.setItem('unix_user', username);
  455. window.location.reload();
  456. } else {
  457. const error = await response.nlohmann::json();
  458. errorDiv.textContent = error.message || 'Login failed';
  459. }
  460. } catch (err) {
  461. errorDiv.textContent = 'Login failed: ' + err.message;
  462. }
  463. });
  464. </script>
  465. </body>
  466. </html>
  467. )";
  468. res.set_content(simpleLoginPage, "text/html");
  469. return;
  470. } else {
  471. // For non-HTML files, return unauthorized
  472. m_authMiddleware->sendAuthError(res, "Authentication required", "AUTH_REQUIRED");
  473. return;
  474. }
  475. }
  476. }
  477. }
  478. }
  479. // If we get here, either auth is disabled or user is authenticated
  480. // Serve the requested file
  481. std::string filePath = req.path.substr(3); // Remove "/ui" prefix
  482. if (filePath.empty() || filePath == "/") {
  483. filePath = "/index.html";
  484. }
  485. std::string fullPath = m_uiDir + filePath;
  486. if (std::filesystem::exists(fullPath) && std::filesystem::is_regular_file(fullPath)) {
  487. std::ifstream file(fullPath, std::ios::binary);
  488. if (file.is_open()) {
  489. std::string content((std::istreambuf_iterator<char>(file)),
  490. std::istreambuf_iterator<char>());
  491. // Determine content type based on file extension
  492. std::string contentType = "text/plain";
  493. if (filePath.find(".html") != std::string::npos) {
  494. contentType = "text/html";
  495. } else if (filePath.find(".js") != std::string::npos) {
  496. contentType = "application/javascript";
  497. } else if (filePath.find(".css") != std::string::npos) {
  498. contentType = "text/css";
  499. } else if (filePath.find(".png") != std::string::npos) {
  500. contentType = "image/png";
  501. } else if (filePath.find(".jpg") != std::string::npos || filePath.find(".jpeg") != std::string::npos) {
  502. contentType = "image/jpeg";
  503. } else if (filePath.find(".svg") != std::string::npos) {
  504. contentType = "image/svg+xml";
  505. }
  506. res.set_content(content, contentType);
  507. } else {
  508. res.status = 404;
  509. res.set_content("File not found", "text/plain");
  510. }
  511. } else {
  512. // For SPA routing, if the file doesn't exist, serve index.html
  513. // This allows Next.js to handle client-side routing
  514. std::string indexPath = m_uiDir + "/index.html";
  515. if (std::filesystem::exists(indexPath)) {
  516. std::ifstream indexFile(indexPath, std::ios::binary);
  517. if (indexFile.is_open()) {
  518. std::string content((std::istreambuf_iterator<char>(indexFile)),
  519. std::istreambuf_iterator<char>());
  520. res.set_content(content, "text/html");
  521. } else {
  522. res.status = 404;
  523. res.set_content("File not found", "text/plain");
  524. }
  525. } else {
  526. res.status = 404;
  527. res.set_content("File not found", "text/plain");
  528. }
  529. }
  530. };
  531. // Set up UI routes with authentication
  532. m_httpServer->Get("/ui/.*", uiHandler);
  533. // Redirect /ui to /ui/ to ensure proper routing
  534. m_httpServer->Get("/ui", [](const httplib::Request& /*req*/, httplib::Response& res) {
  535. res.set_redirect("/ui/");
  536. });
  537. }
  538. }
  539. void Server::setAuthComponents(std::shared_ptr<UserManager> userManager, std::shared_ptr<AuthMiddleware> authMiddleware) {
  540. m_userManager = userManager;
  541. m_authMiddleware = authMiddleware;
  542. }
  543. void Server::registerAuthEndpoints() {
  544. // Login endpoint
  545. m_httpServer->Post("/api/auth/login", [this](const httplib::Request& req, httplib::Response& res) {
  546. handleLogin(req, res);
  547. });
  548. // Logout endpoint
  549. m_httpServer->Post("/api/auth/logout", [this](const httplib::Request& req, httplib::Response& res) {
  550. handleLogout(req, res);
  551. });
  552. // Token validation endpoint
  553. m_httpServer->Get("/api/auth/validate", [this](const httplib::Request& req, httplib::Response& res) {
  554. handleValidateToken(req, res);
  555. });
  556. // Refresh token endpoint
  557. m_httpServer->Post("/api/auth/refresh", [this](const httplib::Request& req, httplib::Response& res) {
  558. handleRefreshToken(req, res);
  559. });
  560. // Get current user endpoint
  561. m_httpServer->Get("/api/auth/me", [this](const httplib::Request& req, httplib::Response& res) {
  562. handleGetCurrentUser(req, res);
  563. });
  564. }
  565. void Server::handleLogin(const httplib::Request& req, httplib::Response& res) {
  566. std::string requestId = generateRequestId();
  567. try {
  568. if (!m_userManager || !m_authMiddleware) {
  569. sendErrorResponse(res, "Authentication system not available", 500, "AUTH_UNAVAILABLE", requestId);
  570. return;
  571. }
  572. // Parse request body
  573. nlohmann::json requestJson;
  574. try {
  575. requestJson = nlohmann::json::parse(req.body);
  576. } catch (const nlohmann::json::parse_error& e) {
  577. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  578. return;
  579. }
  580. // Check if using Unix authentication
  581. if (m_authMiddleware->getConfig().authMethod == AuthMethod::UNIX) {
  582. // For Unix auth, get username and password from request body
  583. std::string username = requestJson.value("username", "");
  584. std::string password = requestJson.value("password", "");
  585. if (username.empty()) {
  586. sendErrorResponse(res, "Missing username", 400, "MISSING_USERNAME", requestId);
  587. return;
  588. }
  589. // Check if PAM is enabled - if so, password is required
  590. if (m_userManager->isPamAuthEnabled() && password.empty()) {
  591. sendErrorResponse(res, "Password is required for Unix authentication", 400, "MISSING_PASSWORD", requestId);
  592. return;
  593. }
  594. // Authenticate Unix user (with or without password depending on PAM)
  595. auto result = m_userManager->authenticateUnix(username, password);
  596. if (!result.success) {
  597. sendErrorResponse(res, result.errorMessage, 401, "UNIX_AUTH_FAILED", requestId);
  598. return;
  599. }
  600. // Generate simple token for Unix auth
  601. std::string token = "unix_token_" + std::to_string(std::chrono::duration_cast<std::chrono::seconds>(
  602. std::chrono::system_clock::now().time_since_epoch()).count()) + "_" + username;
  603. nlohmann::json response = {
  604. {"token", token},
  605. {"user", {
  606. {"id", result.userId},
  607. {"username", result.username},
  608. {"role", result.role},
  609. {"permissions", result.permissions}
  610. }},
  611. {"message", "Unix authentication successful"}
  612. };
  613. sendJsonResponse(res, response);
  614. return;
  615. }
  616. // For non-Unix auth, validate required fields
  617. if (!requestJson.contains("username") || !requestJson.contains("password")) {
  618. sendErrorResponse(res, "Missing username or password", 400, "MISSING_CREDENTIALS", requestId);
  619. return;
  620. }
  621. std::string username = requestJson["username"];
  622. std::string password = requestJson["password"];
  623. // Authenticate user
  624. auto result = m_userManager->authenticateUser(username, password);
  625. if (!result.success) {
  626. sendErrorResponse(res, result.errorMessage, 401, "INVALID_CREDENTIALS", requestId);
  627. return;
  628. }
  629. // Generate JWT token if using JWT auth
  630. std::string token;
  631. if (m_authMiddleware->getConfig().authMethod == AuthMethod::JWT) {
  632. // For now, create a simple token (in a real implementation, use JWT)
  633. token = "token_" + std::to_string(std::chrono::duration_cast<std::chrono::seconds>(
  634. std::chrono::system_clock::now().time_since_epoch()).count()) + "_" + username;
  635. }
  636. nlohmann::json response = {
  637. {"token", token},
  638. {"user", {
  639. {"id", result.userId},
  640. {"username", result.username},
  641. {"role", result.role},
  642. {"permissions", result.permissions}
  643. }},
  644. {"message", "Login successful"}
  645. };
  646. sendJsonResponse(res, response);
  647. } catch (const std::exception& e) {
  648. sendErrorResponse(res, std::string("Login failed: ") + e.what(), 500, "LOGIN_ERROR", requestId);
  649. }
  650. }
  651. void Server::handleLogout(const httplib::Request& /*req*/, httplib::Response& res) {
  652. std::string requestId = generateRequestId();
  653. try {
  654. // For now, just return success (in a real implementation, invalidate the token)
  655. nlohmann::json response = {
  656. {"message", "Logout successful"}
  657. };
  658. sendJsonResponse(res, response);
  659. } catch (const std::exception& e) {
  660. sendErrorResponse(res, std::string("Logout failed: ") + e.what(), 500, "LOGOUT_ERROR", requestId);
  661. }
  662. }
  663. void Server::handleValidateToken(const httplib::Request& req, httplib::Response& res) {
  664. std::string requestId = generateRequestId();
  665. try {
  666. if (!m_userManager || !m_authMiddleware) {
  667. sendErrorResponse(res, "Authentication system not available", 500, "AUTH_UNAVAILABLE", requestId);
  668. return;
  669. }
  670. // Extract token from header
  671. std::string authHeader = req.get_header_value("Authorization");
  672. if (authHeader.empty()) {
  673. sendErrorResponse(res, "Missing authorization token", 401, "MISSING_TOKEN", requestId);
  674. return;
  675. }
  676. // Simple token validation (in a real implementation, validate JWT)
  677. // For now, just check if it starts with "token_"
  678. if (authHeader.find("Bearer ") != 0) {
  679. sendErrorResponse(res, "Invalid authorization header format", 401, "INVALID_HEADER", requestId);
  680. return;
  681. }
  682. std::string token = authHeader.substr(7); // Remove "Bearer "
  683. if (token.find("token_") != 0) {
  684. sendErrorResponse(res, "Invalid token", 401, "INVALID_TOKEN", requestId);
  685. return;
  686. }
  687. // Extract username from token (simple format: token_timestamp_username)
  688. size_t last_underscore = token.find_last_of('_');
  689. if (last_underscore == std::string::npos) {
  690. sendErrorResponse(res, "Invalid token format", 401, "INVALID_TOKEN", requestId);
  691. return;
  692. }
  693. std::string username = token.substr(last_underscore + 1);
  694. // Get user info
  695. auto userInfo = m_userManager->getUserInfoByUsername(username);
  696. if (userInfo.id.empty()) {
  697. sendErrorResponse(res, "User not found", 401, "USER_NOT_FOUND", requestId);
  698. return;
  699. }
  700. nlohmann::json response = {
  701. {"user", {
  702. {"id", userInfo.id},
  703. {"username", userInfo.username},
  704. {"role", userInfo.role},
  705. {"permissions", userInfo.permissions}
  706. }},
  707. {"valid", true}
  708. };
  709. sendJsonResponse(res, response);
  710. } catch (const std::exception& e) {
  711. sendErrorResponse(res, std::string("Token validation failed: ") + e.what(), 500, "VALIDATION_ERROR", requestId);
  712. }
  713. }
  714. void Server::handleRefreshToken(const httplib::Request& /*req*/, httplib::Response& res) {
  715. std::string requestId = generateRequestId();
  716. try {
  717. // For now, just return a new token (in a real implementation, refresh JWT)
  718. nlohmann::json response = {
  719. {"token", "new_token_" + std::to_string(std::chrono::duration_cast<std::chrono::seconds>(
  720. std::chrono::system_clock::now().time_since_epoch()).count())},
  721. {"message", "Token refreshed successfully"}
  722. };
  723. sendJsonResponse(res, response);
  724. } catch (const std::exception& e) {
  725. sendErrorResponse(res, std::string("Token refresh failed: ") + e.what(), 500, "REFRESH_ERROR", requestId);
  726. }
  727. }
  728. void Server::handleGetCurrentUser(const httplib::Request& req, httplib::Response& res) {
  729. std::string requestId = generateRequestId();
  730. try {
  731. if (!m_userManager || !m_authMiddleware) {
  732. sendErrorResponse(res, "Authentication system not available", 500, "AUTH_UNAVAILABLE", requestId);
  733. return;
  734. }
  735. // Authenticate the request
  736. AuthContext authContext = m_authMiddleware->authenticate(req, res);
  737. if (!authContext.authenticated) {
  738. sendErrorResponse(res, "Authentication required", 401, "AUTH_REQUIRED", requestId);
  739. return;
  740. }
  741. nlohmann::json response = {
  742. {"user", {
  743. {"id", authContext.userId},
  744. {"username", authContext.username},
  745. {"role", authContext.role},
  746. {"permissions", authContext.permissions}
  747. }}
  748. };
  749. sendJsonResponse(res, response);
  750. } catch (const std::exception& e) {
  751. sendErrorResponse(res, std::string("Get current user failed: ") + e.what(), 500, "USER_ERROR", requestId);
  752. }
  753. }
  754. void Server::setupCORS() {
  755. // Use post-routing handler to set CORS headers after the response is generated
  756. // This ensures we don't duplicate headers that may be set by other handlers
  757. m_httpServer->set_post_routing_handler([](const httplib::Request& /*req*/, httplib::Response& res) {
  758. // Only add CORS headers if they haven't been set already
  759. if (!res.has_header("Access-Control-Allow-Origin")) {
  760. res.set_header("Access-Control-Allow-Origin", "*");
  761. }
  762. if (!res.has_header("Access-Control-Allow-Methods")) {
  763. res.set_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS");
  764. }
  765. if (!res.has_header("Access-Control-Allow-Headers")) {
  766. res.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
  767. }
  768. });
  769. // Handle OPTIONS requests for CORS preflight (API endpoints only)
  770. m_httpServer->Options("/api/.*", [](const httplib::Request&, httplib::Response& res) {
  771. res.set_header("Access-Control-Allow-Origin", "*");
  772. res.set_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS");
  773. res.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
  774. res.status = 200;
  775. });
  776. }
  777. void Server::handleHealthCheck(const httplib::Request& /*req*/, httplib::Response& res) {
  778. try {
  779. nlohmann::json response = {
  780. {"status", "healthy"},
  781. {"timestamp", std::chrono::duration_cast<std::chrono::seconds>(
  782. std::chrono::system_clock::now().time_since_epoch()).count()},
  783. {"version", sd_rest::VERSION_INFO.version_full}
  784. };
  785. sendJsonResponse(res, response);
  786. } catch (const std::exception& e) {
  787. sendErrorResponse(res, std::string("Health check failed: ") + e.what(), 500);
  788. }
  789. }
  790. void Server::handleApiStatus(const httplib::Request& /*req*/, httplib::Response& res) {
  791. try {
  792. nlohmann::json response = {
  793. {"server", {
  794. {"running", m_isRunning.load()},
  795. {"host", m_host},
  796. {"port", m_port}
  797. }},
  798. {"generation_queue", {
  799. {"running", m_generationQueue ? m_generationQueue->isRunning() : false},
  800. {"queue_size", m_generationQueue ? m_generationQueue->getQueueSize() : 0},
  801. {"active_generations", m_generationQueue ? m_generationQueue->getActiveGenerations() : 0}
  802. }},
  803. {"models", {
  804. {"loaded_count", m_modelManager ? m_modelManager->getLoadedModelsCount() : 0},
  805. {"available_count", m_modelManager ? m_modelManager->getAvailableModelsCount() : 0}
  806. }}
  807. };
  808. sendJsonResponse(res, response);
  809. } catch (const std::exception& e) {
  810. sendErrorResponse(res, std::string("Status check failed: ") + e.what(), 500);
  811. }
  812. }
  813. void Server::handleVersion(const httplib::Request& /*req*/, httplib::Response& res) {
  814. try {
  815. nlohmann::json response = {
  816. {"version", sd_rest::VERSION_INFO.version_full},
  817. {"type", sd_rest::VERSION_INFO.version_type},
  818. {"commit", {
  819. {"short", sd_rest::VERSION_INFO.commit_short},
  820. {"full", sd_rest::VERSION_INFO.commit_full}
  821. }},
  822. {"branch", sd_rest::VERSION_INFO.branch},
  823. {"clean", sd_rest::VERSION_INFO.is_clean},
  824. {"build_time", sd_rest::VERSION_INFO.build_time}
  825. };
  826. sendJsonResponse(res, response);
  827. } catch (const std::exception& e) {
  828. sendErrorResponse(res, std::string("Version check failed: ") + e.what(), 500);
  829. }
  830. }
  831. // Helper function to convert ModelDetails vector to JSON array
  832. nlohmann::json Server::modelDetailsToJson(const std::vector<ModelManager::ModelDetails>& modelDetails) {
  833. nlohmann::json jsonArray = nlohmann::json::array();
  834. for (const auto& detail : modelDetails) {
  835. nlohmann::json modelJson = {
  836. {"name", detail.name},
  837. {"exists", detail.exists},
  838. {"type", detail.type},
  839. {"file_size", detail.file_size}
  840. };
  841. // Handle path and sha256 separately to avoid type mismatch
  842. if (detail.exists) {
  843. modelJson["path"] = detail.path;
  844. modelJson["sha256"] = detail.sha256;
  845. } else {
  846. modelJson["path"] = nullptr;
  847. modelJson["sha256"] = "";
  848. }
  849. // Add conditional fields for required/recommended models
  850. if (detail.is_required) {
  851. modelJson["is_required"] = true;
  852. }
  853. if (detail.is_recommended) {
  854. modelJson["is_recommended"] = true;
  855. }
  856. jsonArray.push_back(modelJson);
  857. }
  858. return jsonArray;
  859. }
  860. // Helper function to determine which recommended fields to include based on architecture
  861. std::map<std::string, bool> Server::getRecommendedModelFields(const std::string& architecture) {
  862. std::map<std::string, bool> recommendedFields;
  863. // Initialize all fields as false (will be set to null if not applicable)
  864. recommendedFields["recommended_vae"] = false;
  865. recommendedFields["recommended_clip_l"] = false;
  866. recommendedFields["recommended_clip_g"] = false;
  867. recommendedFields["recommended_t5xxl"] = false;
  868. recommendedFields["recommended_clip_vision"] = false;
  869. recommendedFields["recommended_qwen2vl"] = false;
  870. // Architecture-specific field inclusion based on actual architecture strings
  871. if (architecture.find("Stable Diffusion 1.5") != std::string::npos) {
  872. // SD 1.x: recommended_vae only
  873. recommendedFields["recommended_vae"] = true;
  874. } else if (architecture.find("Stable Diffusion XL") != std::string::npos) {
  875. // SDXL: recommended_vae only
  876. recommendedFields["recommended_vae"] = true;
  877. } else if (architecture.find("Modern Architecture") != std::string::npos ||
  878. architecture.find("Flux Dev") != std::string::npos ||
  879. architecture.find("Flux Chroma") != std::string::npos) {
  880. // FLUX/SD3/Modern Architecture: recommended_vae, recommended_clip_l, recommended_t5xxl
  881. recommendedFields["recommended_vae"] = true;
  882. recommendedFields["recommended_clip_l"] = true;
  883. recommendedFields["recommended_t5xxl"] = true;
  884. } else if (architecture.find("SD 3") != std::string::npos) {
  885. // SD3: recommended_vae, recommended_clip_l, recommended_clip_g, recommended_t5xxl
  886. recommendedFields["recommended_vae"] = true;
  887. recommendedFields["recommended_clip_l"] = true;
  888. recommendedFields["recommended_clip_g"] = true;
  889. recommendedFields["recommended_t5xxl"] = true;
  890. } else if (architecture.find("Wan") != std::string::npos) {
  891. // Wan models: recommended_vae, recommended_t5xxl, recommended_clip_vision
  892. recommendedFields["recommended_vae"] = true;
  893. recommendedFields["recommended_t5xxl"] = true;
  894. recommendedFields["recommended_clip_vision"] = true;
  895. } else if (architecture.find("Qwen") != std::string::npos) {
  896. // Qwen models: recommended_vae, recommended_qwen2vl
  897. recommendedFields["recommended_vae"] = true;
  898. recommendedFields["recommended_qwen2vl"] = true;
  899. }
  900. // For UNKNOWN architecture, keep all fields false
  901. return recommendedFields;
  902. }
  903. // Helper function to populate recommended models with existence information
  904. void Server::populateRecommendedModels(nlohmann::json& response, const ModelManager::ModelInfo& modelInfo) {
  905. if (modelInfo.requiredModels.empty()) {
  906. return;
  907. }
  908. // Check existence of required models
  909. auto requiredModelsDetails = m_modelManager->checkRequiredModelsExistence(modelInfo.requiredModels);
  910. // Get the recommended fields for this architecture
  911. auto recommendedFields = getRecommendedModelFields(modelInfo.architecture);
  912. // Group models by type
  913. std::map<std::string, std::vector<ModelManager::ModelDetails>> modelsByType;
  914. for (const auto& detail : requiredModelsDetails) {
  915. modelsByType[detail.type].push_back(detail);
  916. }
  917. // Populate recommended fields based on model types and architecture requirements
  918. for (const auto& [type, models] : modelsByType) {
  919. if (type == "VAE" && recommendedFields["recommended_vae"]) {
  920. response["recommended_vae"] = modelDetailsToJson(models);
  921. } else if (type == "CLIP-L" && recommendedFields["recommended_clip_l"]) {
  922. response["recommended_clip_l"] = modelDetailsToJson(models);
  923. } else if (type == "CLIP-G" && recommendedFields["recommended_clip_g"]) {
  924. response["recommended_clip_g"] = modelDetailsToJson(models);
  925. } else if (type == "T5XXL" && recommendedFields["recommended_t5xxl"]) {
  926. response["recommended_t5xxl"] = modelDetailsToJson(models);
  927. } else if (type == "CLIP-Vision" && recommendedFields["recommended_clip_vision"]) {
  928. response["recommended_clip_vision"] = modelDetailsToJson(models);
  929. } else if (type == "Qwen2VL" && recommendedFields["recommended_qwen2vl"]) {
  930. response["recommended_qwen2vl"] = modelDetailsToJson(models);
  931. }
  932. }
  933. // Set non-applicable fields to null
  934. for (const auto& [fieldName, shouldInclude] : recommendedFields) {
  935. if (!shouldInclude || !response.contains(fieldName)) {
  936. response[fieldName] = nlohmann::json(nullptr);
  937. }
  938. }
  939. }
  940. void Server::handleModelsList(const httplib::Request& req, httplib::Response& res) {
  941. std::string requestId = generateRequestId();
  942. try {
  943. if (!m_modelManager) {
  944. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  945. return;
  946. }
  947. // Parse query parameters for enhanced filtering
  948. std::string typeFilter = req.get_param_value("type");
  949. std::string searchQuery = req.get_param_value("search");
  950. std::string sortBy = req.get_param_value("sort_by");
  951. std::string sortOrder = req.get_param_value("sort_order");
  952. std::string dateFilter = req.get_param_value("date");
  953. std::string sizeFilter = req.get_param_value("size");
  954. // Pagination parameters - only apply if limit is explicitly provided
  955. int page = 1;
  956. int limit = 50;
  957. bool usePagination = false;
  958. try {
  959. if (!req.get_param_value("limit").empty()) {
  960. limit = std::stoi(req.get_param_value("limit"));
  961. // Special case: limit<=0 means return all models (no pagination)
  962. if (limit <= 0) {
  963. usePagination = false;
  964. limit = INT_MAX; // Set to very large number to effectively disable pagination
  965. } else {
  966. usePagination = true;
  967. if (!req.get_param_value("page").empty()) {
  968. page = std::stoi(req.get_param_value("page"));
  969. if (page < 1) page = 1;
  970. }
  971. }
  972. }
  973. } catch (const std::exception& e) {
  974. sendErrorResponse(res, "Invalid pagination parameters", 400, "INVALID_PAGINATION", requestId);
  975. return;
  976. }
  977. // Filter parameters
  978. bool includeLoaded = req.get_param_value("loaded") == "true";
  979. bool includeUnloaded = req.get_param_value("unloaded") == "true";
  980. (void)req.get_param_value("include_metadata"); // unused but kept for API compatibility
  981. (void)req.get_param_value("include_thumbnails"); // unused but kept for API compatibility
  982. // Get all models
  983. auto allModels = m_modelManager->getAllModels();
  984. nlohmann::json models = nlohmann::json::array();
  985. // Apply filters and build response
  986. for (const auto& pair : allModels) {
  987. const auto& modelInfo = pair.second;
  988. // Apply type filter
  989. if (!typeFilter.empty()) {
  990. ModelType filterType = ModelManager::stringToModelType(typeFilter);
  991. if (modelInfo.type != filterType) continue;
  992. }
  993. // Apply loaded/unloaded filters
  994. if (includeLoaded && !modelInfo.isLoaded) continue;
  995. if (includeUnloaded && modelInfo.isLoaded) continue;
  996. // Apply search filter (case-insensitive search in name and description)
  997. if (!searchQuery.empty()) {
  998. std::string searchLower = searchQuery;
  999. std::transform(searchLower.begin(), searchLower.end(), searchLower.begin(), ::tolower);
  1000. std::string nameLower = modelInfo.name;
  1001. std::transform(nameLower.begin(), nameLower.end(), nameLower.begin(), ::tolower);
  1002. std::string descLower = modelInfo.description;
  1003. std::transform(descLower.begin(), descLower.end(), descLower.begin(), ::tolower);
  1004. if (nameLower.find(searchLower) == std::string::npos &&
  1005. descLower.find(searchLower) == std::string::npos) {
  1006. continue;
  1007. }
  1008. }
  1009. // Apply date filter (simplified - expects "recent", "old", or YYYY-MM-DD)
  1010. if (!dateFilter.empty()) {
  1011. auto now = std::filesystem::file_time_type::clock::now();
  1012. auto modelTime = modelInfo.modifiedAt;
  1013. auto duration = std::chrono::duration_cast<std::chrono::hours>(now - modelTime).count();
  1014. if (dateFilter == "recent" && duration > 24 * 7) continue; // Older than 1 week
  1015. if (dateFilter == "old" && duration < 24 * 30) continue; // Newer than 1 month
  1016. }
  1017. // Apply size filter (expects "small", "medium", "large", or size in MB)
  1018. if (!sizeFilter.empty()) {
  1019. double sizeMB = modelInfo.fileSize / (1024.0 * 1024.0);
  1020. if (sizeFilter == "small" && sizeMB > 1024) continue; // > 1GB
  1021. if (sizeFilter == "medium" && (sizeMB < 1024 || sizeMB > 4096)) continue; // < 1GB or > 4GB
  1022. if (sizeFilter == "large" && sizeMB < 4096) continue; // < 4GB
  1023. // Try to parse as specific size in MB
  1024. try {
  1025. double maxSizeMB = std::stod(sizeFilter);
  1026. if (sizeMB > maxSizeMB) continue;
  1027. } catch (...) {
  1028. // Ignore if parsing fails
  1029. }
  1030. }
  1031. // Build model JSON with enhanced structure
  1032. nlohmann::json modelJson = {
  1033. {"name", modelInfo.name},
  1034. {"type", ModelManager::modelTypeToString(modelInfo.type)},
  1035. {"file_size", modelInfo.fileSize},
  1036. {"file_size_mb", modelInfo.fileSize / (1024.0 * 1024.0)},
  1037. {"sha256", modelInfo.sha256.empty() ? nullptr : nlohmann::json(modelInfo.sha256)},
  1038. {"sha256_short", (modelInfo.sha256.empty() || modelInfo.sha256.length() < 10) ? nullptr : nlohmann::json(modelInfo.sha256.substr(0, 10))}
  1039. };
  1040. // Add architecture information if available (checkpoints only)
  1041. if (!modelInfo.architecture.empty()) {
  1042. modelJson["architecture"] = modelInfo.architecture;
  1043. modelJson["recommended_vae"] = modelInfo.recommendedVAE.empty() ? nullptr : nlohmann::json(modelInfo.recommendedVAE);
  1044. if (modelInfo.recommendedWidth > 0) {
  1045. modelJson["recommended_width"] = modelInfo.recommendedWidth;
  1046. }
  1047. if (modelInfo.recommendedHeight > 0) {
  1048. modelJson["recommended_height"] = modelInfo.recommendedHeight;
  1049. }
  1050. if (modelInfo.recommendedSteps > 0) {
  1051. modelJson["recommended_steps"] = modelInfo.recommendedSteps;
  1052. }
  1053. if (!modelInfo.recommendedSampler.empty()) {
  1054. modelJson["recommended_sampler"] = modelInfo.recommendedSampler;
  1055. }
  1056. // Enhanced model information with existence checking
  1057. if (!modelInfo.requiredModels.empty()) {
  1058. auto requiredModelsDetails = m_modelManager->checkRequiredModelsExistence(modelInfo.requiredModels);
  1059. modelJson["required_models"] = modelDetailsToJson(requiredModelsDetails);
  1060. // Populate recommended models based on architecture
  1061. populateRecommendedModels(modelJson, modelInfo);
  1062. }
  1063. // Backward compatibility - keep existing fields
  1064. if (!modelInfo.missingModels.empty()) {
  1065. modelJson["missing_models"] = modelInfo.missingModels;
  1066. modelJson["has_missing_dependencies"] = true;
  1067. } else {
  1068. modelJson["has_missing_dependencies"] = false;
  1069. }
  1070. }
  1071. models.push_back(modelJson);
  1072. }
  1073. // Apply sorting
  1074. if (!sortBy.empty()) {
  1075. std::sort(models.begin(), models.end(), [&sortBy, &sortOrder](const nlohmann::json& a, const nlohmann::json& b) {
  1076. bool ascending = sortOrder != "desc";
  1077. if (sortBy == "name") {
  1078. return ascending ? a["name"] < b["name"] : a["name"] > b["name"];
  1079. } else if (sortBy == "size") {
  1080. return ascending ? a["file_size"] < b["file_size"] : a["file_size"] > b["file_size"];
  1081. } else if (sortBy == "date") {
  1082. return ascending ? a["last_modified"] < b["last_modified"] : a["last_modified"] > b["last_modified"];
  1083. } else if (sortBy == "type") {
  1084. return ascending ? a["type"] < b["type"] : a["type"] > b["type"];
  1085. } else if (sortBy == "loaded") {
  1086. return ascending ? a["is_loaded"] < b["is_loaded"] : a["is_loaded"] > b["is_loaded"];
  1087. }
  1088. return false;
  1089. });
  1090. }
  1091. // Apply pagination only if limit parameter was provided
  1092. int totalCount = models.size();
  1093. nlohmann::json paginatedModels = nlohmann::json::array();
  1094. nlohmann::json paginationInfo = nlohmann::json::object();
  1095. if (usePagination) {
  1096. // Apply pagination
  1097. int totalPages = (totalCount + limit - 1) / limit;
  1098. int startIndex = (page - 1) * limit;
  1099. int endIndex = std::min(startIndex + limit, totalCount);
  1100. for (int i = startIndex; i < endIndex; ++i) {
  1101. paginatedModels.push_back(models[i]);
  1102. }
  1103. paginationInfo = {
  1104. {"page", page},
  1105. {"limit", limit},
  1106. {"total_count", totalCount},
  1107. {"total_pages", totalPages},
  1108. {"has_next", page < totalPages},
  1109. {"has_prev", page > 1}
  1110. };
  1111. } else {
  1112. // Return all models without pagination
  1113. paginatedModels = models;
  1114. paginationInfo = {
  1115. {"page", 1},
  1116. {"limit", totalCount},
  1117. {"total_count", totalCount},
  1118. {"total_pages", 1},
  1119. {"has_next", false},
  1120. {"has_prev", false}
  1121. };
  1122. }
  1123. // Build comprehensive response
  1124. nlohmann::json response = {
  1125. {"models", paginatedModels},
  1126. {"pagination", paginationInfo},
  1127. {"filters_applied", {
  1128. {"type", typeFilter.empty() ? nlohmann::json(nullptr) : nlohmann::json(typeFilter)},
  1129. {"search", searchQuery.empty() ? nlohmann::json(nullptr) : nlohmann::json(searchQuery)},
  1130. {"date", dateFilter.empty() ? nlohmann::json(nullptr) : nlohmann::json(dateFilter)},
  1131. {"size", sizeFilter.empty() ? nlohmann::json(nullptr) : nlohmann::json(sizeFilter)},
  1132. {"loaded", includeLoaded ? nlohmann::json(true) : nlohmann::json(nullptr)},
  1133. {"unloaded", includeUnloaded ? nlohmann::json(true) : nlohmann::json(nullptr)}
  1134. }},
  1135. {"sorting", {
  1136. {"sort_by", sortBy.empty() ? "name" : nlohmann::json(sortBy)},
  1137. {"sort_order", sortOrder.empty() ? "asc" : nlohmann::json(sortOrder)}
  1138. }},
  1139. {"statistics", {
  1140. {"loaded_count", m_modelManager->getLoadedModelsCount()},
  1141. {"available_count", m_modelManager->getAvailableModelsCount()}
  1142. }},
  1143. {"request_id", requestId}
  1144. };
  1145. sendJsonResponse(res, response);
  1146. } catch (const std::exception& e) {
  1147. sendErrorResponse(res, std::string("Failed to list models: ") + e.what(), 500, "MODEL_LIST_ERROR", requestId);
  1148. }
  1149. }
  1150. void Server::handleQueueStatus(const httplib::Request& /*req*/, httplib::Response& res) {
  1151. try {
  1152. if (!m_generationQueue) {
  1153. sendErrorResponse(res, "Generation queue not available", 500);
  1154. return;
  1155. }
  1156. // Get detailed queue status
  1157. auto jobs = m_generationQueue->getQueueStatus();
  1158. // Convert jobs to JSON
  1159. nlohmann::json jobsJson = nlohmann::json::array();
  1160. for (const auto& job : jobs) {
  1161. std::string statusStr;
  1162. switch (job.status) {
  1163. case GenerationStatus::QUEUED: statusStr = "queued"; break;
  1164. case GenerationStatus::PROCESSING: statusStr = "processing"; break;
  1165. case GenerationStatus::COMPLETED: statusStr = "completed"; break;
  1166. case GenerationStatus::FAILED: statusStr = "failed"; break;
  1167. }
  1168. // Convert time points to timestamps
  1169. auto queuedTime = std::chrono::duration_cast<std::chrono::milliseconds>(
  1170. job.queuedTime.time_since_epoch()).count();
  1171. auto startTime = std::chrono::duration_cast<std::chrono::milliseconds>(
  1172. job.startTime.time_since_epoch()).count();
  1173. auto endTime = std::chrono::duration_cast<std::chrono::milliseconds>(
  1174. job.endTime.time_since_epoch()).count();
  1175. jobsJson.push_back({
  1176. {"id", job.id},
  1177. {"status", statusStr},
  1178. {"prompt", job.prompt},
  1179. {"queued_time", queuedTime},
  1180. {"start_time", startTime > 0 ? nlohmann::json(startTime) : nlohmann::json(nullptr)},
  1181. {"end_time", endTime > 0 ? nlohmann::json(endTime) : nlohmann::json(nullptr)},
  1182. {"position", job.position},
  1183. {"progress", job.progress}
  1184. });
  1185. }
  1186. nlohmann::json response = {
  1187. {"queue", {
  1188. {"size", m_generationQueue->getQueueSize()},
  1189. {"active_generations", m_generationQueue->getActiveGenerations()},
  1190. {"running", m_generationQueue->isRunning()},
  1191. {"jobs", jobsJson}
  1192. }}
  1193. };
  1194. sendJsonResponse(res, response);
  1195. } catch (const std::exception& e) {
  1196. sendErrorResponse(res, std::string("Queue status check failed: ") + e.what(), 500);
  1197. }
  1198. }
  1199. void Server::handleJobStatus(const httplib::Request& req, httplib::Response& res) {
  1200. try {
  1201. if (!m_generationQueue) {
  1202. sendErrorResponse(res, "Generation queue not available", 500);
  1203. return;
  1204. }
  1205. // Extract job ID from URL path
  1206. std::string jobId = req.matches[1].str();
  1207. if (jobId.empty()) {
  1208. sendErrorResponse(res, "Missing job ID", 400);
  1209. return;
  1210. }
  1211. // Get job information
  1212. auto jobInfo = m_generationQueue->getJobInfo(jobId);
  1213. if (jobInfo.id.empty()) {
  1214. sendErrorResponse(res, "Job not found", 404);
  1215. return;
  1216. }
  1217. // Convert status to string
  1218. std::string statusStr;
  1219. switch (jobInfo.status) {
  1220. case GenerationStatus::QUEUED: statusStr = "queued"; break;
  1221. case GenerationStatus::PROCESSING: statusStr = "processing"; break;
  1222. case GenerationStatus::COMPLETED: statusStr = "completed"; break;
  1223. case GenerationStatus::FAILED: statusStr = "failed"; break;
  1224. }
  1225. // Convert time points to timestamps
  1226. auto queuedTime = std::chrono::duration_cast<std::chrono::milliseconds>(
  1227. jobInfo.queuedTime.time_since_epoch()).count();
  1228. auto startTime = std::chrono::duration_cast<std::chrono::milliseconds>(
  1229. jobInfo.startTime.time_since_epoch()).count();
  1230. auto endTime = std::chrono::duration_cast<std::chrono::milliseconds>(
  1231. jobInfo.endTime.time_since_epoch()).count();
  1232. // Create download URLs for output files
  1233. nlohmann::json outputUrls = nlohmann::json::array();
  1234. for (const auto& filePath : jobInfo.outputFiles) {
  1235. // Extract filename from full path
  1236. std::filesystem::path p(filePath);
  1237. std::string filename = p.filename().string();
  1238. // Create download URL
  1239. std::string url = "/api/queue/job/" + jobInfo.id + "/output/" + filename;
  1240. nlohmann::json fileInfo = {
  1241. {"filename", filename},
  1242. {"url", url},
  1243. {"path", filePath}
  1244. };
  1245. outputUrls.push_back(fileInfo);
  1246. }
  1247. nlohmann::json response = {
  1248. {"job", {
  1249. {"id", jobInfo.id},
  1250. {"status", statusStr},
  1251. {"prompt", jobInfo.prompt},
  1252. {"queued_time", queuedTime},
  1253. {"start_time", startTime > 0 ? nlohmann::json(startTime) : nlohmann::json(nullptr)},
  1254. {"end_time", endTime > 0 ? nlohmann::json(endTime) : nlohmann::json(nullptr)},
  1255. {"position", jobInfo.position},
  1256. {"outputs", outputUrls},
  1257. {"error_message", jobInfo.errorMessage},
  1258. {"progress", jobInfo.progress}
  1259. }}
  1260. };
  1261. sendJsonResponse(res, response);
  1262. } catch (const std::exception& e) {
  1263. sendErrorResponse(res, std::string("Job status check failed: ") + e.what(), 500);
  1264. }
  1265. }
  1266. void Server::handleCancelJob(const httplib::Request& req, httplib::Response& res) {
  1267. try {
  1268. if (!m_generationQueue) {
  1269. sendErrorResponse(res, "Generation queue not available", 500);
  1270. return;
  1271. }
  1272. // Parse JSON request body
  1273. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  1274. // Validate required fields
  1275. if (!requestJson.contains("job_id") || !requestJson["job_id"].is_string()) {
  1276. sendErrorResponse(res, "Missing or invalid 'job_id' field", 400);
  1277. return;
  1278. }
  1279. std::string jobId = requestJson["job_id"];
  1280. // Try to cancel the job
  1281. bool cancelled = m_generationQueue->cancelJob(jobId);
  1282. if (cancelled) {
  1283. nlohmann::json response = {
  1284. {"status", "success"},
  1285. {"message", "Job cancelled successfully"},
  1286. {"job_id", jobId}
  1287. };
  1288. sendJsonResponse(res, response);
  1289. } else {
  1290. nlohmann::json response = {
  1291. {"status", "error"},
  1292. {"message", "Job not found or already processing"},
  1293. {"job_id", jobId}
  1294. };
  1295. sendJsonResponse(res, response, 404);
  1296. }
  1297. } catch (const nlohmann::json::parse_error& e) {
  1298. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400);
  1299. } catch (const std::exception& e) {
  1300. sendErrorResponse(res, std::string("Job cancellation failed: ") + e.what(), 500);
  1301. }
  1302. }
  1303. void Server::handleClearQueue(const httplib::Request& /*req*/, httplib::Response& res) {
  1304. try {
  1305. if (!m_generationQueue) {
  1306. sendErrorResponse(res, "Generation queue not available", 500);
  1307. return;
  1308. }
  1309. // Clear the queue
  1310. m_generationQueue->clearQueue();
  1311. nlohmann::json response = {
  1312. {"status", "success"},
  1313. {"message", "Queue cleared successfully"}
  1314. };
  1315. sendJsonResponse(res, response);
  1316. } catch (const std::exception& e) {
  1317. sendErrorResponse(res, std::string("Queue clear failed: ") + e.what(), 500);
  1318. }
  1319. }
  1320. void Server::handleDownloadOutput(const httplib::Request& req, httplib::Response& res) {
  1321. try {
  1322. // Extract job ID and filename from URL path
  1323. if (req.matches.size() < 3) {
  1324. sendErrorResponse(res, "Invalid request: job ID and filename required", 400, "INVALID_REQUEST", "");
  1325. return;
  1326. }
  1327. std::string jobId = req.matches[1];
  1328. std::string filename = req.matches[2];
  1329. // Validate inputs
  1330. if (jobId.empty() || filename.empty()) {
  1331. sendErrorResponse(res, "Job ID and filename cannot be empty", 400, "INVALID_PARAMETERS", "");
  1332. return;
  1333. }
  1334. // Construct absolute file path using the same logic as when saving:
  1335. // {outputDir}/{jobId}/{filename}
  1336. std::string fullPath = std::filesystem::absolute(m_outputDir + "/" + jobId + "/" + filename).string();
  1337. // Log the request for debugging
  1338. std::cout << "Image download request: jobId=" << jobId << ", filename=" << filename
  1339. << ", fullPath=" << fullPath << std::endl;
  1340. // Check if file exists
  1341. if (!std::filesystem::exists(fullPath)) {
  1342. std::cerr << "Output file not found: " << fullPath << std::endl;
  1343. sendErrorResponse(res, "Output file not found: " + filename, 404, "FILE_NOT_FOUND", "");
  1344. return;
  1345. }
  1346. // Check file size to detect zero-byte files
  1347. auto fileSize = std::filesystem::file_size(fullPath);
  1348. if (fileSize == 0) {
  1349. std::cerr << "Output file is zero bytes: " << fullPath << std::endl;
  1350. sendErrorResponse(res, "Output file is empty (corrupted generation)", 500, "EMPTY_FILE", "");
  1351. return;
  1352. }
  1353. // Check if file is accessible
  1354. std::ifstream file(fullPath, std::ios::binary);
  1355. if (!file.is_open()) {
  1356. std::cerr << "Failed to open output file: " << fullPath << std::endl;
  1357. sendErrorResponse(res, "Output file not accessible", 500, "FILE_ACCESS_ERROR", "");
  1358. return;
  1359. }
  1360. // Read file contents
  1361. std::string fileContent;
  1362. try {
  1363. fileContent = std::string(
  1364. std::istreambuf_iterator<char>(file),
  1365. std::istreambuf_iterator<char>()
  1366. );
  1367. file.close();
  1368. } catch (const std::exception& e) {
  1369. std::cerr << "Failed to read file content: " << e.what() << std::endl;
  1370. sendErrorResponse(res, "Failed to read file content", 500, "FILE_READ_ERROR", "");
  1371. return;
  1372. }
  1373. // Verify we actually read data
  1374. if (fileContent.empty()) {
  1375. std::cerr << "File content is empty after read: " << fullPath << std::endl;
  1376. sendErrorResponse(res, "File content is empty after read", 500, "EMPTY_CONTENT", "");
  1377. return;
  1378. }
  1379. // Determine content type based on file extension
  1380. std::string contentType = "application/octet-stream";
  1381. if (Utils::endsWith(filename, ".png")) {
  1382. contentType = "image/png";
  1383. } else if (Utils::endsWith(filename, ".jpg") || Utils::endsWith(filename, ".jpeg")) {
  1384. contentType = "image/jpeg";
  1385. } else if (Utils::endsWith(filename, ".mp4")) {
  1386. contentType = "video/mp4";
  1387. } else if (Utils::endsWith(filename, ".gif")) {
  1388. contentType = "image/gif";
  1389. } else if (Utils::endsWith(filename, ".webp")) {
  1390. contentType = "image/webp";
  1391. }
  1392. // Set response headers for proper browser handling
  1393. res.set_header("Content-Type", contentType);
  1394. res.set_header("Content-Length", std::to_string(fileContent.length()));
  1395. res.set_header("Cache-Control", "public, max-age=3600"); // Cache for 1 hour
  1396. res.set_header("Access-Control-Allow-Origin", "*"); // CORS for image access
  1397. // Uncomment if you want to force download instead of inline display:
  1398. // res.set_header("Content-Disposition", "attachment; filename=\"" + filename + "\"");
  1399. // Set the content
  1400. res.set_content(fileContent, contentType);
  1401. res.status = 200;
  1402. std::cout << "Successfully served image: " << filename << " (" << fileContent.length() << " bytes)" << std::endl;
  1403. } catch (const std::exception& e) {
  1404. std::cerr << "Exception in handleDownloadOutput: " << e.what() << std::endl;
  1405. sendErrorResponse(res, std::string("Failed to download file: ") + e.what(), 500, "DOWNLOAD_ERROR", "");
  1406. }
  1407. }
  1408. void Server::handleJobOutput(const httplib::Request& req, httplib::Response& res) {
  1409. std::string requestId = generateRequestId();
  1410. try {
  1411. // Extract job ID from URL path
  1412. if (req.matches.size() < 2) {
  1413. sendErrorResponse(res, "Invalid request: job ID required", 400, "INVALID_REQUEST", requestId);
  1414. return;
  1415. }
  1416. std::string jobId = req.matches[1].str();
  1417. // Validate job ID
  1418. if (jobId.empty()) {
  1419. sendErrorResponse(res, "Job ID cannot be empty", 400, "INVALID_PARAMETERS", requestId);
  1420. return;
  1421. }
  1422. // Log the request for debugging
  1423. std::cout << "Job output request: jobId=" << jobId << std::endl;
  1424. // Get job information to check if it exists and is completed
  1425. if (!m_generationQueue) {
  1426. sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
  1427. return;
  1428. }
  1429. auto jobInfo = m_generationQueue->getJobInfo(jobId);
  1430. if (jobInfo.id.empty()) {
  1431. sendErrorResponse(res, "Job not found", 404, "JOB_NOT_FOUND", requestId);
  1432. return;
  1433. }
  1434. // Check if job is completed
  1435. if (jobInfo.status != GenerationStatus::COMPLETED) {
  1436. std::string statusStr;
  1437. switch (jobInfo.status) {
  1438. case GenerationStatus::QUEUED: statusStr = "queued"; break;
  1439. case GenerationStatus::PROCESSING: statusStr = "processing"; break;
  1440. case GenerationStatus::FAILED: statusStr = "failed"; break;
  1441. default: statusStr = "unknown"; break;
  1442. }
  1443. nlohmann::json response = {
  1444. {"error", {
  1445. {"message", "Job not completed yet"},
  1446. {"status_code", 400},
  1447. {"error_code", "JOB_NOT_COMPLETED"},
  1448. {"request_id", requestId},
  1449. {"timestamp", std::chrono::duration_cast<std::chrono::seconds>(
  1450. std::chrono::system_clock::now().time_since_epoch()).count()},
  1451. {"job_status", statusStr}
  1452. }}
  1453. };
  1454. sendJsonResponse(res, response, 400);
  1455. return;
  1456. }
  1457. // Check if job has output files
  1458. if (jobInfo.outputFiles.empty()) {
  1459. sendErrorResponse(res, "No output files found for completed job", 404, "NO_OUTPUT_FILES", requestId);
  1460. return;
  1461. }
  1462. // For simplicity, return the first output file
  1463. // In a more complex implementation, we could return all files or allow file selection
  1464. std::string firstOutputFile = jobInfo.outputFiles[0];
  1465. // Extract filename from full path
  1466. std::filesystem::path filePath(firstOutputFile);
  1467. std::string filename = filePath.filename().string();
  1468. // Construct absolute file path
  1469. std::string fullPath = std::filesystem::absolute(firstOutputFile).string();
  1470. // Check if file exists
  1471. if (!std::filesystem::exists(fullPath)) {
  1472. std::cerr << "Output file not found: " << fullPath << std::endl;
  1473. sendErrorResponse(res, "Output file not found: " + filename, 404, "FILE_NOT_FOUND", requestId);
  1474. return;
  1475. }
  1476. // Check file size to detect zero-byte files
  1477. auto fileSize = std::filesystem::file_size(fullPath);
  1478. if (fileSize == 0) {
  1479. std::cerr << "Output file is zero bytes: " << fullPath << std::endl;
  1480. sendErrorResponse(res, "Output file is empty (corrupted generation)", 500, "EMPTY_FILE", requestId);
  1481. return;
  1482. }
  1483. // Check if file is accessible
  1484. std::ifstream file(fullPath, std::ios::binary);
  1485. if (!file.is_open()) {
  1486. std::cerr << "Failed to open output file: " << fullPath << std::endl;
  1487. sendErrorResponse(res, "Output file not accessible", 500, "FILE_ACCESS_ERROR", requestId);
  1488. return;
  1489. }
  1490. // Read file contents
  1491. std::string fileContent;
  1492. try {
  1493. fileContent = std::string(
  1494. std::istreambuf_iterator<char>(file),
  1495. std::istreambuf_iterator<char>()
  1496. );
  1497. file.close();
  1498. } catch (const std::exception& e) {
  1499. std::cerr << "Failed to read file content: " << e.what() << std::endl;
  1500. sendErrorResponse(res, "Failed to read file content", 500, "FILE_READ_ERROR", requestId);
  1501. return;
  1502. }
  1503. // Verify we actually read data
  1504. if (fileContent.empty()) {
  1505. std::cerr << "File content is empty after read: " << fullPath << std::endl;
  1506. sendErrorResponse(res, "File content is empty after read", 500, "EMPTY_CONTENT", requestId);
  1507. return;
  1508. }
  1509. // Determine content type based on file extension
  1510. std::string contentType = "application/octet-stream";
  1511. if (Utils::endsWith(filename, ".png")) {
  1512. contentType = "image/png";
  1513. } else if (Utils::endsWith(filename, ".jpg") || Utils::endsWith(filename, ".jpeg")) {
  1514. contentType = "image/jpeg";
  1515. } else if (Utils::endsWith(filename, ".mp4")) {
  1516. contentType = "video/mp4";
  1517. } else if (Utils::endsWith(filename, ".gif")) {
  1518. contentType = "image/gif";
  1519. } else if (Utils::endsWith(filename, ".webp")) {
  1520. contentType = "image/webp";
  1521. }
  1522. // Set response headers for proper browser handling
  1523. res.set_header("Content-Type", contentType);
  1524. res.set_header("Content-Length", std::to_string(fileContent.length()));
  1525. res.set_header("Cache-Control", "public, max-age=3600"); // Cache for 1 hour
  1526. res.set_header("Access-Control-Allow-Origin", "*"); // CORS for image access
  1527. // Set additional metadata headers
  1528. res.set_header("X-Job-ID", jobId);
  1529. res.set_header("X-Filename", filename);
  1530. res.set_header("X-File-Size", std::to_string(fileSize));
  1531. // If there are multiple files, indicate this
  1532. if (jobInfo.outputFiles.size() > 1) {
  1533. res.set_header("X-Total-Files", std::to_string(jobInfo.outputFiles.size()));
  1534. res.set_header("X-File-Index", "1");
  1535. }
  1536. // Set the content
  1537. res.set_content(fileContent, contentType);
  1538. res.status = 200;
  1539. std::cout << "Successfully served job output: jobId=" << jobId
  1540. << ", filename=" << filename
  1541. << " (" << fileContent.length() << " bytes)" << std::endl;
  1542. } catch (const std::exception& e) {
  1543. std::cerr << "Exception in handleJobOutput: " << e.what() << std::endl;
  1544. sendErrorResponse(res, std::string("Failed to get job output: ") + e.what(), 500, "OUTPUT_ERROR", requestId);
  1545. }
  1546. }
  1547. void Server::handleImageResize(const httplib::Request& req, httplib::Response& res) {
  1548. std::string requestId = generateRequestId();
  1549. try {
  1550. // Parse JSON request body
  1551. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  1552. // Validate required fields
  1553. if (!requestJson.contains("image") || !requestJson["image"].is_string()) {
  1554. sendErrorResponse(res, "Missing or invalid 'image' field", 400, "INVALID_PARAMETERS", requestId);
  1555. return;
  1556. }
  1557. if (!requestJson.contains("width") || !requestJson["width"].is_number_integer()) {
  1558. sendErrorResponse(res, "Missing or invalid 'width' field", 400, "INVALID_PARAMETERS", requestId);
  1559. return;
  1560. }
  1561. if (!requestJson.contains("height") || !requestJson["height"].is_number_integer()) {
  1562. sendErrorResponse(res, "Missing or invalid 'height' field", 400, "INVALID_PARAMETERS", requestId);
  1563. return;
  1564. }
  1565. std::string imageInput = requestJson["image"];
  1566. int targetWidth = requestJson["width"];
  1567. int targetHeight = requestJson["height"];
  1568. // Validate dimensions
  1569. if (targetWidth < 1 || targetWidth > 4096) {
  1570. sendErrorResponse(res, "Width must be between 1 and 4096", 400, "INVALID_DIMENSIONS", requestId);
  1571. return;
  1572. }
  1573. if (targetHeight < 1 || targetHeight > 4096) {
  1574. sendErrorResponse(res, "Height must be between 1 and 4096", 400, "INVALID_DIMENSIONS", requestId);
  1575. return;
  1576. }
  1577. // Load the source image
  1578. auto [imageData, sourceWidth, sourceHeight, sourceChannels, success, loadError] = loadImageFromInput(imageInput);
  1579. if (!success) {
  1580. sendErrorResponse(res, "Failed to load image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId);
  1581. return;
  1582. }
  1583. // Convert image data to stb_image format for processing
  1584. int channels = 3; // Force RGB
  1585. size_t sourceSize = sourceWidth * sourceHeight * channels;
  1586. std::vector<uint8_t> sourcePixels(sourceSize);
  1587. std::memcpy(sourcePixels.data(), imageData.data(), std::min(imageData.size(), sourceSize));
  1588. // Resize the image using stb_image_resize if available, otherwise use simple scaling
  1589. std::vector<uint8_t> resizedPixels(targetWidth * targetHeight * channels);
  1590. // Simple nearest-neighbor scaling for now (can be improved with better algorithms)
  1591. float xScale = static_cast<float>(sourceWidth) / targetWidth;
  1592. float yScale = static_cast<float>(sourceHeight) / targetHeight;
  1593. for (int y = 0; y < targetHeight; y++) {
  1594. for (int x = 0; x < targetWidth; x++) {
  1595. int sourceX = static_cast<int>(x * xScale);
  1596. int sourceY = static_cast<int>(y * yScale);
  1597. // Clamp to source bounds
  1598. sourceX = std::min(sourceX, sourceWidth - 1);
  1599. sourceY = std::min(sourceY, sourceHeight - 1);
  1600. for (int c = 0; c < channels; c++) {
  1601. resizedPixels[(y * targetWidth + x) * channels + c] =
  1602. sourcePixels[(sourceY * sourceWidth + sourceX) * channels + c];
  1603. }
  1604. }
  1605. }
  1606. // Convert resized image to base64
  1607. std::string base64Data = Utils::base64Encode(resizedPixels);
  1608. // Determine MIME type based on input
  1609. std::string mimeType = "image/jpeg"; // default
  1610. if (Utils::startsWith(imageInput, "data:image/png")) {
  1611. mimeType = "image/png";
  1612. } else if (Utils::startsWith(imageInput, "data:image/gif")) {
  1613. mimeType = "image/gif";
  1614. } else if (Utils::startsWith(imageInput, "data:image/webp")) {
  1615. mimeType = "image/webp";
  1616. } else if (Utils::startsWith(imageInput, "data:image/bmp")) {
  1617. mimeType = "image/bmp";
  1618. }
  1619. // Create data URL format
  1620. std::string dataUrl = "data:" + mimeType + ";base64," + base64Data;
  1621. // Build response
  1622. nlohmann::json response = {
  1623. {"success", true},
  1624. {"original_width", sourceWidth},
  1625. {"original_height", sourceHeight},
  1626. {"resized_width", targetWidth},
  1627. {"resized_height", targetHeight},
  1628. {"mime_type", mimeType},
  1629. {"base64_data", dataUrl},
  1630. {"file_size_bytes", resizedPixels.size()},
  1631. {"request_id", requestId}
  1632. };
  1633. sendJsonResponse(res, response, 200);
  1634. std::cout << "Successfully resized image from " << sourceWidth << "x" << sourceHeight
  1635. << " to " << targetWidth << "x" << targetHeight
  1636. << " (" << resizedPixels.size() << " bytes)" << std::endl;
  1637. } catch (const nlohmann::json::parse_error& e) {
  1638. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  1639. } catch (const std::exception& e) {
  1640. std::cerr << "Exception in handleImageResize: " << e.what() << std::endl;
  1641. sendErrorResponse(res, std::string("Failed to resize image: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  1642. }
  1643. }
  1644. void Server::handleImageCrop(const httplib::Request& req, httplib::Response& res) {
  1645. std::string requestId = generateRequestId();
  1646. try {
  1647. // Parse JSON request body
  1648. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  1649. // Validate required fields
  1650. if (!requestJson.contains("image") || !requestJson["image"].is_string()) {
  1651. sendErrorResponse(res, "Missing or invalid 'image' field", 400, "INVALID_PARAMETERS", requestId);
  1652. return;
  1653. }
  1654. if (!requestJson.contains("x") || !requestJson["x"].is_number_integer()) {
  1655. sendErrorResponse(res, "Missing or invalid 'x' field", 400, "INVALID_PARAMETERS", requestId);
  1656. return;
  1657. }
  1658. if (!requestJson.contains("y") || !requestJson["y"].is_number_integer()) {
  1659. sendErrorResponse(res, "Missing or invalid 'y' field", 400, "INVALID_PARAMETERS", requestId);
  1660. return;
  1661. }
  1662. if (!requestJson.contains("width") || !requestJson["width"].is_number_integer()) {
  1663. sendErrorResponse(res, "Missing or invalid 'width' field", 400, "INVALID_PARAMETERS", requestId);
  1664. return;
  1665. }
  1666. if (!requestJson.contains("height") || !requestJson["height"].is_number_integer()) {
  1667. sendErrorResponse(res, "Missing or invalid 'height' field", 400, "INVALID_PARAMETERS", requestId);
  1668. return;
  1669. }
  1670. std::string imageInput = requestJson["image"];
  1671. int cropX = requestJson["x"];
  1672. int cropY = requestJson["y"];
  1673. int cropWidth = requestJson["width"];
  1674. int cropHeight = requestJson["height"];
  1675. // Load the source image
  1676. auto [imageData, sourceWidth, sourceHeight, sourceChannels, success, loadError] = loadImageFromInput(imageInput);
  1677. if (!success) {
  1678. sendErrorResponse(res, "Failed to load image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId);
  1679. return;
  1680. }
  1681. // Validate crop dimensions
  1682. if (cropX < 0 || cropY < 0) {
  1683. sendErrorResponse(res, "Crop coordinates must be non-negative", 400, "INVALID_CROP_AREA", requestId);
  1684. return;
  1685. }
  1686. if (cropX + cropWidth > sourceWidth || cropY + cropHeight > sourceHeight) {
  1687. sendErrorResponse(res, "Crop area exceeds image dimensions", 400, "INVALID_CROP_AREA", requestId);
  1688. return;
  1689. }
  1690. if (cropWidth < 1 || cropHeight < 1) {
  1691. sendErrorResponse(res, "Crop width and height must be at least 1", 400, "INVALID_CROP_AREA", requestId);
  1692. return;
  1693. }
  1694. // Convert image data to stb_image format for processing
  1695. int channels = 3; // Force RGB
  1696. size_t sourceSize = sourceWidth * sourceHeight * channels;
  1697. std::vector<uint8_t> sourcePixels(sourceSize);
  1698. std::memcpy(sourcePixels.data(), imageData.data(), std::min(imageData.size(), sourceSize));
  1699. // Crop the image
  1700. std::vector<uint8_t> croppedPixels(cropWidth * cropHeight * channels);
  1701. for (int y = 0; y < cropHeight; y++) {
  1702. for (int x = 0; x < cropWidth; x++) {
  1703. int sourceX = cropX + x;
  1704. int sourceY = cropY + y;
  1705. for (int c = 0; c < channels; c++) {
  1706. croppedPixels[(y * cropWidth + x) * channels + c] =
  1707. sourcePixels[(sourceY * sourceWidth + sourceX) * channels + c];
  1708. }
  1709. }
  1710. }
  1711. // Convert cropped image to base64
  1712. std::string base64Data = Utils::base64Encode(croppedPixels);
  1713. // Determine MIME type based on input
  1714. std::string mimeType = "image/jpeg"; // default
  1715. if (Utils::startsWith(imageInput, "data:image/png")) {
  1716. mimeType = "image/png";
  1717. } else if (Utils::startsWith(imageInput, "data:image/gif")) {
  1718. mimeType = "image/gif";
  1719. } else if (Utils::startsWith(imageInput, "data:image/webp")) {
  1720. mimeType = "image/webp";
  1721. } else if (Utils::startsWith(imageInput, "data:image/bmp")) {
  1722. mimeType = "image/bmp";
  1723. }
  1724. // Create data URL format
  1725. std::string dataUrl = "data:" + mimeType + ";base64," + base64Data;
  1726. // Build response
  1727. nlohmann::json response = {
  1728. {"success", true},
  1729. {"original_width", sourceWidth},
  1730. {"original_height", sourceHeight},
  1731. {"crop_x", cropX},
  1732. {"crop_y", cropY},
  1733. {"cropped_width", cropWidth},
  1734. {"cropped_height", cropHeight},
  1735. {"mime_type", mimeType},
  1736. {"base64_data", dataUrl},
  1737. {"file_size_bytes", croppedPixels.size()},
  1738. {"request_id", requestId}
  1739. };
  1740. sendJsonResponse(res, response, 200);
  1741. std::cout << "Successfully cropped image from " << sourceWidth << "x" << sourceHeight
  1742. << " to " << cropWidth << "x" << cropHeight
  1743. << " at (" << cropX << "," << cropY << ")"
  1744. << " (" << croppedPixels.size() << " bytes)" << std::endl;
  1745. } catch (const nlohmann::json::parse_error& e) {
  1746. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  1747. } catch (const std::exception& e) {
  1748. std::cerr << "Exception in handleImageCrop: " << e.what() << std::endl;
  1749. sendErrorResponse(res, std::string("Failed to crop image: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  1750. }
  1751. }
  1752. void Server::handleDownloadImageFromUrl(const httplib::Request& req, httplib::Response& res) {
  1753. std::string requestId = generateRequestId();
  1754. try {
  1755. // Parse query parameters
  1756. std::string imageUrl = req.get_param_value("url");
  1757. if (imageUrl.empty()) {
  1758. sendErrorResponse(res, "Missing 'url' parameter", 400, "MISSING_URL", requestId);
  1759. return;
  1760. }
  1761. // Basic URL format validation
  1762. if (!Utils::startsWith(imageUrl, "http://") && !Utils::startsWith(imageUrl, "https://")) {
  1763. sendErrorResponse(res, "Invalid URL format. URL must start with http:// or https://", 400, "INVALID_URL_FORMAT", requestId);
  1764. return;
  1765. }
  1766. // Extract filename from URL for content type detection
  1767. std::string filename = imageUrl;
  1768. size_t lastSlash = imageUrl.find_last_of('/');
  1769. if (lastSlash != std::string::npos) {
  1770. filename = imageUrl.substr(lastSlash + 1);
  1771. }
  1772. // Remove query parameters and fragments
  1773. size_t questionMark = filename.find('?');
  1774. if (questionMark != std::string::npos) {
  1775. filename = filename.substr(0, questionMark);
  1776. }
  1777. size_t hashMark = filename.find('#');
  1778. if (hashMark != std::string::npos) {
  1779. filename = filename.substr(0, hashMark);
  1780. }
  1781. // Check if URL has image extension
  1782. std::string extension;
  1783. size_t lastDot = filename.find_last_of('.');
  1784. if (lastDot != std::string::npos) {
  1785. extension = filename.substr(lastDot + 1);
  1786. std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower);
  1787. }
  1788. // Validate image extension
  1789. const std::vector<std::string> validExtensions = {"jpg", "jpeg", "png", "gif", "webp", "bmp"};
  1790. if (extension.empty() || std::find(validExtensions.begin(), validExtensions.end(), extension) == validExtensions.end()) {
  1791. sendErrorResponse(res, "URL must point to an image file with a valid extension: " +
  1792. std::accumulate(validExtensions.begin(), validExtensions.end(), std::string(),
  1793. [](const std::string& a, const std::string& b) {
  1794. return a.empty() ? b : a + ", " + b;
  1795. }), 400, "INVALID_IMAGE_EXTENSION", requestId);
  1796. return;
  1797. }
  1798. // Load image using existing loadImageFromInput function
  1799. auto [imageData, width, height, channels, success, error] = loadImageFromInput(imageUrl);
  1800. if (!success) {
  1801. sendErrorResponse(res, "Failed to download image from URL: " + error, 400, "IMAGE_DOWNLOAD_FAILED", requestId);
  1802. return;
  1803. }
  1804. // Convert image data to base64
  1805. std::string base64Data = Utils::base64Encode(imageData);
  1806. // Determine MIME type based on extension
  1807. std::string mimeType = "image/jpeg"; // default
  1808. if (extension == "png") {
  1809. mimeType = "image/png";
  1810. } else if (extension == "gif") {
  1811. mimeType = "image/gif";
  1812. } else if (extension == "webp") {
  1813. mimeType = "image/webp";
  1814. } else if (extension == "bmp") {
  1815. mimeType = "image/bmp";
  1816. } else if (extension == "jpg" || extension == "jpeg") {
  1817. mimeType = "image/jpeg";
  1818. }
  1819. // Create data URL format
  1820. std::string dataUrl = "data:" + mimeType + ";base64," + base64Data;
  1821. // Build response
  1822. nlohmann::json response = {
  1823. {"success", true},
  1824. {"url", imageUrl},
  1825. {"filename", filename},
  1826. {"width", width},
  1827. {"height", height},
  1828. {"channels", channels},
  1829. {"mime_type", mimeType},
  1830. {"base64_data", dataUrl},
  1831. {"file_size_bytes", imageData.size()},
  1832. {"request_id", requestId}
  1833. };
  1834. sendJsonResponse(res, response, 200);
  1835. std::cout << "Successfully downloaded and encoded image from URL: " << imageUrl
  1836. << " (" << width << "x" << height << ", " << imageData.size() << " bytes)" << std::endl;
  1837. } catch (const nlohmann::json::parse_error& e) {
  1838. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  1839. } catch (const std::exception& e) {
  1840. std::cerr << "Exception in handleDownloadImageFromUrl: " << e.what() << std::endl;
  1841. sendErrorResponse(res, std::string("Failed to download image from URL: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  1842. }
  1843. }
  1844. void Server::sendJsonResponse(httplib::Response& res, const nlohmann::json& json, int status_code) {
  1845. res.set_header("Content-Type", "application/json");
  1846. res.status = status_code;
  1847. res.body = json.dump();
  1848. }
  1849. void Server::sendErrorResponse(httplib::Response& res, const std::string& message, int status_code,
  1850. const std::string& error_code, const std::string& request_id) {
  1851. nlohmann::json errorResponse = {
  1852. {"error", {
  1853. {"message", message},
  1854. {"status_code", status_code},
  1855. {"error_code", error_code},
  1856. {"request_id", request_id},
  1857. {"timestamp", std::chrono::duration_cast<std::chrono::seconds>(
  1858. std::chrono::system_clock::now().time_since_epoch()).count()}
  1859. }}
  1860. };
  1861. sendJsonResponse(res, errorResponse, status_code);
  1862. }
  1863. std::pair<bool, std::string> Server::validateGenerationParameters(const nlohmann::json& params) {
  1864. // Validate required fields
  1865. if (!params.contains("prompt") || !params["prompt"].is_string()) {
  1866. return {false, "Missing or invalid 'prompt' field"};
  1867. }
  1868. const std::string& prompt = params["prompt"];
  1869. if (prompt.empty()) {
  1870. return {false, "Prompt cannot be empty"};
  1871. }
  1872. if (prompt.length() > m_config.maxPromptLength) {
  1873. return {false, "Prompt too long (max " + std::to_string(m_config.maxPromptLength) + " characters)"};
  1874. }
  1875. // Validate negative prompt if present
  1876. if (params.contains("negative_prompt")) {
  1877. if (!params["negative_prompt"].is_string()) {
  1878. return {false, "Invalid 'negative_prompt' field, must be string"};
  1879. }
  1880. if (params["negative_prompt"].get<std::string>().length() > m_config.maxNegativePromptLength) {
  1881. return {false, "Negative prompt too long (max " + std::to_string(m_config.maxNegativePromptLength) + " characters)"};
  1882. }
  1883. }
  1884. // Validate width
  1885. if (params.contains("width")) {
  1886. if (!params["width"].is_number_integer()) {
  1887. return {false, "Invalid 'width' field, must be integer"};
  1888. }
  1889. int width = params["width"];
  1890. if (width < 64 || width > 2048 || width % 64 != 0) {
  1891. return {false, "Width must be between 64 and 2048 and divisible by 64"};
  1892. }
  1893. }
  1894. // Validate height
  1895. if (params.contains("height")) {
  1896. if (!params["height"].is_number_integer()) {
  1897. return {false, "Invalid 'height' field, must be integer"};
  1898. }
  1899. int height = params["height"];
  1900. if (height < 64 || height > 2048 || height % 64 != 0) {
  1901. return {false, "Height must be between 64 and 2048 and divisible by 64"};
  1902. }
  1903. }
  1904. // Validate batch count
  1905. if (params.contains("batch_count")) {
  1906. if (!params["batch_count"].is_number_integer()) {
  1907. return {false, "Invalid 'batch_count' field, must be integer"};
  1908. }
  1909. int batchCount = params["batch_count"];
  1910. if (batchCount < 1 || batchCount > 100) {
  1911. return {false, "Batch count must be between 1 and 100"};
  1912. }
  1913. }
  1914. // Validate steps
  1915. if (params.contains("steps")) {
  1916. if (!params["steps"].is_number_integer()) {
  1917. return {false, "Invalid 'steps' field, must be integer"};
  1918. }
  1919. int steps = params["steps"];
  1920. if (steps < 1 || steps > 150) {
  1921. return {false, "Steps must be between 1 and 150"};
  1922. }
  1923. }
  1924. // Validate CFG scale
  1925. if (params.contains("cfg_scale")) {
  1926. if (!params["cfg_scale"].is_number()) {
  1927. return {false, "Invalid 'cfg_scale' field, must be number"};
  1928. }
  1929. float cfgScale = params["cfg_scale"];
  1930. if (cfgScale < 1.0f || cfgScale > 30.0f) {
  1931. return {false, "CFG scale must be between 1.0 and 30.0"};
  1932. }
  1933. }
  1934. // Validate seed
  1935. if (params.contains("seed")) {
  1936. if (!params["seed"].is_string() && !params["seed"].is_number_integer()) {
  1937. return {false, "Invalid 'seed' field, must be string or integer"};
  1938. }
  1939. }
  1940. // Validate sampling method
  1941. if (params.contains("sampling_method")) {
  1942. if (!params["sampling_method"].is_string()) {
  1943. return {false, "Invalid 'sampling_method' field, must be string"};
  1944. }
  1945. std::string method = params["sampling_method"];
  1946. std::vector<std::string> validMethods = {
  1947. "euler", "euler_a", "heun", "dpm2", "dpm++2s_a", "dpm++2m",
  1948. "dpm++2mv2", "ipndm", "ipndm_v", "lcm", "ddim_trailing", "tcd", "default"
  1949. };
  1950. if (std::find(validMethods.begin(), validMethods.end(), method) == validMethods.end()) {
  1951. return {false, "Invalid sampling method"};
  1952. }
  1953. }
  1954. // Validate scheduler
  1955. if (params.contains("scheduler")) {
  1956. if (!params["scheduler"].is_string()) {
  1957. return {false, "Invalid 'scheduler' field, must be string"};
  1958. }
  1959. std::string scheduler = params["scheduler"];
  1960. std::vector<std::string> validSchedulers = {
  1961. "discrete", "karras", "exponential", "ays", "gits",
  1962. "smoothstep", "sgm_uniform", "simple", "default"
  1963. };
  1964. if (std::find(validSchedulers.begin(), validSchedulers.end(), scheduler) == validSchedulers.end()) {
  1965. return {false, "Invalid scheduler"};
  1966. }
  1967. }
  1968. // Validate strength
  1969. if (params.contains("strength")) {
  1970. if (!params["strength"].is_number()) {
  1971. return {false, "Invalid 'strength' field, must be number"};
  1972. }
  1973. float strength = params["strength"];
  1974. if (strength < 0.0f || strength > 1.0f) {
  1975. return {false, "Strength must be between 0.0 and 1.0"};
  1976. }
  1977. }
  1978. // Validate control strength
  1979. if (params.contains("control_strength")) {
  1980. if (!params["control_strength"].is_number()) {
  1981. return {false, "Invalid 'control_strength' field, must be number"};
  1982. }
  1983. float controlStrength = params["control_strength"];
  1984. if (controlStrength < 0.0f || controlStrength > 1.0f) {
  1985. return {false, "Control strength must be between 0.0 and 1.0"};
  1986. }
  1987. }
  1988. // Validate clip skip
  1989. if (params.contains("clip_skip")) {
  1990. if (!params["clip_skip"].is_number_integer()) {
  1991. return {false, "Invalid 'clip_skip' field, must be integer"};
  1992. }
  1993. int clipSkip = params["clip_skip"];
  1994. if (clipSkip < -1 || clipSkip > 12) {
  1995. return {false, "Clip skip must be between -1 and 12"};
  1996. }
  1997. }
  1998. // Validate threads
  1999. if (params.contains("threads")) {
  2000. if (!params["threads"].is_number_integer()) {
  2001. return {false, "Invalid 'threads' field, must be integer"};
  2002. }
  2003. int threads = params["threads"];
  2004. if (threads < -1 || threads > 32) {
  2005. return {false, "Threads must be between -1 (auto) and 32"};
  2006. }
  2007. }
  2008. return {true, ""};
  2009. }
  2010. SamplingMethod Server::parseSamplingMethod(const std::string& method) {
  2011. if (method == "euler") return SamplingMethod::EULER;
  2012. else if (method == "euler_a") return SamplingMethod::EULER_A;
  2013. else if (method == "heun") return SamplingMethod::HEUN;
  2014. else if (method == "dpm2") return SamplingMethod::DPM2;
  2015. else if (method == "dpm++2s_a") return SamplingMethod::DPMPP2S_A;
  2016. else if (method == "dpm++2m") return SamplingMethod::DPMPP2M;
  2017. else if (method == "dpm++2mv2") return SamplingMethod::DPMPP2MV2;
  2018. else if (method == "ipndm") return SamplingMethod::IPNDM;
  2019. else if (method == "ipndm_v") return SamplingMethod::IPNDM_V;
  2020. else if (method == "lcm") return SamplingMethod::LCM;
  2021. else if (method == "ddim_trailing") return SamplingMethod::DDIM_TRAILING;
  2022. else if (method == "tcd") return SamplingMethod::TCD;
  2023. else return SamplingMethod::DEFAULT;
  2024. }
  2025. Scheduler Server::parseScheduler(const std::string& scheduler) {
  2026. if (scheduler == "discrete") return Scheduler::DISCRETE;
  2027. else if (scheduler == "karras") return Scheduler::KARRAS;
  2028. else if (scheduler == "exponential") return Scheduler::EXPONENTIAL;
  2029. else if (scheduler == "ays") return Scheduler::AYS;
  2030. else if (scheduler == "gits") return Scheduler::GITS;
  2031. else if (scheduler == "smoothstep") return Scheduler::SMOOTHSTEP;
  2032. else if (scheduler == "sgm_uniform") return Scheduler::SGM_UNIFORM;
  2033. else if (scheduler == "simple") return Scheduler::SIMPLE;
  2034. else return Scheduler::DEFAULT;
  2035. }
  2036. std::string Server::generateRequestId() {
  2037. std::random_device rd;
  2038. std::mt19937 gen(rd());
  2039. std::uniform_int_distribution<> dis(100000, 999999);
  2040. return "req_" + std::to_string(dis(gen));
  2041. }
  2042. std::tuple<std::vector<uint8_t>, int, int, int, bool, std::string>
  2043. Server::loadImageFromInput(const std::string& input) {
  2044. std::vector<uint8_t> imageData;
  2045. int width = 0, height = 0, channels = 0;
  2046. // Auto-detect input source type
  2047. // 1. Check if input is a URL (starts with http:// or https://)
  2048. if (Utils::startsWith(input, "http://") || Utils::startsWith(input, "https://")) {
  2049. // Parse URL to extract host and path
  2050. std::string url = input;
  2051. std::string scheme, host, path;
  2052. int port = 80;
  2053. // Determine scheme and port
  2054. if (Utils::startsWith(url, "https://")) {
  2055. scheme = "https";
  2056. port = 443;
  2057. url = url.substr(8); // Remove "https://"
  2058. } else {
  2059. scheme = "http";
  2060. port = 80;
  2061. url = url.substr(7); // Remove "http://"
  2062. }
  2063. // Extract host and path
  2064. size_t slashPos = url.find('/');
  2065. if (slashPos != std::string::npos) {
  2066. host = url.substr(0, slashPos);
  2067. path = url.substr(slashPos);
  2068. } else {
  2069. host = url;
  2070. path = "/";
  2071. }
  2072. // Check for custom port
  2073. size_t colonPos = host.find(':');
  2074. if (colonPos != std::string::npos) {
  2075. try {
  2076. port = std::stoi(host.substr(colonPos + 1));
  2077. host = host.substr(0, colonPos);
  2078. } catch (...) {
  2079. return {imageData, 0, 0, 0, false, "Invalid port in URL"};
  2080. }
  2081. }
  2082. // Download image using httplib
  2083. try {
  2084. httplib::Result res;
  2085. if (scheme == "https") {
  2086. #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
  2087. httplib::SSLClient client(host, port);
  2088. client.set_follow_location(true);
  2089. client.set_connection_timeout(30, 0); // 30 seconds
  2090. client.set_read_timeout(60, 0); // 60 seconds
  2091. res = client.Get(path.c_str());
  2092. #else
  2093. return {imageData, 0, 0, 0, false, "HTTPS not supported (OpenSSL not available)"};
  2094. #endif
  2095. } else {
  2096. httplib::Client client(host, port);
  2097. client.set_follow_location(true);
  2098. client.set_connection_timeout(30, 0); // 30 seconds
  2099. client.set_read_timeout(60, 0); // 60 seconds
  2100. res = client.Get(path.c_str());
  2101. }
  2102. if (!res) {
  2103. return {imageData, 0, 0, 0, false, "Failed to download image from URL: Connection error"};
  2104. }
  2105. if (res->status != 200) {
  2106. return {imageData, 0, 0, 0, false, "Failed to download image from URL: HTTP " + std::to_string(res->status)};
  2107. }
  2108. // Convert response body to vector
  2109. std::vector<uint8_t> downloadedData(res->body.begin(), res->body.end());
  2110. // Load image from memory
  2111. int w, h, c;
  2112. unsigned char* pixels = stbi_load_from_memory(
  2113. downloadedData.data(),
  2114. downloadedData.size(),
  2115. &w, &h, &c,
  2116. 3 // Force RGB
  2117. );
  2118. if (!pixels) {
  2119. return {imageData, 0, 0, 0, false, "Failed to decode image from URL"};
  2120. }
  2121. width = w;
  2122. height = h;
  2123. channels = 3;
  2124. size_t dataSize = width * height * channels;
  2125. imageData.resize(dataSize);
  2126. std::memcpy(imageData.data(), pixels, dataSize);
  2127. stbi_image_free(pixels);
  2128. } catch (const std::exception& e) {
  2129. return {imageData, 0, 0, 0, false, "Failed to download image from URL: " + std::string(e.what())};
  2130. }
  2131. }
  2132. // 2. Check if input is base64 encoded data URI (starts with "data:image")
  2133. else if (Utils::startsWith(input, "data:image")) {
  2134. // Extract base64 data after the comma
  2135. size_t commaPos = input.find(',');
  2136. if (commaPos == std::string::npos) {
  2137. return {imageData, 0, 0, 0, false, "Invalid data URI format"};
  2138. }
  2139. std::string base64Data = input.substr(commaPos + 1);
  2140. std::vector<uint8_t> decodedData = Utils::base64Decode(base64Data);
  2141. // Load image from memory using stb_image
  2142. int w, h, c;
  2143. unsigned char* pixels = stbi_load_from_memory(
  2144. decodedData.data(),
  2145. decodedData.size(),
  2146. &w, &h, &c,
  2147. 3 // Force RGB
  2148. );
  2149. if (!pixels) {
  2150. return {imageData, 0, 0, 0, false, "Failed to decode image from base64 data URI"};
  2151. }
  2152. width = w;
  2153. height = h;
  2154. channels = 3; // We forced RGB
  2155. // Copy pixel data
  2156. size_t dataSize = width * height * channels;
  2157. imageData.resize(dataSize);
  2158. std::memcpy(imageData.data(), pixels, dataSize);
  2159. stbi_image_free(pixels);
  2160. }
  2161. // 3. Check if input is raw base64 (long string without slashes, likely base64)
  2162. else if (input.length() > 100 && input.find('/') == std::string::npos && input.find('.') == std::string::npos) {
  2163. // Likely raw base64 without data URI prefix
  2164. std::vector<uint8_t> decodedData = Utils::base64Decode(input);
  2165. int w, h, c;
  2166. unsigned char* pixels = stbi_load_from_memory(
  2167. decodedData.data(),
  2168. decodedData.size(),
  2169. &w, &h, &c,
  2170. 3 // Force RGB
  2171. );
  2172. if (!pixels) {
  2173. return {imageData, 0, 0, 0, false, "Failed to decode image from base64"};
  2174. }
  2175. width = w;
  2176. height = h;
  2177. channels = 3;
  2178. size_t dataSize = width * height * channels;
  2179. imageData.resize(dataSize);
  2180. std::memcpy(imageData.data(), pixels, dataSize);
  2181. stbi_image_free(pixels);
  2182. }
  2183. // 4. Treat as local file path
  2184. else {
  2185. int w, h, c;
  2186. unsigned char* pixels = stbi_load(input.c_str(), &w, &h, &c, 3);
  2187. if (!pixels) {
  2188. return {imageData, 0, 0, 0, false, "Failed to load image from file: " + input};
  2189. }
  2190. width = w;
  2191. height = h;
  2192. channels = 3;
  2193. size_t dataSize = width * height * channels;
  2194. imageData.resize(dataSize);
  2195. std::memcpy(imageData.data(), pixels, dataSize);
  2196. stbi_image_free(pixels);
  2197. }
  2198. return {imageData, width, height, channels, true, ""};
  2199. }
  2200. std::string Server::samplingMethodToString(SamplingMethod method) {
  2201. switch (method) {
  2202. case SamplingMethod::EULER: return "euler";
  2203. case SamplingMethod::EULER_A: return "euler_a";
  2204. case SamplingMethod::HEUN: return "heun";
  2205. case SamplingMethod::DPM2: return "dpm2";
  2206. case SamplingMethod::DPMPP2S_A: return "dpm++2s_a";
  2207. case SamplingMethod::DPMPP2M: return "dpm++2m";
  2208. case SamplingMethod::DPMPP2MV2: return "dpm++2mv2";
  2209. case SamplingMethod::IPNDM: return "ipndm";
  2210. case SamplingMethod::IPNDM_V: return "ipndm_v";
  2211. case SamplingMethod::LCM: return "lcm";
  2212. case SamplingMethod::DDIM_TRAILING: return "ddim_trailing";
  2213. case SamplingMethod::TCD: return "tcd";
  2214. default: return "default";
  2215. }
  2216. }
  2217. std::string Server::schedulerToString(Scheduler scheduler) {
  2218. switch (scheduler) {
  2219. case Scheduler::DISCRETE: return "discrete";
  2220. case Scheduler::KARRAS: return "karras";
  2221. case Scheduler::EXPONENTIAL: return "exponential";
  2222. case Scheduler::AYS: return "ays";
  2223. case Scheduler::GITS: return "gits";
  2224. case Scheduler::SMOOTHSTEP: return "smoothstep";
  2225. case Scheduler::SGM_UNIFORM: return "sgm_uniform";
  2226. case Scheduler::SIMPLE: return "simple";
  2227. default: return "default";
  2228. }
  2229. }
  2230. uint64_t Server::estimateGenerationTime(const GenerationRequest& request) {
  2231. // Basic estimation based on parameters
  2232. uint64_t baseTime = 1000; // 1 second base time
  2233. // Factor in steps
  2234. baseTime *= request.steps;
  2235. // Factor in resolution
  2236. double resolutionFactor = (request.width * request.height) / (512.0 * 512.0);
  2237. baseTime = static_cast<uint64_t>(baseTime * resolutionFactor);
  2238. // Factor in batch count
  2239. baseTime *= request.batchCount;
  2240. // Adjust for sampling method (some are faster than others)
  2241. switch (request.samplingMethod) {
  2242. case SamplingMethod::LCM:
  2243. baseTime /= 4; // LCM is much faster
  2244. break;
  2245. case SamplingMethod::EULER:
  2246. case SamplingMethod::EULER_A:
  2247. baseTime *= 0.8; // Euler methods are faster
  2248. break;
  2249. case SamplingMethod::DPM2:
  2250. case SamplingMethod::DPMPP2S_A:
  2251. baseTime *= 1.2; // DPM methods are slower
  2252. break;
  2253. default:
  2254. break;
  2255. }
  2256. return baseTime;
  2257. }
  2258. size_t Server::estimateMemoryUsage(const GenerationRequest& request) {
  2259. // Basic memory estimation in bytes
  2260. size_t baseMemory = 1024 * 1024 * 1024; // 1GB base
  2261. // Factor in resolution
  2262. double resolutionFactor = (request.width * request.height) / (512.0 * 512.0);
  2263. baseMemory = static_cast<size_t>(baseMemory * resolutionFactor);
  2264. // Factor in batch count
  2265. baseMemory *= request.batchCount;
  2266. // Additional memory for certain features
  2267. if (request.diffusionFlashAttn) {
  2268. baseMemory += 512 * 1024 * 1024; // Extra 512MB for flash attention
  2269. }
  2270. if (!request.controlNetPath.empty()) {
  2271. baseMemory += 1024 * 1024 * 1024; // Extra 1GB for ControlNet
  2272. }
  2273. return baseMemory;
  2274. }
  2275. // Specialized generation endpoints
  2276. void Server::handleText2Img(const httplib::Request& req, httplib::Response& res) {
  2277. std::string requestId = generateRequestId();
  2278. try {
  2279. if (!m_generationQueue) {
  2280. sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
  2281. return;
  2282. }
  2283. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  2284. // Validate required fields for text2img
  2285. if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
  2286. sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
  2287. return;
  2288. }
  2289. // Validate all parameters
  2290. auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
  2291. if (!isValid) {
  2292. sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
  2293. return;
  2294. }
  2295. // Check if any model is loaded
  2296. if (!m_modelManager) {
  2297. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  2298. return;
  2299. }
  2300. // Get currently loaded checkpoint model
  2301. auto allModels = m_modelManager->getAllModels();
  2302. std::string loadedModelName;
  2303. for (const auto& [modelName, modelInfo] : allModels) {
  2304. if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
  2305. loadedModelName = modelName;
  2306. break;
  2307. }
  2308. }
  2309. if (loadedModelName.empty()) {
  2310. sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
  2311. return;
  2312. }
  2313. // Create generation request specifically for text2img
  2314. GenerationRequest genRequest;
  2315. genRequest.id = requestId;
  2316. genRequest.modelName = loadedModelName; // Use the currently loaded model
  2317. genRequest.prompt = requestJson["prompt"];
  2318. genRequest.negativePrompt = requestJson.value("negative_prompt", "");
  2319. genRequest.width = requestJson.value("width", 512);
  2320. genRequest.height = requestJson.value("height", 512);
  2321. genRequest.batchCount = requestJson.value("batch_count", 1);
  2322. genRequest.steps = requestJson.value("steps", 20);
  2323. genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
  2324. genRequest.seed = requestJson.value("seed", "random");
  2325. // Parse optional parameters
  2326. if (requestJson.contains("sampling_method")) {
  2327. genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
  2328. }
  2329. if (requestJson.contains("scheduler")) {
  2330. genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
  2331. }
  2332. // Set text2img specific defaults
  2333. genRequest.strength = 1.0f; // Full strength for text2img
  2334. // Optional VAE model
  2335. if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
  2336. std::string vaeModelId = requestJson["vae_model"];
  2337. if (!vaeModelId.empty()) {
  2338. auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
  2339. if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
  2340. genRequest.vaePath = vaeInfo.path;
  2341. } else {
  2342. sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
  2343. return;
  2344. }
  2345. }
  2346. }
  2347. // Optional TAESD model
  2348. if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
  2349. std::string taesdModelId = requestJson["taesd_model"];
  2350. if (!taesdModelId.empty()) {
  2351. auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
  2352. if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
  2353. genRequest.taesdPath = taesdInfo.path;
  2354. } else {
  2355. sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
  2356. return;
  2357. }
  2358. }
  2359. }
  2360. // Enqueue request
  2361. auto future = m_generationQueue->enqueueRequest(genRequest);
  2362. nlohmann::json params = {
  2363. {"prompt", genRequest.prompt},
  2364. {"negative_prompt", genRequest.negativePrompt},
  2365. {"model", genRequest.modelName},
  2366. {"width", genRequest.width},
  2367. {"height", genRequest.height},
  2368. {"batch_count", genRequest.batchCount},
  2369. {"steps", genRequest.steps},
  2370. {"cfg_scale", genRequest.cfgScale},
  2371. {"seed", genRequest.seed},
  2372. {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
  2373. {"scheduler", schedulerToString(genRequest.scheduler)}
  2374. };
  2375. // Add VAE/TAESD if specified
  2376. if (!genRequest.vaePath.empty()) {
  2377. params["vae_model"] = requestJson.value("vae_model", "");
  2378. }
  2379. if (!genRequest.taesdPath.empty()) {
  2380. params["taesd_model"] = requestJson.value("taesd_model", "");
  2381. }
  2382. nlohmann::json response = {
  2383. {"request_id", requestId},
  2384. {"status", "queued"},
  2385. {"message", "Text-to-image generation request queued successfully"},
  2386. {"queue_position", m_generationQueue->getQueueSize()},
  2387. {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
  2388. {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
  2389. {"type", "text2img"},
  2390. {"parameters", params}
  2391. };
  2392. sendJsonResponse(res, response, 202);
  2393. } catch (const nlohmann::json::parse_error& e) {
  2394. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  2395. } catch (const std::exception& e) {
  2396. sendErrorResponse(res, std::string("Text-to-image request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  2397. }
  2398. }
  2399. void Server::handleImg2Img(const httplib::Request& req, httplib::Response& res) {
  2400. std::string requestId = generateRequestId();
  2401. try {
  2402. if (!m_generationQueue) {
  2403. sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
  2404. return;
  2405. }
  2406. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  2407. // Validate required fields for img2img
  2408. if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
  2409. sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
  2410. return;
  2411. }
  2412. if (!requestJson.contains("init_image") || !requestJson["init_image"].is_string()) {
  2413. sendErrorResponse(res, "Missing or invalid 'init_image' field", 400, "INVALID_PARAMETERS", requestId);
  2414. return;
  2415. }
  2416. // Validate all parameters
  2417. auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
  2418. if (!isValid) {
  2419. sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
  2420. return;
  2421. }
  2422. // Check if any model is loaded
  2423. if (!m_modelManager) {
  2424. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  2425. return;
  2426. }
  2427. // Get currently loaded checkpoint model
  2428. auto allModels = m_modelManager->getAllModels();
  2429. std::string loadedModelName;
  2430. for (const auto& [modelName, modelInfo] : allModels) {
  2431. if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
  2432. loadedModelName = modelName;
  2433. break;
  2434. }
  2435. }
  2436. if (loadedModelName.empty()) {
  2437. sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
  2438. return;
  2439. }
  2440. // Load the init image
  2441. std::string initImageInput = requestJson["init_image"];
  2442. auto [imageData, imgWidth, imgHeight, imgChannels, success, loadError] = loadImageFromInput(initImageInput);
  2443. if (!success) {
  2444. sendErrorResponse(res, "Failed to load init image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId);
  2445. return;
  2446. }
  2447. // Create generation request specifically for img2img
  2448. GenerationRequest genRequest;
  2449. genRequest.id = requestId;
  2450. genRequest.requestType = GenerationRequest::RequestType::IMG2IMG;
  2451. genRequest.modelName = loadedModelName; // Use the currently loaded model
  2452. genRequest.prompt = requestJson["prompt"];
  2453. genRequest.negativePrompt = requestJson.value("negative_prompt", "");
  2454. genRequest.width = requestJson.value("width", imgWidth); // Default to input image dimensions
  2455. genRequest.height = requestJson.value("height", imgHeight);
  2456. genRequest.batchCount = requestJson.value("batch_count", 1);
  2457. genRequest.steps = requestJson.value("steps", 20);
  2458. genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
  2459. genRequest.seed = requestJson.value("seed", "random");
  2460. genRequest.strength = requestJson.value("strength", 0.75f);
  2461. // Set init image data
  2462. genRequest.initImageData = imageData;
  2463. genRequest.initImageWidth = imgWidth;
  2464. genRequest.initImageHeight = imgHeight;
  2465. genRequest.initImageChannels = imgChannels;
  2466. // Parse optional parameters
  2467. if (requestJson.contains("sampling_method")) {
  2468. genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
  2469. }
  2470. if (requestJson.contains("scheduler")) {
  2471. genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
  2472. }
  2473. // Optional VAE model
  2474. if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
  2475. std::string vaeModelId = requestJson["vae_model"];
  2476. if (!vaeModelId.empty()) {
  2477. auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
  2478. if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
  2479. genRequest.vaePath = vaeInfo.path;
  2480. } else {
  2481. sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
  2482. return;
  2483. }
  2484. }
  2485. }
  2486. // Optional TAESD model
  2487. if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
  2488. std::string taesdModelId = requestJson["taesd_model"];
  2489. if (!taesdModelId.empty()) {
  2490. auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
  2491. if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
  2492. genRequest.taesdPath = taesdInfo.path;
  2493. } else {
  2494. sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
  2495. return;
  2496. }
  2497. }
  2498. }
  2499. // Enqueue request
  2500. auto future = m_generationQueue->enqueueRequest(genRequest);
  2501. nlohmann::json params = {
  2502. {"prompt", genRequest.prompt},
  2503. {"negative_prompt", genRequest.negativePrompt},
  2504. {"init_image", requestJson["init_image"]},
  2505. {"model", genRequest.modelName},
  2506. {"width", genRequest.width},
  2507. {"height", genRequest.height},
  2508. {"batch_count", genRequest.batchCount},
  2509. {"steps", genRequest.steps},
  2510. {"cfg_scale", genRequest.cfgScale},
  2511. {"seed", genRequest.seed},
  2512. {"strength", genRequest.strength},
  2513. {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
  2514. {"scheduler", schedulerToString(genRequest.scheduler)}
  2515. };
  2516. // Add VAE/TAESD if specified
  2517. if (!genRequest.vaePath.empty()) {
  2518. params["vae_model"] = requestJson.value("vae_model", "");
  2519. }
  2520. if (!genRequest.taesdPath.empty()) {
  2521. params["taesd_model"] = requestJson.value("taesd_model", "");
  2522. }
  2523. nlohmann::json response = {
  2524. {"request_id", requestId},
  2525. {"status", "queued"},
  2526. {"message", "Image-to-image generation request queued successfully"},
  2527. {"queue_position", m_generationQueue->getQueueSize()},
  2528. {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
  2529. {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
  2530. {"type", "img2img"},
  2531. {"parameters", params}
  2532. };
  2533. sendJsonResponse(res, response, 202);
  2534. } catch (const nlohmann::json::parse_error& e) {
  2535. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  2536. } catch (const std::exception& e) {
  2537. sendErrorResponse(res, std::string("Image-to-image request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  2538. }
  2539. }
  2540. void Server::handleControlNet(const httplib::Request& req, httplib::Response& res) {
  2541. std::string requestId = generateRequestId();
  2542. try {
  2543. if (!m_generationQueue) {
  2544. sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
  2545. return;
  2546. }
  2547. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  2548. // Validate required fields for ControlNet
  2549. if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
  2550. sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
  2551. return;
  2552. }
  2553. if (!requestJson.contains("control_image") || !requestJson["control_image"].is_string()) {
  2554. sendErrorResponse(res, "Missing or invalid 'control_image' field", 400, "INVALID_PARAMETERS", requestId);
  2555. return;
  2556. }
  2557. // Validate all parameters
  2558. auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
  2559. if (!isValid) {
  2560. sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
  2561. return;
  2562. }
  2563. // Check if any model is loaded
  2564. if (!m_modelManager) {
  2565. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  2566. return;
  2567. }
  2568. // Get currently loaded checkpoint model
  2569. auto allModels = m_modelManager->getAllModels();
  2570. std::string loadedModelName;
  2571. for (const auto& [modelName, modelInfo] : allModels) {
  2572. if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
  2573. loadedModelName = modelName;
  2574. break;
  2575. }
  2576. }
  2577. if (loadedModelName.empty()) {
  2578. sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
  2579. return;
  2580. }
  2581. // Create generation request specifically for ControlNet
  2582. GenerationRequest genRequest;
  2583. genRequest.id = requestId;
  2584. genRequest.modelName = loadedModelName; // Use the currently loaded model
  2585. genRequest.prompt = requestJson["prompt"];
  2586. genRequest.negativePrompt = requestJson.value("negative_prompt", "");
  2587. genRequest.width = requestJson.value("width", 512);
  2588. genRequest.height = requestJson.value("height", 512);
  2589. genRequest.batchCount = requestJson.value("batch_count", 1);
  2590. genRequest.steps = requestJson.value("steps", 20);
  2591. genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
  2592. genRequest.seed = requestJson.value("seed", "random");
  2593. genRequest.controlStrength = requestJson.value("control_strength", 0.9f);
  2594. genRequest.controlNetPath = requestJson.value("control_net_model", "");
  2595. // Parse optional parameters
  2596. if (requestJson.contains("sampling_method")) {
  2597. genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
  2598. }
  2599. if (requestJson.contains("scheduler")) {
  2600. genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
  2601. }
  2602. // Optional VAE model
  2603. if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
  2604. std::string vaeModelId = requestJson["vae_model"];
  2605. if (!vaeModelId.empty()) {
  2606. auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
  2607. if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
  2608. genRequest.vaePath = vaeInfo.path;
  2609. } else {
  2610. sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
  2611. return;
  2612. }
  2613. }
  2614. }
  2615. // Optional TAESD model
  2616. if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
  2617. std::string taesdModelId = requestJson["taesd_model"];
  2618. if (!taesdModelId.empty()) {
  2619. auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
  2620. if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
  2621. genRequest.taesdPath = taesdInfo.path;
  2622. } else {
  2623. sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
  2624. return;
  2625. }
  2626. }
  2627. }
  2628. // Store control image path (would be handled in actual implementation)
  2629. genRequest.outputPath = requestJson.value("control_image", "");
  2630. // Enqueue request
  2631. auto future = m_generationQueue->enqueueRequest(genRequest);
  2632. nlohmann::json params = {
  2633. {"prompt", genRequest.prompt},
  2634. {"negative_prompt", genRequest.negativePrompt},
  2635. {"control_image", requestJson["control_image"]},
  2636. {"control_net_model", genRequest.controlNetPath},
  2637. {"model", genRequest.modelName},
  2638. {"width", genRequest.width},
  2639. {"height", genRequest.height},
  2640. {"batch_count", genRequest.batchCount},
  2641. {"steps", genRequest.steps},
  2642. {"cfg_scale", genRequest.cfgScale},
  2643. {"seed", genRequest.seed},
  2644. {"control_strength", genRequest.controlStrength},
  2645. {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
  2646. {"scheduler", schedulerToString(genRequest.scheduler)}
  2647. };
  2648. // Add VAE/TAESD if specified
  2649. if (!genRequest.vaePath.empty()) {
  2650. params["vae_model"] = requestJson.value("vae_model", "");
  2651. }
  2652. if (!genRequest.taesdPath.empty()) {
  2653. params["taesd_model"] = requestJson.value("taesd_model", "");
  2654. }
  2655. nlohmann::json response = {
  2656. {"request_id", requestId},
  2657. {"status", "queued"},
  2658. {"message", "ControlNet generation request queued successfully"},
  2659. {"queue_position", m_generationQueue->getQueueSize()},
  2660. {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
  2661. {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
  2662. {"type", "controlnet"},
  2663. {"parameters", params}
  2664. };
  2665. sendJsonResponse(res, response, 202);
  2666. } catch (const nlohmann::json::parse_error& e) {
  2667. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  2668. } catch (const std::exception& e) {
  2669. sendErrorResponse(res, std::string("ControlNet request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  2670. }
  2671. }
  2672. void Server::handleUpscale(const httplib::Request& req, httplib::Response& res) {
  2673. std::string requestId = generateRequestId();
  2674. try {
  2675. if (!m_generationQueue) {
  2676. sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
  2677. return;
  2678. }
  2679. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  2680. // Validate required fields for upscaler
  2681. if (!requestJson.contains("image") || !requestJson["image"].is_string()) {
  2682. sendErrorResponse(res, "Missing or invalid 'image' field", 400, "INVALID_PARAMETERS", requestId);
  2683. return;
  2684. }
  2685. if (!requestJson.contains("esrgan_model") || !requestJson["esrgan_model"].is_string()) {
  2686. sendErrorResponse(res, "Missing or invalid 'esrgan_model' field (model hash or name)", 400, "INVALID_PARAMETERS", requestId);
  2687. return;
  2688. }
  2689. // Check if model manager is available
  2690. if (!m_modelManager) {
  2691. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  2692. return;
  2693. }
  2694. // Get the ESRGAN/upscaler model
  2695. std::string esrganModelId = requestJson["esrgan_model"];
  2696. auto modelInfo = m_modelManager->getModelInfo(esrganModelId);
  2697. if (modelInfo.name.empty()) {
  2698. sendErrorResponse(res, "ESRGAN model not found: " + esrganModelId, 404, "MODEL_NOT_FOUND", requestId);
  2699. return;
  2700. }
  2701. if (modelInfo.type != ModelType::ESRGAN && modelInfo.type != ModelType::UPSCALER) {
  2702. sendErrorResponse(res, "Model is not an ESRGAN/upscaler model", 400, "INVALID_MODEL_TYPE", requestId);
  2703. return;
  2704. }
  2705. // Load the input image
  2706. std::string imageInput = requestJson["image"];
  2707. auto [imageData, imgWidth, imgHeight, imgChannels, success, loadError] = loadImageFromInput(imageInput);
  2708. if (!success) {
  2709. sendErrorResponse(res, "Failed to load image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId);
  2710. return;
  2711. }
  2712. // Create upscaler request
  2713. GenerationRequest genRequest;
  2714. genRequest.id = requestId;
  2715. genRequest.requestType = GenerationRequest::RequestType::UPSCALER;
  2716. genRequest.esrganPath = modelInfo.path;
  2717. genRequest.upscaleFactor = requestJson.value("upscale_factor", 4);
  2718. genRequest.nThreads = requestJson.value("threads", -1);
  2719. genRequest.offloadParamsToCpu = requestJson.value("offload_to_cpu", false);
  2720. genRequest.diffusionConvDirect = requestJson.value("direct", false);
  2721. // Set input image data
  2722. genRequest.initImageData = imageData;
  2723. genRequest.initImageWidth = imgWidth;
  2724. genRequest.initImageHeight = imgHeight;
  2725. genRequest.initImageChannels = imgChannels;
  2726. // Enqueue request
  2727. auto future = m_generationQueue->enqueueRequest(genRequest);
  2728. nlohmann::json response = {
  2729. {"request_id", requestId},
  2730. {"status", "queued"},
  2731. {"message", "Upscale request queued successfully"},
  2732. {"queue_position", m_generationQueue->getQueueSize()},
  2733. {"type", "upscale"},
  2734. {"parameters", {
  2735. {"esrgan_model", esrganModelId},
  2736. {"upscale_factor", genRequest.upscaleFactor},
  2737. {"input_width", imgWidth},
  2738. {"input_height", imgHeight},
  2739. {"output_width", imgWidth * genRequest.upscaleFactor},
  2740. {"output_height", imgHeight * genRequest.upscaleFactor}
  2741. }}
  2742. };
  2743. sendJsonResponse(res, response, 202);
  2744. } catch (const nlohmann::json::parse_error& e) {
  2745. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  2746. } catch (const std::exception& e) {
  2747. sendErrorResponse(res, std::string("Upscale request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  2748. }
  2749. }
  2750. void Server::handleInpainting(const httplib::Request& req, httplib::Response& res) {
  2751. std::string requestId = generateRequestId();
  2752. try {
  2753. if (!m_generationQueue) {
  2754. sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
  2755. return;
  2756. }
  2757. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  2758. // Validate required fields for inpainting
  2759. if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
  2760. sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
  2761. return;
  2762. }
  2763. if (!requestJson.contains("source_image") || !requestJson["source_image"].is_string()) {
  2764. sendErrorResponse(res, "Missing or invalid 'source_image' field", 400, "INVALID_PARAMETERS", requestId);
  2765. return;
  2766. }
  2767. if (!requestJson.contains("mask_image") || !requestJson["mask_image"].is_string()) {
  2768. sendErrorResponse(res, "Missing or invalid 'mask_image' field", 400, "INVALID_PARAMETERS", requestId);
  2769. return;
  2770. }
  2771. // Validate all parameters
  2772. auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
  2773. if (!isValid) {
  2774. sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
  2775. return;
  2776. }
  2777. // Check if any model is loaded
  2778. if (!m_modelManager) {
  2779. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  2780. return;
  2781. }
  2782. // Get currently loaded checkpoint model
  2783. auto allModels = m_modelManager->getAllModels();
  2784. std::string loadedModelName;
  2785. for (const auto& [modelName, modelInfo] : allModels) {
  2786. if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
  2787. loadedModelName = modelName;
  2788. break;
  2789. }
  2790. }
  2791. if (loadedModelName.empty()) {
  2792. sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
  2793. return;
  2794. }
  2795. // Load the source image
  2796. std::string sourceImageInput = requestJson["source_image"];
  2797. auto [sourceImageData, sourceImgWidth, sourceImgHeight, sourceImgChannels, sourceSuccess, sourceLoadError] = loadImageFromInput(sourceImageInput);
  2798. if (!sourceSuccess) {
  2799. sendErrorResponse(res, "Failed to load source image: " + sourceLoadError, 400, "IMAGE_LOAD_ERROR", requestId);
  2800. return;
  2801. }
  2802. // Load the mask image
  2803. std::string maskImageInput = requestJson["mask_image"];
  2804. auto [maskImageData, maskImgWidth, maskImgHeight, maskImgChannels, maskSuccess, maskLoadError] = loadImageFromInput(maskImageInput);
  2805. if (!maskSuccess) {
  2806. sendErrorResponse(res, "Failed to load mask image: " + maskLoadError, 400, "MASK_LOAD_ERROR", requestId);
  2807. return;
  2808. }
  2809. // Validate that source and mask images have compatible dimensions
  2810. if (sourceImgWidth != maskImgWidth || sourceImgHeight != maskImgHeight) {
  2811. sendErrorResponse(res, "Source and mask images must have the same dimensions", 400, "DIMENSION_MISMATCH", requestId);
  2812. return;
  2813. }
  2814. // Create generation request specifically for inpainting
  2815. GenerationRequest genRequest;
  2816. genRequest.id = requestId;
  2817. genRequest.requestType = GenerationRequest::RequestType::INPAINTING;
  2818. genRequest.modelName = loadedModelName; // Use the currently loaded model
  2819. genRequest.prompt = requestJson["prompt"];
  2820. genRequest.negativePrompt = requestJson.value("negative_prompt", "");
  2821. genRequest.width = requestJson.value("width", sourceImgWidth); // Default to input image dimensions
  2822. genRequest.height = requestJson.value("height", sourceImgHeight);
  2823. genRequest.batchCount = requestJson.value("batch_count", 1);
  2824. genRequest.steps = requestJson.value("steps", 20);
  2825. genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
  2826. genRequest.seed = requestJson.value("seed", "random");
  2827. genRequest.strength = requestJson.value("strength", 0.75f);
  2828. // Set source image data
  2829. genRequest.initImageData = sourceImageData;
  2830. genRequest.initImageWidth = sourceImgWidth;
  2831. genRequest.initImageHeight = sourceImgHeight;
  2832. genRequest.initImageChannels = sourceImgChannels;
  2833. // Set mask image data
  2834. genRequest.maskImageData = maskImageData;
  2835. genRequest.maskImageWidth = maskImgWidth;
  2836. genRequest.maskImageHeight = maskImgHeight;
  2837. genRequest.maskImageChannels = maskImgChannels;
  2838. // Parse optional parameters
  2839. if (requestJson.contains("sampling_method")) {
  2840. genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
  2841. }
  2842. if (requestJson.contains("scheduler")) {
  2843. genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
  2844. }
  2845. // Optional VAE model
  2846. if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
  2847. std::string vaeModelId = requestJson["vae_model"];
  2848. if (!vaeModelId.empty()) {
  2849. auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
  2850. if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
  2851. genRequest.vaePath = vaeInfo.path;
  2852. } else {
  2853. sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
  2854. return;
  2855. }
  2856. }
  2857. }
  2858. // Optional TAESD model
  2859. if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
  2860. std::string taesdModelId = requestJson["taesd_model"];
  2861. if (!taesdModelId.empty()) {
  2862. auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
  2863. if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
  2864. genRequest.taesdPath = taesdInfo.path;
  2865. } else {
  2866. sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
  2867. return;
  2868. }
  2869. }
  2870. }
  2871. // Enqueue request
  2872. auto future = m_generationQueue->enqueueRequest(genRequest);
  2873. nlohmann::json params = {
  2874. {"prompt", genRequest.prompt},
  2875. {"negative_prompt", genRequest.negativePrompt},
  2876. {"source_image", requestJson["source_image"]},
  2877. {"mask_image", requestJson["mask_image"]},
  2878. {"model", genRequest.modelName},
  2879. {"width", genRequest.width},
  2880. {"height", genRequest.height},
  2881. {"batch_count", genRequest.batchCount},
  2882. {"steps", genRequest.steps},
  2883. {"cfg_scale", genRequest.cfgScale},
  2884. {"seed", genRequest.seed},
  2885. {"strength", genRequest.strength},
  2886. {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
  2887. {"scheduler", schedulerToString(genRequest.scheduler)}
  2888. };
  2889. // Add VAE/TAESD if specified
  2890. if (!genRequest.vaePath.empty()) {
  2891. params["vae_model"] = requestJson.value("vae_model", "");
  2892. }
  2893. if (!genRequest.taesdPath.empty()) {
  2894. params["taesd_model"] = requestJson.value("taesd_model", "");
  2895. }
  2896. nlohmann::json response = {
  2897. {"request_id", requestId},
  2898. {"status", "queued"},
  2899. {"message", "Inpainting generation request queued successfully"},
  2900. {"queue_position", m_generationQueue->getQueueSize()},
  2901. {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
  2902. {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
  2903. {"type", "inpainting"},
  2904. {"parameters", params}
  2905. };
  2906. sendJsonResponse(res, response, 202);
  2907. } catch (const nlohmann::json::parse_error& e) {
  2908. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  2909. } catch (const std::exception& e) {
  2910. sendErrorResponse(res, std::string("Inpainting request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  2911. }
  2912. }
  2913. // Utility endpoints
  2914. void Server::handleSamplers(const httplib::Request& /*req*/, httplib::Response& res) {
  2915. try {
  2916. nlohmann::json samplers = {
  2917. {"samplers", {
  2918. {
  2919. {"name", "euler"},
  2920. {"description", "Euler sampler - fast and simple"},
  2921. {"recommended_steps", 20}
  2922. },
  2923. {
  2924. {"name", "euler_a"},
  2925. {"description", "Euler ancestral sampler - adds randomness"},
  2926. {"recommended_steps", 20}
  2927. },
  2928. {
  2929. {"name", "heun"},
  2930. {"description", "Heun sampler - more accurate but slower"},
  2931. {"recommended_steps", 20}
  2932. },
  2933. {
  2934. {"name", "dpm2"},
  2935. {"description", "DPM2 sampler - second-order DPM"},
  2936. {"recommended_steps", 20}
  2937. },
  2938. {
  2939. {"name", "dpm++2s_a"},
  2940. {"description", "DPM++ 2s ancestral sampler"},
  2941. {"recommended_steps", 20}
  2942. },
  2943. {
  2944. {"name", "dpm++2m"},
  2945. {"description", "DPM++ 2m sampler - multistep"},
  2946. {"recommended_steps", 20}
  2947. },
  2948. {
  2949. {"name", "dpm++2mv2"},
  2950. {"description", "DPM++ 2m v2 sampler - improved multistep"},
  2951. {"recommended_steps", 20}
  2952. },
  2953. {
  2954. {"name", "ipndm"},
  2955. {"description", "IPNDM sampler - improved noise prediction"},
  2956. {"recommended_steps", 20}
  2957. },
  2958. {
  2959. {"name", "ipndm_v"},
  2960. {"description", "IPNDM v sampler - variant of IPNDM"},
  2961. {"recommended_steps", 20}
  2962. },
  2963. {
  2964. {"name", "lcm"},
  2965. {"description", "LCM sampler - Latent Consistency Model, very fast"},
  2966. {"recommended_steps", 4}
  2967. },
  2968. {
  2969. {"name", "ddim_trailing"},
  2970. {"description", "DDIM trailing sampler - deterministic"},
  2971. {"recommended_steps", 20}
  2972. },
  2973. {
  2974. {"name", "tcd"},
  2975. {"description", "TCD sampler - Trajectory Consistency Distillation"},
  2976. {"recommended_steps", 8}
  2977. },
  2978. {
  2979. {"name", "default"},
  2980. {"description", "Use model's default sampler"},
  2981. {"recommended_steps", 20}
  2982. }
  2983. }}
  2984. };
  2985. sendJsonResponse(res, samplers);
  2986. } catch (const std::exception& e) {
  2987. sendErrorResponse(res, std::string("Failed to get samplers: ") + e.what(), 500);
  2988. }
  2989. }
  2990. void Server::handleSchedulers(const httplib::Request& /*req*/, httplib::Response& res) {
  2991. try {
  2992. nlohmann::json schedulers = {
  2993. {"schedulers", {
  2994. {
  2995. {"name", "discrete"},
  2996. {"description", "Discrete scheduler - standard noise schedule"}
  2997. },
  2998. {
  2999. {"name", "karras"},
  3000. {"description", "Karras scheduler - improved noise schedule"}
  3001. },
  3002. {
  3003. {"name", "exponential"},
  3004. {"description", "Exponential scheduler - exponential noise decay"}
  3005. },
  3006. {
  3007. {"name", "ays"},
  3008. {"description", "AYS scheduler - Adaptive Your Scheduler"}
  3009. },
  3010. {
  3011. {"name", "gits"},
  3012. {"description", "GITS scheduler - Generalized Iterative Time Steps"}
  3013. },
  3014. {
  3015. {"name", "smoothstep"},
  3016. {"description", "Smoothstep scheduler - smooth transition function"}
  3017. },
  3018. {
  3019. {"name", "sgm_uniform"},
  3020. {"description", "SGM uniform scheduler - uniform noise schedule"}
  3021. },
  3022. {
  3023. {"name", "simple"},
  3024. {"description", "Simple scheduler - basic linear schedule"}
  3025. },
  3026. {
  3027. {"name", "default"},
  3028. {"description", "Use model's default scheduler"}
  3029. }
  3030. }}
  3031. };
  3032. sendJsonResponse(res, schedulers);
  3033. } catch (const std::exception& e) {
  3034. sendErrorResponse(res, std::string("Failed to get schedulers: ") + e.what(), 500);
  3035. }
  3036. }
  3037. void Server::handleParameters(const httplib::Request& /*req*/, httplib::Response& res) {
  3038. try {
  3039. nlohmann::json parameters = {
  3040. {"parameters", {
  3041. {
  3042. {"name", "prompt"},
  3043. {"type", "string"},
  3044. {"required", true},
  3045. {"description", "Text prompt for image generation"},
  3046. {"min_length", 1},
  3047. {"max_length", 10000},
  3048. {"example", "a beautiful landscape with mountains"}
  3049. },
  3050. {
  3051. {"name", "negative_prompt"},
  3052. {"type", "string"},
  3053. {"required", false},
  3054. {"description", "Negative prompt to guide generation away from"},
  3055. {"min_length", 0},
  3056. {"max_length", 10000},
  3057. {"example", "blurry, low quality, distorted"}
  3058. },
  3059. {
  3060. {"name", "width"},
  3061. {"type", "integer"},
  3062. {"required", false},
  3063. {"description", "Image width in pixels"},
  3064. {"min", 64},
  3065. {"max", 2048},
  3066. {"multiple_of", 64},
  3067. {"default", 512}
  3068. },
  3069. {
  3070. {"name", "height"},
  3071. {"type", "integer"},
  3072. {"required", false},
  3073. {"description", "Image height in pixels"},
  3074. {"min", 64},
  3075. {"max", 2048},
  3076. {"multiple_of", 64},
  3077. {"default", 512}
  3078. },
  3079. {
  3080. {"name", "steps"},
  3081. {"type", "integer"},
  3082. {"required", false},
  3083. {"description", "Number of diffusion steps"},
  3084. {"min", 1},
  3085. {"max", 150},
  3086. {"default", 20}
  3087. },
  3088. {
  3089. {"name", "cfg_scale"},
  3090. {"type", "number"},
  3091. {"required", false},
  3092. {"description", "Classifier-Free Guidance scale"},
  3093. {"min", 1.0},
  3094. {"max", 30.0},
  3095. {"default", 7.5}
  3096. },
  3097. {
  3098. {"name", "seed"},
  3099. {"type", "string|integer"},
  3100. {"required", false},
  3101. {"description", "Seed for generation (use 'random' for random seed)"},
  3102. {"example", "42"}
  3103. },
  3104. {
  3105. {"name", "sampling_method"},
  3106. {"type", "string"},
  3107. {"required", false},
  3108. {"description", "Sampling method to use"},
  3109. {"enum", {"euler", "euler_a", "heun", "dpm2", "dpm++2s_a", "dpm++2m", "dpm++2mv2", "ipndm", "ipndm_v", "lcm", "ddim_trailing", "tcd", "default"}},
  3110. {"default", "default"}
  3111. },
  3112. {
  3113. {"name", "scheduler"},
  3114. {"type", "string"},
  3115. {"required", false},
  3116. {"description", "Scheduler to use"},
  3117. {"enum", {"discrete", "karras", "exponential", "ays", "gits", "smoothstep", "sgm_uniform", "simple", "default"}},
  3118. {"default", "default"}
  3119. },
  3120. {
  3121. {"name", "batch_count"},
  3122. {"type", "integer"},
  3123. {"required", false},
  3124. {"description", "Number of images to generate"},
  3125. {"min", 1},
  3126. {"max", 100},
  3127. {"default", 1}
  3128. },
  3129. {
  3130. {"name", "strength"},
  3131. {"type", "number"},
  3132. {"required", false},
  3133. {"description", "Strength for img2img (0.0-1.0)"},
  3134. {"min", 0.0},
  3135. {"max", 1.0},
  3136. {"default", 0.75}
  3137. },
  3138. {
  3139. {"name", "control_strength"},
  3140. {"type", "number"},
  3141. {"required", false},
  3142. {"description", "ControlNet strength (0.0-1.0)"},
  3143. {"min", 0.0},
  3144. {"max", 1.0},
  3145. {"default", 0.9}
  3146. }
  3147. }},
  3148. {"openapi", {
  3149. {"version", "3.0.0"},
  3150. {"info", {
  3151. {"title", "Stable Diffusion REST API"},
  3152. {"version", "1.0.0"},
  3153. {"description", "Comprehensive REST API for stable-diffusion.cpp functionality"}
  3154. }},
  3155. {"components", {
  3156. {"schemas", {
  3157. {"GenerationRequest", {
  3158. {"type", "object"},
  3159. {"required", {"prompt"}},
  3160. {"properties", {
  3161. {"prompt", {{"type", "string"}, {"description", "Text prompt for generation"}}},
  3162. {"negative_prompt", {{"type", "string"}, {"description", "Negative prompt"}}},
  3163. {"width", {{"type", "integer"}, {"minimum", 64}, {"maximum", 2048}, {"default", 512}}},
  3164. {"height", {{"type", "integer"}, {"minimum", 64}, {"maximum", 2048}, {"default", 512}}},
  3165. {"steps", {{"type", "integer"}, {"minimum", 1}, {"maximum", 150}, {"default", 20}}},
  3166. {"cfg_scale", {{"type", "number"}, {"minimum", 1.0}, {"maximum", 30.0}, {"default", 7.5}}}
  3167. }}
  3168. }}
  3169. }}
  3170. }}
  3171. }}
  3172. };
  3173. sendJsonResponse(res, parameters);
  3174. } catch (const std::exception& e) {
  3175. sendErrorResponse(res, std::string("Failed to get parameters: ") + e.what(), 500);
  3176. }
  3177. }
  3178. void Server::handleValidate(const httplib::Request& req, httplib::Response& res) {
  3179. std::string requestId = generateRequestId();
  3180. try {
  3181. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  3182. // Validate parameters
  3183. auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
  3184. nlohmann::json response = {
  3185. {"request_id", requestId},
  3186. {"valid", isValid},
  3187. {"message", isValid ? "Parameters are valid" : errorMessage},
  3188. {"errors", isValid ? nlohmann::json::array() : nlohmann::json::array({errorMessage})}
  3189. };
  3190. sendJsonResponse(res, response, isValid ? 200 : 400);
  3191. } catch (const nlohmann::json::parse_error& e) {
  3192. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  3193. } catch (const std::exception& e) {
  3194. sendErrorResponse(res, std::string("Validation failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  3195. }
  3196. }
  3197. void Server::handleEstimate(const httplib::Request& req, httplib::Response& res) {
  3198. std::string requestId = generateRequestId();
  3199. try {
  3200. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  3201. // Validate parameters first
  3202. auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
  3203. if (!isValid) {
  3204. sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
  3205. return;
  3206. }
  3207. // Create a temporary request to estimate
  3208. GenerationRequest genRequest;
  3209. genRequest.prompt = requestJson["prompt"];
  3210. genRequest.width = requestJson.value("width", 512);
  3211. genRequest.height = requestJson.value("height", 512);
  3212. genRequest.batchCount = requestJson.value("batch_count", 1);
  3213. genRequest.steps = requestJson.value("steps", 20);
  3214. genRequest.diffusionFlashAttn = requestJson.value("diffusion_flash_attn", false);
  3215. genRequest.controlNetPath = requestJson.value("control_net_path", "");
  3216. if (requestJson.contains("sampling_method")) {
  3217. genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
  3218. }
  3219. // Calculate estimates
  3220. uint64_t estimatedTime = estimateGenerationTime(genRequest);
  3221. size_t estimatedMemory = estimateMemoryUsage(genRequest);
  3222. nlohmann::json response = {
  3223. {"request_id", requestId},
  3224. {"estimated_time_seconds", estimatedTime / 1000},
  3225. {"estimated_memory_mb", estimatedMemory / (1024 * 1024)},
  3226. {"parameters", {
  3227. {"resolution", std::to_string(genRequest.width) + "x" + std::to_string(genRequest.height)},
  3228. {"steps", genRequest.steps},
  3229. {"batch_count", genRequest.batchCount},
  3230. {"sampling_method", samplingMethodToString(genRequest.samplingMethod)}
  3231. }}
  3232. };
  3233. sendJsonResponse(res, response);
  3234. } catch (const nlohmann::json::parse_error& e) {
  3235. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  3236. } catch (const std::exception& e) {
  3237. sendErrorResponse(res, std::string("Estimation failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  3238. }
  3239. }
  3240. void Server::handleConfig(const httplib::Request& /*req*/, httplib::Response& res) {
  3241. std::string requestId = generateRequestId();
  3242. try {
  3243. // Get current configuration
  3244. nlohmann::json config = {
  3245. {"request_id", requestId},
  3246. {"config", {
  3247. {"server", {
  3248. {"host", m_host},
  3249. {"port", m_port},
  3250. {"max_concurrent_generations", 1}
  3251. }},
  3252. {"generation", {
  3253. {"default_width", 512},
  3254. {"default_height", 512},
  3255. {"default_steps", 20},
  3256. {"default_cfg_scale", 7.5},
  3257. {"max_batch_count", 100},
  3258. {"max_steps", 150},
  3259. {"max_resolution", 2048}
  3260. }},
  3261. {"rate_limiting", {
  3262. {"requests_per_minute", 60},
  3263. {"enabled", true}
  3264. }}
  3265. }}
  3266. };
  3267. sendJsonResponse(res, config);
  3268. } catch (const std::exception& e) {
  3269. sendErrorResponse(res, std::string("Config operation failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  3270. }
  3271. }
  3272. void Server::handleSystem(const httplib::Request& /*req*/, httplib::Response& res) {
  3273. try {
  3274. nlohmann::json system = {
  3275. {"system", {
  3276. {"version", "1.0.0"},
  3277. {"build", "stable-diffusion.cpp-rest"},
  3278. {"uptime", std::chrono::duration_cast<std::chrono::seconds>(
  3279. std::chrono::steady_clock::now().time_since_epoch()).count()},
  3280. {"capabilities", {
  3281. {"text2img", true},
  3282. {"img2img", true},
  3283. {"controlnet", true},
  3284. {"batch_generation", true},
  3285. {"parameter_validation", true},
  3286. {"estimation", true}
  3287. }},
  3288. {"supported_formats", {
  3289. {"input", {"png", "jpg", "jpeg", "webp"}},
  3290. {"output", {"png", "jpg", "jpeg", "webp"}}
  3291. }},
  3292. {"limits", {
  3293. {"max_resolution", 2048},
  3294. {"max_steps", 150},
  3295. {"max_batch_count", 100},
  3296. {"max_prompt_length", 10000}
  3297. }}
  3298. }},
  3299. {"hardware", {
  3300. {"cpu_threads", std::thread::hardware_concurrency()}
  3301. }}
  3302. };
  3303. sendJsonResponse(res, system);
  3304. } catch (const std::exception& e) {
  3305. sendErrorResponse(res, std::string("System info failed: ") + e.what(), 500);
  3306. }
  3307. }
  3308. void Server::handleSystemRestart(const httplib::Request& /*req*/, httplib::Response& res) {
  3309. try {
  3310. nlohmann::json response = {
  3311. {"message", "Server restart initiated. The server will shut down gracefully and exit. Please use a process manager to automatically restart it."},
  3312. {"status", "restarting"}
  3313. };
  3314. sendJsonResponse(res, response);
  3315. // Schedule server stop after response is sent
  3316. // Using a separate thread to allow the response to be sent first
  3317. std::thread([this]() {
  3318. std::this_thread::sleep_for(std::chrono::seconds(1));
  3319. this->stop();
  3320. // Exit with code 42 to signal restart intent to process manager
  3321. std::exit(42);
  3322. }).detach();
  3323. } catch (const std::exception& e) {
  3324. sendErrorResponse(res, std::string("Restart failed: ") + e.what(), 500);
  3325. }
  3326. }
  3327. // Helper methods for model management
  3328. nlohmann::json Server::getModelCapabilities(ModelType type) {
  3329. nlohmann::json capabilities = nlohmann::json::object();
  3330. switch (type) {
  3331. case ModelType::CHECKPOINT:
  3332. capabilities = {
  3333. {"text2img", true},
  3334. {"img2img", true},
  3335. {"inpainting", true},
  3336. {"outpainting", true},
  3337. {"controlnet", true},
  3338. {"lora", true},
  3339. {"vae", true},
  3340. {"sampling_methods", {"euler", "euler_a", "heun", "dpm2", "dpm++2s_a", "dpm++2m", "dpm++2mv2", "ipndm", "ipndm_v", "lcm", "ddim_trailing", "tcd"}},
  3341. {"schedulers", {"discrete", "karras", "exponential", "ays", "gits", "smoothstep", "sgm_uniform", "simple"}},
  3342. {"recommended_resolution", "512x512"},
  3343. {"max_resolution", "2048x2048"},
  3344. {"supports_batch", true}
  3345. };
  3346. break;
  3347. case ModelType::LORA:
  3348. capabilities = {
  3349. {"text2img", true},
  3350. {"img2img", true},
  3351. {"inpainting", true},
  3352. {"controlnet", false},
  3353. {"lora", true},
  3354. {"vae", false},
  3355. {"requires_checkpoint", true},
  3356. {"strength_range", {0.0, 2.0}},
  3357. {"recommended_strength", 1.0}
  3358. };
  3359. break;
  3360. case ModelType::CONTROLNET:
  3361. capabilities = {
  3362. {"text2img", false},
  3363. {"img2img", true},
  3364. {"inpainting", true},
  3365. {"controlnet", true},
  3366. {"requires_checkpoint", true},
  3367. {"control_modes", {"canny", "depth", "pose", "scribble", "hed", "mlsd", "normal", "seg"}},
  3368. {"strength_range", {0.0, 1.0}},
  3369. {"recommended_strength", 0.9}
  3370. };
  3371. break;
  3372. case ModelType::VAE:
  3373. capabilities = {
  3374. {"text2img", false},
  3375. {"img2img", false},
  3376. {"inpainting", false},
  3377. {"vae", true},
  3378. {"requires_checkpoint", true},
  3379. {"encoding", true},
  3380. {"decoding", true},
  3381. {"precision", {"fp16", "fp32"}}
  3382. };
  3383. break;
  3384. case ModelType::EMBEDDING:
  3385. capabilities = {
  3386. {"text2img", true},
  3387. {"img2img", true},
  3388. {"inpainting", true},
  3389. {"embedding", true},
  3390. {"requires_checkpoint", true},
  3391. {"token_count", 1},
  3392. {"compatible_with", {"checkpoint", "lora"}}
  3393. };
  3394. break;
  3395. case ModelType::TAESD:
  3396. capabilities = {
  3397. {"text2img", false},
  3398. {"img2img", false},
  3399. {"inpainting", false},
  3400. {"vae", true},
  3401. {"requires_checkpoint", true},
  3402. {"fast_decoding", true},
  3403. {"real_time", true},
  3404. {"precision", {"fp16", "fp32"}}
  3405. };
  3406. break;
  3407. case ModelType::ESRGAN:
  3408. capabilities = {
  3409. {"text2img", false},
  3410. {"img2img", false},
  3411. {"inpainting", false},
  3412. {"upscaling", true},
  3413. {"scale_factors", {2, 4}},
  3414. {"models", {"ESRGAN", "RealESRGAN", "SwinIR"}},
  3415. {"supports_alpha", false}
  3416. };
  3417. break;
  3418. default:
  3419. capabilities = {
  3420. {"text2img", false},
  3421. {"img2img", false},
  3422. {"inpainting", false},
  3423. {"capabilities", {}}
  3424. };
  3425. break;
  3426. }
  3427. return capabilities;
  3428. }
  3429. nlohmann::json Server::getModelTypeStatistics() {
  3430. if (!m_modelManager) return nlohmann::json::object();
  3431. nlohmann::json stats = nlohmann::json::object();
  3432. auto allModels = m_modelManager->getAllModels();
  3433. // Initialize counters for each type
  3434. std::map<ModelType, int> typeCounts;
  3435. std::map<ModelType, int> loadedCounts;
  3436. std::map<ModelType, size_t> sizeByType;
  3437. for (const auto& pair : allModels) {
  3438. ModelType type = pair.second.type;
  3439. typeCounts[type]++;
  3440. if (pair.second.isLoaded) {
  3441. loadedCounts[type]++;
  3442. }
  3443. sizeByType[type] += pair.second.fileSize;
  3444. }
  3445. // Build statistics JSON
  3446. for (const auto& count : typeCounts) {
  3447. std::string typeName = ModelManager::modelTypeToString(count.first);
  3448. stats[typeName] = {
  3449. {"total_count", count.second},
  3450. {"loaded_count", loadedCounts[count.first]},
  3451. {"total_size_bytes", sizeByType[count.first]},
  3452. {"total_size_mb", sizeByType[count.first] / (1024.0 * 1024.0)},
  3453. {"average_size_mb", count.second > 0 ? (sizeByType[count.first] / (1024.0 * 1024.0)) / count.second : 0.0}
  3454. };
  3455. }
  3456. return stats;
  3457. }
  3458. // Additional helper methods for model management
  3459. nlohmann::json Server::getModelCompatibility(const ModelManager::ModelInfo& modelInfo) {
  3460. nlohmann::json compatibility = {
  3461. {"is_compatible", true},
  3462. {"compatibility_score", 100},
  3463. {"issues", nlohmann::json::array()},
  3464. {"warnings", nlohmann::json::array()},
  3465. {"requirements", {
  3466. {"min_memory_mb", 1024},
  3467. {"recommended_memory_mb", 2048},
  3468. {"supported_formats", {"safetensors", "ckpt", "gguf"}},
  3469. {"required_dependencies", {}}
  3470. }}
  3471. };
  3472. // Check for specific compatibility issues based on model type
  3473. if (modelInfo.type == ModelType::LORA) {
  3474. compatibility["requirements"]["required_dependencies"] = {"checkpoint"};
  3475. } else if (modelInfo.type == ModelType::CONTROLNET) {
  3476. compatibility["requirements"]["required_dependencies"] = {"checkpoint"};
  3477. } else if (modelInfo.type == ModelType::VAE) {
  3478. compatibility["requirements"]["required_dependencies"] = {"checkpoint"};
  3479. }
  3480. return compatibility;
  3481. }
  3482. nlohmann::json Server::getModelRequirements(ModelType type) {
  3483. nlohmann::json requirements = {
  3484. {"min_memory_mb", 1024},
  3485. {"recommended_memory_mb", 2048},
  3486. {"min_disk_space_mb", 1024},
  3487. {"supported_formats", {"safetensors", "ckpt", "gguf"}},
  3488. {"required_dependencies", nlohmann::json::array()},
  3489. {"optional_dependencies", nlohmann::json::array()},
  3490. {"system_requirements", {
  3491. {"cpu_cores", 4},
  3492. {"cpu_architecture", "x86_64"},
  3493. {"os", "Linux/Windows/macOS"},
  3494. {"gpu_memory_mb", 2048},
  3495. {"gpu_compute_capability", "3.5+"}
  3496. }}
  3497. };
  3498. switch (type) {
  3499. case ModelType::CHECKPOINT:
  3500. requirements["min_memory_mb"] = 2048;
  3501. requirements["recommended_memory_mb"] = 4096;
  3502. requirements["min_disk_space_mb"] = 2048;
  3503. requirements["supported_formats"] = {"safetensors", "ckpt", "gguf"};
  3504. break;
  3505. case ModelType::LORA:
  3506. requirements["min_memory_mb"] = 512;
  3507. requirements["recommended_memory_mb"] = 1024;
  3508. requirements["min_disk_space_mb"] = 100;
  3509. requirements["supported_formats"] = {"safetensors", "ckpt"};
  3510. requirements["required_dependencies"] = {"checkpoint"};
  3511. break;
  3512. case ModelType::CONTROLNET:
  3513. requirements["min_memory_mb"] = 1024;
  3514. requirements["recommended_memory_mb"] = 2048;
  3515. requirements["min_disk_space_mb"] = 500;
  3516. requirements["supported_formats"] = {"safetensors", "pth"};
  3517. requirements["required_dependencies"] = {"checkpoint"};
  3518. break;
  3519. case ModelType::VAE:
  3520. requirements["min_memory_mb"] = 512;
  3521. requirements["recommended_memory_mb"] = 1024;
  3522. requirements["min_disk_space_mb"] = 200;
  3523. requirements["supported_formats"] = {"safetensors", "pt", "ckpt", "gguf"};
  3524. requirements["required_dependencies"] = {"checkpoint"};
  3525. break;
  3526. case ModelType::EMBEDDING:
  3527. requirements["min_memory_mb"] = 64;
  3528. requirements["recommended_memory_mb"] = 256;
  3529. requirements["min_disk_space_mb"] = 10;
  3530. requirements["supported_formats"] = {"safetensors", "pt"};
  3531. requirements["required_dependencies"] = {"checkpoint"};
  3532. break;
  3533. case ModelType::TAESD:
  3534. requirements["min_memory_mb"] = 256;
  3535. requirements["recommended_memory_mb"] = 512;
  3536. requirements["min_disk_space_mb"] = 100;
  3537. requirements["supported_formats"] = {"safetensors", "pth", "gguf"};
  3538. requirements["required_dependencies"] = {"checkpoint"};
  3539. break;
  3540. case ModelType::ESRGAN:
  3541. requirements["min_memory_mb"] = 1024;
  3542. requirements["recommended_memory_mb"] = 2048;
  3543. requirements["min_disk_space_mb"] = 500;
  3544. requirements["supported_formats"] = {"pth", "pt"};
  3545. requirements["optional_dependencies"] = {"checkpoint"};
  3546. break;
  3547. default:
  3548. break;
  3549. }
  3550. return requirements;
  3551. }
  3552. nlohmann::json Server::getRecommendedUsage(ModelType type) {
  3553. nlohmann::json usage = {
  3554. {"text2img", false},
  3555. {"img2img", false},
  3556. {"inpainting", false},
  3557. {"controlnet", false},
  3558. {"lora", false},
  3559. {"vae", false},
  3560. {"recommended_resolution", "512x512"},
  3561. {"recommended_steps", 20},
  3562. {"recommended_cfg_scale", 7.5},
  3563. {"recommended_batch_size", 1}
  3564. };
  3565. switch (type) {
  3566. case ModelType::CHECKPOINT:
  3567. usage = {
  3568. {"text2img", true},
  3569. {"img2img", true},
  3570. {"inpainting", true},
  3571. {"controlnet", true},
  3572. {"lora", true},
  3573. {"vae", true},
  3574. {"recommended_resolution", "512x512"},
  3575. {"recommended_steps", 20},
  3576. {"recommended_cfg_scale", 7.5},
  3577. {"recommended_batch_size", 1}
  3578. };
  3579. break;
  3580. case ModelType::LORA:
  3581. usage = {
  3582. {"text2img", true},
  3583. {"img2img", true},
  3584. {"inpainting", true},
  3585. {"controlnet", false},
  3586. {"lora", true},
  3587. {"vae", false},
  3588. {"recommended_strength", 1.0},
  3589. {"recommended_usage", "Style transfer, character customization"}
  3590. };
  3591. break;
  3592. case ModelType::CONTROLNET:
  3593. usage = {
  3594. {"text2img", false},
  3595. {"img2img", true},
  3596. {"inpainting", true},
  3597. {"controlnet", true},
  3598. {"lora", false},
  3599. {"vae", false},
  3600. {"recommended_strength", 0.9},
  3601. {"recommended_usage", "Precise control over output"}
  3602. };
  3603. break;
  3604. case ModelType::VAE:
  3605. usage = {
  3606. {"text2img", false},
  3607. {"img2img", false},
  3608. {"inpainting", false},
  3609. {"controlnet", false},
  3610. {"lora", false},
  3611. {"vae", true},
  3612. {"recommended_usage", "Improved encoding/decoding quality"}
  3613. };
  3614. break;
  3615. case ModelType::EMBEDDING:
  3616. usage = {
  3617. {"text2img", true},
  3618. {"img2img", true},
  3619. {"inpainting", true},
  3620. {"controlnet", false},
  3621. {"lora", false},
  3622. {"vae", false},
  3623. {"embedding", true},
  3624. {"recommended_usage", "Concept control, style words"}
  3625. };
  3626. break;
  3627. case ModelType::TAESD:
  3628. usage = {
  3629. {"text2img", false},
  3630. {"img2img", false},
  3631. {"inpainting", false},
  3632. {"controlnet", false},
  3633. {"lora", false},
  3634. {"vae", true},
  3635. {"recommended_usage", "Real-time decoding"}
  3636. };
  3637. break;
  3638. case ModelType::ESRGAN:
  3639. usage = {
  3640. {"text2img", false},
  3641. {"img2img", false},
  3642. {"inpainting", false},
  3643. {"controlnet", false},
  3644. {"lora", false},
  3645. {"vae", false},
  3646. {"upscaling", true},
  3647. {"recommended_usage", "Image upscaling and quality enhancement"}
  3648. };
  3649. break;
  3650. default:
  3651. break;
  3652. }
  3653. return usage;
  3654. }
  3655. std::string Server::getModelTypeFromDirectoryName(const std::string& dirName) {
  3656. if (dirName == "stable-diffusion" || dirName == "checkpoints") {
  3657. return "checkpoint";
  3658. } else if (dirName == "lora") {
  3659. return "lora";
  3660. } else if (dirName == "controlnet") {
  3661. return "controlnet";
  3662. } else if (dirName == "vae") {
  3663. return "vae";
  3664. } else if (dirName == "taesd") {
  3665. return "taesd";
  3666. } else if (dirName == "esrgan" || dirName == "upscaler") {
  3667. return "esrgan";
  3668. } else if (dirName == "embeddings" || dirName == "textual-inversion") {
  3669. return "embedding";
  3670. } else {
  3671. return "unknown";
  3672. }
  3673. }
  3674. std::string Server::getDirectoryDescription(const std::string& dirName) {
  3675. if (dirName == "stable-diffusion" || dirName == "checkpoints") {
  3676. return "Main stable diffusion model files";
  3677. } else if (dirName == "lora") {
  3678. return "LoRA adapter models for style transfer";
  3679. } else if (dirName == "controlnet") {
  3680. return "ControlNet models for precise control";
  3681. } else if (dirName == "vae") {
  3682. return "VAE models for improved encoding/decoding";
  3683. } else if (dirName == "taesd") {
  3684. return "TAESD models for real-time decoding";
  3685. } else if (dirName == "esrgan" || dirName == "upscaler") {
  3686. return "ESRGAN models for image upscaling";
  3687. } else if (dirName == "embeddings" || dirName == "textual-inversion") {
  3688. return "Text embeddings for concept control";
  3689. } else {
  3690. return "Unknown model directory";
  3691. }
  3692. }
  3693. nlohmann::json Server::getDirectoryContents(const std::string& dirPath) {
  3694. nlohmann::json contents = nlohmann::json::array();
  3695. try {
  3696. if (std::filesystem::exists(dirPath) && std::filesystem::is_directory(dirPath)) {
  3697. for (const auto& entry : std::filesystem::directory_iterator(dirPath)) {
  3698. if (entry.is_regular_file()) {
  3699. nlohmann::json file = {
  3700. {"name", entry.path().filename().string()},
  3701. {"path", entry.path().string()},
  3702. {"size", std::filesystem::file_size(entry.path())},
  3703. {"size_mb", std::filesystem::file_size(entry.path()) / (1024.0 * 1024.0)},
  3704. {"last_modified", std::chrono::duration_cast<std::chrono::seconds>(
  3705. std::filesystem::last_write_time(entry.path()).time_since_epoch()).count()}
  3706. };
  3707. contents.push_back(file);
  3708. }
  3709. }
  3710. }
  3711. } catch (const std::exception& e) {
  3712. // Return empty array if directory access fails
  3713. }
  3714. return contents;
  3715. }
  3716. nlohmann::json Server::getLargestModel(const std::map<std::string, ModelManager::ModelInfo>& allModels) {
  3717. nlohmann::json largest = nlohmann::json::object();
  3718. size_t maxSize = 0;
  3719. std::string largestName;
  3720. for (const auto& pair : allModels) {
  3721. if (pair.second.fileSize > maxSize) {
  3722. maxSize = pair.second.fileSize;
  3723. largestName = pair.second.name;
  3724. }
  3725. }
  3726. if (!largestName.empty()) {
  3727. largest = {
  3728. {"name", largestName},
  3729. {"size", maxSize},
  3730. {"size_mb", maxSize / (1024.0 * 1024.0)},
  3731. {"type", ModelManager::modelTypeToString(allModels.at(largestName).type)}
  3732. };
  3733. }
  3734. return largest;
  3735. }
  3736. nlohmann::json Server::getSmallestModel(const std::map<std::string, ModelManager::ModelInfo>& allModels) {
  3737. nlohmann::json smallest = nlohmann::json::object();
  3738. size_t minSize = SIZE_MAX;
  3739. std::string smallestName;
  3740. for (const auto& pair : allModels) {
  3741. if (pair.second.fileSize < minSize) {
  3742. minSize = pair.second.fileSize;
  3743. smallestName = pair.second.name;
  3744. }
  3745. }
  3746. if (!smallestName.empty()) {
  3747. smallest = {
  3748. {"name", smallestName},
  3749. {"size", minSize},
  3750. {"size_mb", minSize / (1024.0 * 1024.0)},
  3751. {"type", ModelManager::modelTypeToString(allModels.at(smallestName).type)}
  3752. };
  3753. }
  3754. return smallest;
  3755. }
  3756. nlohmann::json Server::validateModelFile(const std::string& modelPath, const std::string& modelType) {
  3757. nlohmann::json validation = {
  3758. {"is_valid", false},
  3759. {"errors", nlohmann::json::array()},
  3760. {"warnings", nlohmann::json::array()},
  3761. {"file_info", nlohmann::json::object()},
  3762. {"compatibility", nlohmann::json::object()},
  3763. {"recommendations", nlohmann::json::array()}
  3764. };
  3765. try {
  3766. if (!std::filesystem::exists(modelPath)) {
  3767. validation["errors"].push_back("File does not exist");
  3768. return validation;
  3769. }
  3770. if (!std::filesystem::is_regular_file(modelPath)) {
  3771. validation["errors"].push_back("Path is not a regular file");
  3772. return validation;
  3773. }
  3774. // Check file extension
  3775. std::string extension = std::filesystem::path(modelPath).extension().string();
  3776. if (extension.empty()) {
  3777. validation["errors"].push_back("Missing file extension");
  3778. return validation;
  3779. }
  3780. // Remove dot and convert to lowercase
  3781. if (extension[0] == '.') {
  3782. extension = extension.substr(1);
  3783. }
  3784. std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower);
  3785. // Validate extension based on model type
  3786. ModelType type = ModelManager::stringToModelType(modelType);
  3787. bool validExtension = false;
  3788. switch (type) {
  3789. case ModelType::CHECKPOINT:
  3790. validExtension = (extension == "safetensors" || extension == "ckpt" || extension == "gguf");
  3791. break;
  3792. case ModelType::LORA:
  3793. validExtension = (extension == "safetensors" || extension == "ckpt");
  3794. break;
  3795. case ModelType::CONTROLNET:
  3796. validExtension = (extension == "safetensors" || extension == "pth");
  3797. break;
  3798. case ModelType::VAE:
  3799. validExtension = (extension == "safetensors" || extension == "pt" || extension == "ckpt" || extension == "gguf");
  3800. break;
  3801. case ModelType::EMBEDDING:
  3802. validExtension = (extension == "safetensors" || extension == "pt");
  3803. break;
  3804. case ModelType::TAESD:
  3805. validExtension = (extension == "safetensors" || extension == "pth" || extension == "gguf");
  3806. break;
  3807. case ModelType::ESRGAN:
  3808. validExtension = (extension == "pth" || extension == "pt");
  3809. break;
  3810. default:
  3811. break;
  3812. }
  3813. if (!validExtension) {
  3814. validation["errors"].push_back("Invalid file extension for model type: " + extension);
  3815. }
  3816. // Check file size
  3817. size_t fileSize = std::filesystem::file_size(modelPath);
  3818. if (fileSize == 0) {
  3819. validation["errors"].push_back("File is empty");
  3820. } else if (fileSize > 8ULL * 1024 * 1024 * 1024) { // 8GB
  3821. validation["warnings"].push_back("Very large file may cause performance issues");
  3822. }
  3823. // Build file info
  3824. validation["file_info"] = {
  3825. {"path", modelPath},
  3826. {"size", fileSize},
  3827. {"size_mb", fileSize / (1024.0 * 1024.0)},
  3828. {"extension", extension},
  3829. {"last_modified", std::chrono::duration_cast<std::chrono::seconds>(
  3830. std::filesystem::last_write_time(modelPath).time_since_epoch()).count()}
  3831. };
  3832. // Check compatibility
  3833. validation["compatibility"] = {
  3834. {"extension_valid", validExtension},
  3835. {"size_appropriate", fileSize <= 4ULL * 1024 * 1024 * 1024}, // 4GB
  3836. {"recommended_format", "safetensors"}
  3837. };
  3838. // Add recommendations
  3839. if (!validExtension) {
  3840. validation["recommendations"].push_back("Convert to SafeTensors format for better security and performance");
  3841. }
  3842. if (fileSize > 2ULL * 1024 * 1024 * 1024) { // 2GB
  3843. validation["recommendations"].push_back("Consider using a smaller model for better performance");
  3844. }
  3845. // If no errors found, mark as valid
  3846. if (validation["errors"].empty()) {
  3847. validation["is_valid"] = true;
  3848. }
  3849. } catch (const std::exception& e) {
  3850. validation["errors"].push_back("Validation failed: " + std::string(e.what()));
  3851. }
  3852. return validation;
  3853. }
  3854. nlohmann::json Server::checkModelCompatibility(const ModelManager::ModelInfo& modelInfo, const std::string& systemInfo) {
  3855. nlohmann::json compatibility = {
  3856. {"is_compatible", true},
  3857. {"compatibility_score", 100},
  3858. {"issues", nlohmann::json::array()},
  3859. {"warnings", nlohmann::json::array()},
  3860. {"requirements", nlohmann::json::object()},
  3861. {"recommendations", nlohmann::json::array()},
  3862. {"system_info", nlohmann::json::object()}
  3863. };
  3864. // Check system compatibility
  3865. if (systemInfo == "auto") {
  3866. compatibility["system_info"] = {
  3867. {"cpu_cores", std::thread::hardware_concurrency()}
  3868. };
  3869. }
  3870. // Check model-specific compatibility issues
  3871. if (modelInfo.type == ModelType::CHECKPOINT) {
  3872. if (modelInfo.fileSize > 4ULL * 1024 * 1024 * 1024) { // 4GB
  3873. compatibility["warnings"].push_back("Large checkpoint model may require significant memory");
  3874. compatibility["compatibility_score"] = 80;
  3875. }
  3876. if (modelInfo.fileSize < 500 * 1024 * 1024) { // 500MB
  3877. compatibility["warnings"].push_back("Small checkpoint model may have limited capabilities");
  3878. compatibility["compatibility_score"] = 85;
  3879. }
  3880. } else if (modelInfo.type == ModelType::LORA) {
  3881. if (modelInfo.fileSize > 500 * 1024 * 1024) { // 500MB
  3882. compatibility["warnings"].push_back("Large LoRA may impact performance");
  3883. compatibility["compatibility_score"] = 75;
  3884. }
  3885. }
  3886. return compatibility;
  3887. }
  3888. nlohmann::json Server::calculateSpecificRequirements(const std::string& modelType, const std::string& resolution, const std::string& batchSize) {
  3889. (void)modelType; // Suppress unused parameter warning
  3890. nlohmann::json specific = {
  3891. {"memory_requirements", nlohmann::json::object()},
  3892. {"performance_impact", nlohmann::json::object()},
  3893. {"quality_expectations", nlohmann::json::object()}
  3894. };
  3895. // Parse resolution
  3896. int width = 512, height = 512;
  3897. try {
  3898. size_t xPos = resolution.find('x');
  3899. if (xPos != std::string::npos) {
  3900. width = std::stoi(resolution.substr(0, xPos));
  3901. height = std::stoi(resolution.substr(xPos + 1));
  3902. }
  3903. } catch (...) {
  3904. // Use defaults if parsing fails
  3905. }
  3906. // Parse batch size
  3907. int batch = 1;
  3908. try {
  3909. batch = std::stoi(batchSize);
  3910. } catch (...) {
  3911. // Use default if parsing fails
  3912. }
  3913. // Calculate memory requirements based on resolution and batch
  3914. size_t pixels = width * height;
  3915. size_t baseMemory = 1024 * 1024 * 1024; // 1GB base
  3916. size_t resolutionMemory = (pixels * 4) / (512 * 512); // Scale based on 512x512
  3917. size_t batchMemory = (batch - 1) * baseMemory * 0.5; // Additional memory for batch
  3918. specific["memory_requirements"] = {
  3919. {"base_memory_mb", baseMemory / (1024 * 1024)},
  3920. {"resolution_memory_mb", resolutionMemory / (1024 * 1024)},
  3921. {"batch_memory_mb", batchMemory / (1024 * 1024)},
  3922. {"total_memory_mb", (baseMemory + resolutionMemory + batchMemory) / (1024 * 1024)}
  3923. };
  3924. // Calculate performance impact
  3925. double performanceFactor = 1.0;
  3926. if (pixels > 512 * 512) {
  3927. performanceFactor = 1.5;
  3928. }
  3929. if (batch > 1) {
  3930. performanceFactor *= 1.2;
  3931. }
  3932. specific["performance_impact"] = {
  3933. {"resolution_factor", pixels > 512 * 512 ? 1.5 : 1.0},
  3934. {"batch_factor", batch > 1 ? 1.2 : 1.0},
  3935. {"overall_factor", performanceFactor}
  3936. };
  3937. return specific;
  3938. }
  3939. // Enhanced model management endpoint implementations
  3940. void Server::handleModelInfo(const httplib::Request& req, httplib::Response& res) {
  3941. std::string requestId = generateRequestId();
  3942. try {
  3943. if (!m_modelManager) {
  3944. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  3945. return;
  3946. }
  3947. // Extract model ID from URL path
  3948. std::string modelId = req.matches[1].str();
  3949. if (modelId.empty()) {
  3950. sendErrorResponse(res, "Missing model ID", 400, "MISSING_MODEL_ID", requestId);
  3951. return;
  3952. }
  3953. // Get model information
  3954. auto modelInfo = m_modelManager->getModelInfo(modelId);
  3955. if (modelInfo.name.empty()) {
  3956. sendErrorResponse(res, "Model not found", 404, "MODEL_NOT_FOUND", requestId);
  3957. return;
  3958. }
  3959. // Build comprehensive model information
  3960. nlohmann::json response = {
  3961. {"model", {
  3962. {"name", modelInfo.name},
  3963. {"path", modelInfo.path},
  3964. {"type", ModelManager::modelTypeToString(modelInfo.type)},
  3965. {"is_loaded", modelInfo.isLoaded},
  3966. {"file_size", modelInfo.fileSize},
  3967. {"file_size_mb", modelInfo.fileSize / (1024.0 * 1024.0)},
  3968. {"description", modelInfo.description},
  3969. {"metadata", modelInfo.metadata},
  3970. {"capabilities", getModelCapabilities(modelInfo.type)},
  3971. {"compatibility", getModelCompatibility(modelInfo)},
  3972. {"requirements", getModelRequirements(modelInfo.type)},
  3973. {"recommended_usage", getRecommendedUsage(modelInfo.type)},
  3974. {"last_modified", std::chrono::duration_cast<std::chrono::seconds>(
  3975. modelInfo.modifiedAt.time_since_epoch()).count()}
  3976. }},
  3977. {"request_id", requestId}
  3978. };
  3979. sendJsonResponse(res, response);
  3980. } catch (const std::exception& e) {
  3981. sendErrorResponse(res, std::string("Failed to get model info: ") + e.what(), 500, "MODEL_INFO_ERROR", requestId);
  3982. }
  3983. }
  3984. void Server::handleLoadModelById(const httplib::Request& req, httplib::Response& res) {
  3985. std::string requestId = generateRequestId();
  3986. try {
  3987. if (!m_modelManager) {
  3988. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  3989. return;
  3990. }
  3991. // Extract model ID from URL path (could be hash or name)
  3992. std::string modelIdentifier = req.matches[1].str();
  3993. if (modelIdentifier.empty()) {
  3994. sendErrorResponse(res, "Missing model identifier", 400, "MISSING_MODEL_ID", requestId);
  3995. return;
  3996. }
  3997. // Try to find by hash first (if it looks like a hash - 10+ hex chars)
  3998. std::string modelId = modelIdentifier;
  3999. if (modelIdentifier.length() >= 10 &&
  4000. std::all_of(modelIdentifier.begin(), modelIdentifier.end(),
  4001. [](char c) { return std::isxdigit(c); })) {
  4002. std::string foundName = m_modelManager->findModelByHash(modelIdentifier);
  4003. if (!foundName.empty()) {
  4004. modelId = foundName;
  4005. std::cout << "Resolved hash " << modelIdentifier << " to model: " << modelId << std::endl;
  4006. }
  4007. }
  4008. // Parse optional parameters from request body
  4009. nlohmann::json requestJson;
  4010. if (!req.body.empty()) {
  4011. try {
  4012. requestJson = nlohmann::json::parse(req.body);
  4013. } catch (const nlohmann::json::parse_error& e) {
  4014. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  4015. return;
  4016. }
  4017. }
  4018. // Unload previous model if one is loaded
  4019. std::string previousModel;
  4020. {
  4021. std::lock_guard<std::mutex> lock(m_currentModelMutex);
  4022. previousModel = m_currentlyLoadedModel;
  4023. }
  4024. if (!previousModel.empty() && previousModel != modelId) {
  4025. std::cout << "Unloading previous model: " << previousModel << std::endl;
  4026. m_modelManager->unloadModel(previousModel);
  4027. }
  4028. // Load model
  4029. bool success = m_modelManager->loadModel(modelId);
  4030. if (success) {
  4031. // Update currently loaded model
  4032. {
  4033. std::lock_guard<std::mutex> lock(m_currentModelMutex);
  4034. m_currentlyLoadedModel = modelId;
  4035. }
  4036. auto modelInfo = m_modelManager->getModelInfo(modelId);
  4037. nlohmann::json response = {
  4038. {"status", "success"},
  4039. {"model", {
  4040. {"name", modelInfo.name},
  4041. {"path", modelInfo.path},
  4042. {"type", ModelManager::modelTypeToString(modelInfo.type)},
  4043. {"is_loaded", modelInfo.isLoaded}
  4044. }},
  4045. {"request_id", requestId}
  4046. };
  4047. sendJsonResponse(res, response);
  4048. } else {
  4049. sendErrorResponse(res, "Failed to load model", 400, "MODEL_LOAD_FAILED", requestId);
  4050. }
  4051. } catch (const std::exception& e) {
  4052. sendErrorResponse(res, std::string("Model load failed: ") + e.what(), 500, "MODEL_LOAD_ERROR", requestId);
  4053. }
  4054. }
  4055. void Server::handleUnloadModelById(const httplib::Request& req, httplib::Response& res) {
  4056. std::string requestId = generateRequestId();
  4057. try {
  4058. if (!m_modelManager) {
  4059. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  4060. return;
  4061. }
  4062. // Extract model ID from URL path
  4063. std::string modelId = req.matches[1].str();
  4064. if (modelId.empty()) {
  4065. sendErrorResponse(res, "Missing model ID", 400, "MISSING_MODEL_ID", requestId);
  4066. return;
  4067. }
  4068. // Unload model
  4069. bool success = m_modelManager->unloadModel(modelId);
  4070. if (success) {
  4071. // Clear currently loaded model if it matches
  4072. {
  4073. std::lock_guard<std::mutex> lock(m_currentModelMutex);
  4074. if (m_currentlyLoadedModel == modelId) {
  4075. m_currentlyLoadedModel = "";
  4076. }
  4077. }
  4078. nlohmann::json response = {
  4079. {"status", "success"},
  4080. {"model", {
  4081. {"name", modelId},
  4082. {"is_loaded", false}
  4083. }},
  4084. {"request_id", requestId}
  4085. };
  4086. sendJsonResponse(res, response);
  4087. } else {
  4088. sendErrorResponse(res, "Failed to unload model or model not found", 404, "MODEL_UNLOAD_FAILED", requestId);
  4089. }
  4090. } catch (const std::exception& e) {
  4091. sendErrorResponse(res, std::string("Model unload failed: ") + e.what(), 500, "MODEL_UNLOAD_ERROR", requestId);
  4092. }
  4093. }
  4094. void Server::handleModelTypes(const httplib::Request& /*req*/, httplib::Response& res) {
  4095. std::string requestId = generateRequestId();
  4096. try {
  4097. nlohmann::json types = {
  4098. {"model_types", {
  4099. {
  4100. {"type", "checkpoint"},
  4101. {"description", "Main stable diffusion model files for text-to-image, image-to-image, and inpainting"},
  4102. {"extensions", {"safetensors", "ckpt", "gguf"}},
  4103. {"capabilities", {"text2img", "img2img", "inpainting", "controlnet", "lora", "vae"}},
  4104. {"recommended_for", "General purpose image generation"}
  4105. },
  4106. {
  4107. {"type", "lora"},
  4108. {"description", "LoRA adapter models for style transfer and character customization"},
  4109. {"extensions", {"safetensors", "ckpt"}},
  4110. {"capabilities", {"style_transfer", "character_customization"}},
  4111. {"requires", {"checkpoint"}},
  4112. {"recommended_for", "Style modification and character-specific generation"}
  4113. },
  4114. {
  4115. {"type", "controlnet"},
  4116. {"description", "ControlNet models for precise control over output composition"},
  4117. {"extensions", {"safetensors", "pth"}},
  4118. {"capabilities", {"precise_control", "composition_control"}},
  4119. {"requires", {"checkpoint"}},
  4120. {"recommended_for", "Precise control over image generation"}
  4121. },
  4122. {
  4123. {"type", "vae"},
  4124. {"description", "VAE models for improved encoding and decoding quality"},
  4125. {"extensions", {"safetensors", "pt", "ckpt", "gguf"}},
  4126. {"capabilities", {"encoding", "decoding", "quality_improvement"}},
  4127. {"requires", {"checkpoint"}},
  4128. {"recommended_for", "Improved image quality and encoding"}
  4129. },
  4130. {
  4131. {"type", "embedding"},
  4132. {"description", "Text embeddings for concept control and style words"},
  4133. {"extensions", {"safetensors", "pt"}},
  4134. {"capabilities", {"concept_control", "style_words"}},
  4135. {"requires", {"checkpoint"}},
  4136. {"recommended_for", "Concept control and specific styles"}
  4137. },
  4138. {
  4139. {"type", "taesd"},
  4140. {"description", "TAESD models for real-time decoding"},
  4141. {"extensions", {"safetensors", "pth", "gguf"}},
  4142. {"capabilities", {"real_time_decoding", "fast_preview"}},
  4143. {"requires", {"checkpoint"}},
  4144. {"recommended_for", "Real-time applications and fast previews"}
  4145. },
  4146. {
  4147. {"type", "esrgan"},
  4148. {"description", "ESRGAN models for image upscaling and enhancement"},
  4149. {"extensions", {"pth", "pt"}},
  4150. {"capabilities", {"upscaling", "enhancement", "quality_improvement"}},
  4151. {"recommended_for", "Image upscaling and quality enhancement"}
  4152. }
  4153. }},
  4154. {"request_id", requestId}
  4155. };
  4156. sendJsonResponse(res, types);
  4157. } catch (const std::exception& e) {
  4158. sendErrorResponse(res, std::string("Failed to get model types: ") + e.what(), 500, "MODEL_TYPES_ERROR", requestId);
  4159. }
  4160. }
  4161. void Server::handleModelDirectories(const httplib::Request& /*req*/, httplib::Response& res) {
  4162. std::string requestId = generateRequestId();
  4163. try {
  4164. if (!m_modelManager) {
  4165. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  4166. return;
  4167. }
  4168. std::string modelsDir = m_modelManager->getModelsDirectory();
  4169. nlohmann::json directories = nlohmann::json::array();
  4170. // Define expected model directories
  4171. std::vector<std::string> modelDirs = {
  4172. "stable-diffusion", "checkpoints", "lora", "controlnet",
  4173. "vae", "taesd", "esrgan", "embeddings"
  4174. };
  4175. for (const auto& dirName : modelDirs) {
  4176. std::string dirPath = modelsDir + "/" + dirName;
  4177. std::string type = getModelTypeFromDirectoryName(dirName);
  4178. std::string description = getDirectoryDescription(dirName);
  4179. nlohmann::json dirInfo = {
  4180. {"name", dirName},
  4181. {"path", dirPath},
  4182. {"type", type},
  4183. {"description", description},
  4184. {"exists", std::filesystem::exists(dirPath) && std::filesystem::is_directory(dirPath)},
  4185. {"contents", getDirectoryContents(dirPath)}
  4186. };
  4187. directories.push_back(dirInfo);
  4188. }
  4189. nlohmann::json response = {
  4190. {"models_directory", modelsDir},
  4191. {"directories", directories},
  4192. {"request_id", requestId}
  4193. };
  4194. sendJsonResponse(res, response);
  4195. } catch (const std::exception& e) {
  4196. sendErrorResponse(res, std::string("Failed to get model directories: ") + e.what(), 500, "MODEL_DIRECTORIES_ERROR", requestId);
  4197. }
  4198. }
  4199. void Server::handleRefreshModels(const httplib::Request& /*req*/, httplib::Response& res) {
  4200. std::string requestId = generateRequestId();
  4201. try {
  4202. if (!m_modelManager) {
  4203. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  4204. return;
  4205. }
  4206. // Force refresh of model cache
  4207. bool success = m_modelManager->scanModelsDirectory();
  4208. if (success) {
  4209. nlohmann::json response = {
  4210. {"status", "success"},
  4211. {"message", "Model cache refreshed successfully"},
  4212. {"models_found", m_modelManager->getAvailableModelsCount()},
  4213. {"models_loaded", m_modelManager->getLoadedModelsCount()},
  4214. {"models_directory", m_modelManager->getModelsDirectory()},
  4215. {"request_id", requestId}
  4216. };
  4217. sendJsonResponse(res, response);
  4218. } else {
  4219. sendErrorResponse(res, "Failed to refresh model cache", 500, "MODEL_REFRESH_FAILED", requestId);
  4220. }
  4221. } catch (const std::exception& e) {
  4222. sendErrorResponse(res, std::string("Model refresh failed: ") + e.what(), 500, "MODEL_REFRESH_ERROR", requestId);
  4223. }
  4224. }
  4225. void Server::handleHashModels(const httplib::Request& req, httplib::Response& res) {
  4226. std::string requestId = generateRequestId();
  4227. try {
  4228. if (!m_generationQueue || !m_modelManager) {
  4229. sendErrorResponse(res, "Services not available", 500, "SERVICE_UNAVAILABLE", requestId);
  4230. return;
  4231. }
  4232. // Parse request body
  4233. nlohmann::json requestJson;
  4234. if (!req.body.empty()) {
  4235. requestJson = nlohmann::json::parse(req.body);
  4236. }
  4237. HashRequest hashReq;
  4238. hashReq.id = requestId;
  4239. hashReq.forceRehash = requestJson.value("force_rehash", false);
  4240. if (requestJson.contains("models") && requestJson["models"].is_array()) {
  4241. for (const auto& model : requestJson["models"]) {
  4242. hashReq.modelNames.push_back(model.get<std::string>());
  4243. }
  4244. }
  4245. // Enqueue hash request
  4246. auto future = m_generationQueue->enqueueHashRequest(hashReq);
  4247. nlohmann::json response = {
  4248. {"request_id", requestId},
  4249. {"status", "queued"},
  4250. {"message", "Hash job queued successfully"},
  4251. {"models_to_hash", hashReq.modelNames.empty() ? "all_unhashed" : std::to_string(hashReq.modelNames.size())}
  4252. };
  4253. sendJsonResponse(res, response, 202);
  4254. } catch (const nlohmann::json::parse_error& e) {
  4255. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  4256. } catch (const std::exception& e) {
  4257. sendErrorResponse(res, std::string("Hash request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  4258. }
  4259. }
  4260. void Server::handleConvertModel(const httplib::Request& req, httplib::Response& res) {
  4261. std::string requestId = generateRequestId();
  4262. try {
  4263. if (!m_generationQueue || !m_modelManager) {
  4264. sendErrorResponse(res, "Services not available", 500, "SERVICE_UNAVAILABLE", requestId);
  4265. return;
  4266. }
  4267. // Parse request body
  4268. nlohmann::json requestJson;
  4269. try {
  4270. requestJson = nlohmann::json::parse(req.body);
  4271. } catch (const nlohmann::json::parse_error& e) {
  4272. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  4273. return;
  4274. }
  4275. // Validate required fields
  4276. if (!requestJson.contains("model_name")) {
  4277. sendErrorResponse(res, "Missing required field: model_name", 400, "MISSING_FIELD", requestId);
  4278. return;
  4279. }
  4280. if (!requestJson.contains("quantization_type")) {
  4281. sendErrorResponse(res, "Missing required field: quantization_type", 400, "MISSING_FIELD", requestId);
  4282. return;
  4283. }
  4284. std::string modelName = requestJson["model_name"].get<std::string>();
  4285. std::string quantizationType = requestJson["quantization_type"].get<std::string>();
  4286. // Validate quantization type
  4287. const std::vector<std::string> validTypes = {"f32", "f16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "q2_K", "q3_K", "q4_K"};
  4288. if (std::find(validTypes.begin(), validTypes.end(), quantizationType) == validTypes.end()) {
  4289. sendErrorResponse(res, "Invalid quantization_type. Valid types: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K",
  4290. 400, "INVALID_QUANTIZATION_TYPE", requestId);
  4291. return;
  4292. }
  4293. // Get model info to find the full path
  4294. auto modelInfo = m_modelManager->getModelInfo(modelName);
  4295. if (modelInfo.name.empty()) {
  4296. sendErrorResponse(res, "Model not found: " + modelName, 404, "MODEL_NOT_FOUND", requestId);
  4297. return;
  4298. }
  4299. // Check if model is already GGUF
  4300. if (modelInfo.fullPath.find(".gguf") != std::string::npos) {
  4301. sendErrorResponse(res, "Model is already in GGUF format. Cannot convert GGUF to GGUF.",
  4302. 400, "ALREADY_GGUF", requestId);
  4303. return;
  4304. }
  4305. // Build output path
  4306. std::string outputPath = requestJson.value("output_path", "");
  4307. if (outputPath.empty()) {
  4308. // Generate default output path: model_name_quantization.gguf
  4309. namespace fs = std::filesystem;
  4310. fs::path inputPath(modelInfo.fullPath);
  4311. std::string baseName = inputPath.stem().string();
  4312. std::string outputDir = inputPath.parent_path().string();
  4313. outputPath = outputDir + "/" + baseName + "_" + quantizationType + ".gguf";
  4314. }
  4315. // Create conversion request
  4316. ConversionRequest convReq;
  4317. convReq.id = requestId;
  4318. convReq.modelName = modelName;
  4319. convReq.modelPath = modelInfo.fullPath;
  4320. convReq.outputPath = outputPath;
  4321. convReq.quantizationType = quantizationType;
  4322. // Enqueue conversion request
  4323. auto future = m_generationQueue->enqueueConversionRequest(convReq);
  4324. nlohmann::json response = {
  4325. {"request_id", requestId},
  4326. {"status", "queued"},
  4327. {"message", "Model conversion queued successfully"},
  4328. {"model_name", modelName},
  4329. {"input_path", modelInfo.fullPath},
  4330. {"output_path", outputPath},
  4331. {"quantization_type", quantizationType}
  4332. };
  4333. sendJsonResponse(res, response, 202);
  4334. } catch (const std::exception& e) {
  4335. sendErrorResponse(res, std::string("Conversion request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  4336. }
  4337. }
  4338. void Server::handleModelStats(const httplib::Request& /*req*/, httplib::Response& res) {
  4339. std::string requestId = generateRequestId();
  4340. try {
  4341. if (!m_modelManager) {
  4342. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  4343. return;
  4344. }
  4345. auto allModels = m_modelManager->getAllModels();
  4346. nlohmann::json response = {
  4347. {"statistics", {
  4348. {"total_models", allModels.size()},
  4349. {"loaded_models", m_modelManager->getLoadedModelsCount()},
  4350. {"available_models", m_modelManager->getAvailableModelsCount()},
  4351. {"model_types", getModelTypeStatistics()},
  4352. {"largest_model", getLargestModel(allModels)},
  4353. {"smallest_model", getSmallestModel(allModels)}
  4354. }},
  4355. {"request_id", requestId}
  4356. };
  4357. sendJsonResponse(res, response);
  4358. } catch (const std::exception& e) {
  4359. sendErrorResponse(res, std::string("Failed to get model stats: ") + e.what(), 500, "MODEL_STATS_ERROR", requestId);
  4360. }
  4361. }
  4362. void Server::handleBatchModels(const httplib::Request& req, httplib::Response& res) {
  4363. std::string requestId = generateRequestId();
  4364. try {
  4365. if (!m_modelManager) {
  4366. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  4367. return;
  4368. }
  4369. // Parse JSON request body
  4370. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  4371. if (!requestJson.contains("operation") || !requestJson["operation"].is_string()) {
  4372. sendErrorResponse(res, "Missing or invalid 'operation' field", 400, "INVALID_OPERATION", requestId);
  4373. return;
  4374. }
  4375. if (!requestJson.contains("models") || !requestJson["models"].is_array()) {
  4376. sendErrorResponse(res, "Missing or invalid 'models' field", 400, "INVALID_MODELS", requestId);
  4377. return;
  4378. }
  4379. std::string operation = requestJson["operation"];
  4380. nlohmann::json models = requestJson["models"];
  4381. nlohmann::json results = nlohmann::json::array();
  4382. for (const auto& model : models) {
  4383. if (!model.is_string()) {
  4384. results.push_back({
  4385. {"model", model},
  4386. {"success", false},
  4387. {"error", "Invalid model name"}
  4388. });
  4389. continue;
  4390. }
  4391. std::string modelName = model;
  4392. bool success = false;
  4393. std::string error = "";
  4394. if (operation == "load") {
  4395. success = m_modelManager->loadModel(modelName);
  4396. if (!success) error = "Failed to load model";
  4397. } else if (operation == "unload") {
  4398. success = m_modelManager->unloadModel(modelName);
  4399. if (!success) error = "Failed to unload model";
  4400. } else {
  4401. error = "Unsupported operation";
  4402. }
  4403. results.push_back({
  4404. {"model", modelName},
  4405. {"success", success},
  4406. {"error", error.empty() ? nlohmann::json(nullptr) : nlohmann::json(error)}
  4407. });
  4408. }
  4409. nlohmann::json response = {
  4410. {"operation", operation},
  4411. {"results", results},
  4412. {"successful_count", std::count_if(results.begin(), results.end(),
  4413. [](const nlohmann::json& result) { return result["success"].get<bool>(); })},
  4414. {"failed_count", std::count_if(results.begin(), results.end(),
  4415. [](const nlohmann::json& result) { return !result["success"].get<bool>(); })},
  4416. {"request_id", requestId}
  4417. };
  4418. sendJsonResponse(res, response);
  4419. } catch (const nlohmann::json::parse_error& e) {
  4420. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  4421. } catch (const std::exception& e) {
  4422. sendErrorResponse(res, std::string("Batch operation failed: ") + e.what(), 500, "BATCH_OPERATION_ERROR", requestId);
  4423. }
  4424. }
  4425. void Server::handleValidateModel(const httplib::Request& req, httplib::Response& res) {
  4426. std::string requestId = generateRequestId();
  4427. try {
  4428. // Parse JSON request body
  4429. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  4430. if (!requestJson.contains("model_path") || !requestJson["model_path"].is_string()) {
  4431. sendErrorResponse(res, "Missing or invalid 'model_path' field", 400, "INVALID_MODEL_PATH", requestId);
  4432. return;
  4433. }
  4434. std::string modelPath = requestJson["model_path"];
  4435. std::string modelType = requestJson.value("model_type", "checkpoint");
  4436. // Validate model file
  4437. nlohmann::json validation = validateModelFile(modelPath, modelType);
  4438. nlohmann::json response = {
  4439. {"validation", validation},
  4440. {"request_id", requestId}
  4441. };
  4442. sendJsonResponse(res, response);
  4443. } catch (const nlohmann::json::parse_error& e) {
  4444. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  4445. } catch (const std::exception& e) {
  4446. sendErrorResponse(res, std::string("Model validation failed: ") + e.what(), 500, "MODEL_VALIDATION_ERROR", requestId);
  4447. }
  4448. }
  4449. void Server::handleCheckCompatibility(const httplib::Request& req, httplib::Response& res) {
  4450. std::string requestId = generateRequestId();
  4451. try {
  4452. if (!m_modelManager) {
  4453. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  4454. return;
  4455. }
  4456. // Parse JSON request body
  4457. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  4458. if (!requestJson.contains("model_name") || !requestJson["model_name"].is_string()) {
  4459. sendErrorResponse(res, "Missing or invalid 'model_name' field", 400, "INVALID_MODEL_NAME", requestId);
  4460. return;
  4461. }
  4462. std::string modelName = requestJson["model_name"];
  4463. std::string systemInfo = requestJson.value("system_info", "auto");
  4464. // Get model information
  4465. auto modelInfo = m_modelManager->getModelInfo(modelName);
  4466. if (modelInfo.name.empty()) {
  4467. sendErrorResponse(res, "Model not found", 404, "MODEL_NOT_FOUND", requestId);
  4468. return;
  4469. }
  4470. // Check compatibility
  4471. nlohmann::json compatibility = checkModelCompatibility(modelInfo, systemInfo);
  4472. nlohmann::json response = {
  4473. {"model", modelName},
  4474. {"compatibility", compatibility},
  4475. {"request_id", requestId}
  4476. };
  4477. sendJsonResponse(res, response);
  4478. } catch (const nlohmann::json::parse_error& e) {
  4479. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  4480. } catch (const std::exception& e) {
  4481. sendErrorResponse(res, std::string("Compatibility check failed: ") + e.what(), 500, "COMPATIBILITY_CHECK_ERROR", requestId);
  4482. }
  4483. }
  4484. void Server::handleModelRequirements(const httplib::Request& req, httplib::Response& res) {
  4485. std::string requestId = generateRequestId();
  4486. try {
  4487. // Parse JSON request body
  4488. nlohmann::json requestJson = nlohmann::json::parse(req.body);
  4489. std::string modelType = requestJson.value("model_type", "checkpoint");
  4490. std::string resolution = requestJson.value("resolution", "512x512");
  4491. std::string batchSize = requestJson.value("batch_size", "1");
  4492. // Calculate specific requirements
  4493. nlohmann::json requirements = calculateSpecificRequirements(modelType, resolution, batchSize);
  4494. // Get general requirements for model type
  4495. ModelType type = ModelManager::stringToModelType(modelType);
  4496. nlohmann::json generalRequirements = getModelRequirements(type);
  4497. nlohmann::json response = {
  4498. {"model_type", modelType},
  4499. {"configuration", {
  4500. {"resolution", resolution},
  4501. {"batch_size", batchSize}
  4502. }},
  4503. {"specific_requirements", requirements},
  4504. {"general_requirements", generalRequirements},
  4505. {"request_id", requestId}
  4506. };
  4507. sendJsonResponse(res, response);
  4508. } catch (const nlohmann::json::parse_error& e) {
  4509. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  4510. } catch (const std::exception& e) {
  4511. sendErrorResponse(res, std::string("Requirements calculation failed: ") + e.what(), 500, "REQUIREMENTS_ERROR", requestId);
  4512. }
  4513. }
  4514. void Server::serverThreadFunction(const std::string& host, int port) {
  4515. try {
  4516. std::cout << "Server thread starting, attempting to bind to " << host << ":" << port << std::endl;
  4517. // Check if port is available before attempting to bind
  4518. std::cout << "Checking if port " << port << " is available..." << std::endl;
  4519. // Try to create a test socket to check if port is in use
  4520. int test_socket = socket(AF_INET, SOCK_STREAM, 0);
  4521. if (test_socket >= 0) {
  4522. // Set SO_REUSEADDR to avoid TIME_WAIT issues
  4523. int opt = 1;
  4524. if (setsockopt(test_socket, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) {
  4525. std::cerr << "Warning: Failed to set SO_REUSEADDR on test socket: " << strerror(errno) << std::endl;
  4526. }
  4527. // Also set SO_REUSEPORT if available (for better concurrent binding handling)
  4528. #ifdef SO_REUSEPORT
  4529. int reuseport = 1;
  4530. if (setsockopt(test_socket, SOL_SOCKET, SO_REUSEPORT, &reuseport, sizeof(reuseport)) < 0) {
  4531. std::cerr << "Warning: Failed to set SO_REUSEPORT on test socket: " << strerror(errno) << std::endl;
  4532. }
  4533. #endif
  4534. struct sockaddr_in addr;
  4535. addr.sin_family = AF_INET;
  4536. addr.sin_port = htons(port);
  4537. addr.sin_addr.s_addr = INADDR_ANY;
  4538. // Try to bind to the port
  4539. if (bind(test_socket, (struct sockaddr*)&addr, sizeof(addr)) < 0) {
  4540. close(test_socket);
  4541. std::cerr << "ERROR: Port " << port << " is already in use! Cannot start server." << std::endl;
  4542. std::cerr << "This could be due to:" << std::endl;
  4543. std::cerr << "1. Another instance is already running on this port" << std::endl;
  4544. std::cerr << "2. A previous instance crashed and the socket is in TIME_WAIT state" << std::endl;
  4545. std::cerr << "3. The port is being used by another application" << std::endl;
  4546. std::cerr << std::endl;
  4547. std::cerr << "Solutions:" << std::endl;
  4548. std::cerr << "- Wait 30-60 seconds for TIME_WAIT to expire (if server crashed)" << std::endl;
  4549. std::cerr << "- Kill any existing processes: sudo lsof -ti:" << port << " | xargs kill -9" << std::endl;
  4550. std::cerr << "- Use a different port with -p <port>" << std::endl;
  4551. m_isRunning.store(false);
  4552. m_startupFailed.store(true);
  4553. return;
  4554. }
  4555. close(test_socket);
  4556. }
  4557. std::cout << "Port " << port << " is available, proceeding with server startup..." << std::endl;
  4558. std::cout << "Calling listen()..." << std::endl;
  4559. // We need to set m_isRunning after successful bind but before blocking
  4560. // cpp-httplib doesn't provide a callback, so we set it optimistically
  4561. // and clear it if listen() returns false
  4562. m_isRunning.store(true);
  4563. bool listenResult = m_httpServer->listen(host.c_str(), port);
  4564. std::cout << "listen() returned: " << (listenResult ? "true" : "false") << std::endl;
  4565. // If we reach here, server has stopped (either normally or due to error)
  4566. m_isRunning.store(false);
  4567. if (!listenResult) {
  4568. std::cerr << "Server listen failed! This usually means port is in use or permission denied." << std::endl;
  4569. }
  4570. } catch (const std::exception& e) {
  4571. std::cerr << "Exception in server thread: " << e.what() << std::endl;
  4572. m_isRunning.store(false);
  4573. }
  4574. }