server.cpp 146 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821
  1. #include "server.h"
  2. #include "model_manager.h"
  3. #include "generation_queue.h"
  4. #include "utils.h"
  5. #include <httplib.h>
  6. #include <nlohmann/json.hpp>
  7. #include <iostream>
  8. #include <sstream>
  9. #include <fstream>
  10. #include <chrono>
  11. #include <random>
  12. #include <iomanip>
  13. #include <algorithm>
  14. #include <thread>
  15. #include <filesystem>
  16. // Include stb_image for loading images (implementation is in generation_queue.cpp)
  17. #include "../stable-diffusion.cpp-src/thirdparty/stb_image.h"
  18. #include <sys/socket.h>
  19. #include <netinet/in.h>
  20. #include <unistd.h>
  21. #include <arpa/inet.h>
  22. using json = nlohmann::json;
  23. Server::Server(ModelManager* modelManager, GenerationQueue* generationQueue, const std::string& outputDir, const std::string& uiDir)
  24. : m_modelManager(modelManager)
  25. , m_generationQueue(generationQueue)
  26. , m_isRunning(false)
  27. , m_startupFailed(false)
  28. , m_port(8080)
  29. , m_outputDir(outputDir)
  30. , m_uiDir(uiDir)
  31. {
  32. m_httpServer = std::make_unique<httplib::Server>();
  33. }
  34. Server::~Server() {
  35. stop();
  36. }
  37. bool Server::start(const std::string& host, int port) {
  38. if (m_isRunning.load()) {
  39. return false;
  40. }
  41. m_host = host;
  42. m_port = port;
  43. // Validate host and port
  44. if (host.empty() || (port < 1 || port > 65535)) {
  45. return false;
  46. }
  47. // Set up CORS headers
  48. setupCORS();
  49. // Register API endpoints
  50. registerEndpoints();
  51. // Reset startup flags
  52. m_startupFailed.store(false);
  53. // Start server in a separate thread
  54. m_serverThread = std::thread(&Server::serverThreadFunction, this, host, port);
  55. // Wait for server to actually start and bind to the port
  56. // Give more time for server to actually start and bind
  57. for (int i = 0; i < 100; i++) { // Wait up to 10 seconds
  58. std::this_thread::sleep_for(std::chrono::milliseconds(100));
  59. // Check if startup failed early
  60. if (m_startupFailed.load()) {
  61. if (m_serverThread.joinable()) {
  62. m_serverThread.join();
  63. }
  64. return false;
  65. }
  66. if (m_isRunning.load()) {
  67. // Give it a moment more to ensure server is fully started
  68. std::this_thread::sleep_for(std::chrono::milliseconds(500));
  69. if (m_isRunning.load()) {
  70. return true;
  71. }
  72. }
  73. }
  74. if (m_isRunning.load()) {
  75. return true;
  76. } else {
  77. if (m_serverThread.joinable()) {
  78. m_serverThread.join();
  79. }
  80. return false;
  81. }
  82. }
  83. void Server::stop() {
  84. // Use atomic check to ensure thread safety
  85. bool wasRunning = m_isRunning.exchange(false);
  86. if (!wasRunning) {
  87. return; // Already stopped
  88. }
  89. if (m_httpServer) {
  90. m_httpServer->stop();
  91. // Give the server a moment to stop the blocking listen call
  92. std::this_thread::sleep_for(std::chrono::milliseconds(100));
  93. // If server thread is still running, try to force unblock the listen call
  94. // by making a quick connection to the server port
  95. if (m_serverThread.joinable()) {
  96. try {
  97. // Create a quick connection to interrupt the blocking listen
  98. httplib::Client client("127.0.0.1", m_port);
  99. client.set_connection_timeout(0, 500000); // 0.5 seconds
  100. client.set_read_timeout(0, 500000); // 0.5 seconds
  101. client.set_write_timeout(0, 500000); // 0.5 seconds
  102. auto res = client.Get("/api/health");
  103. // We don't care about the response, just trying to unblock
  104. } catch (...) {
  105. // Ignore any connection errors - we're just trying to unblock
  106. }
  107. }
  108. }
  109. if (m_serverThread.joinable()) {
  110. m_serverThread.join();
  111. }
  112. }
  113. bool Server::isRunning() const {
  114. return m_isRunning.load();
  115. }
  116. void Server::waitForStop() {
  117. if (m_serverThread.joinable()) {
  118. m_serverThread.join();
  119. }
  120. }
  121. void Server::registerEndpoints() {
  122. // Health check endpoint
  123. m_httpServer->Get("/api/health", [this](const httplib::Request& req, httplib::Response& res) {
  124. handleHealthCheck(req, res);
  125. });
  126. // API status endpoint
  127. m_httpServer->Get("/api/status", [this](const httplib::Request& req, httplib::Response& res) {
  128. handleApiStatus(req, res);
  129. });
  130. // Specialized generation endpoints
  131. m_httpServer->Post("/api/generate/text2img", [this](const httplib::Request& req, httplib::Response& res) {
  132. handleText2Img(req, res);
  133. });
  134. m_httpServer->Post("/api/generate/img2img", [this](const httplib::Request& req, httplib::Response& res) {
  135. handleImg2Img(req, res);
  136. });
  137. m_httpServer->Post("/api/generate/controlnet", [this](const httplib::Request& req, httplib::Response& res) {
  138. handleControlNet(req, res);
  139. });
  140. m_httpServer->Post("/api/generate/upscale", [this](const httplib::Request& req, httplib::Response& res) {
  141. handleUpscale(req, res);
  142. });
  143. // Utility endpoints
  144. m_httpServer->Get("/api/samplers", [this](const httplib::Request& req, httplib::Response& res) {
  145. handleSamplers(req, res);
  146. });
  147. m_httpServer->Get("/api/schedulers", [this](const httplib::Request& req, httplib::Response& res) {
  148. handleSchedulers(req, res);
  149. });
  150. m_httpServer->Get("/api/parameters", [this](const httplib::Request& req, httplib::Response& res) {
  151. handleParameters(req, res);
  152. });
  153. m_httpServer->Post("/api/validate", [this](const httplib::Request& req, httplib::Response& res) {
  154. handleValidate(req, res);
  155. });
  156. m_httpServer->Post("/api/estimate", [this](const httplib::Request& req, httplib::Response& res) {
  157. handleEstimate(req, res);
  158. });
  159. m_httpServer->Get("/api/config", [this](const httplib::Request& req, httplib::Response& res) {
  160. handleConfig(req, res);
  161. });
  162. m_httpServer->Get("/api/system", [this](const httplib::Request& req, httplib::Response& res) {
  163. handleSystem(req, res);
  164. });
  165. m_httpServer->Post("/api/system/restart", [this](const httplib::Request& req, httplib::Response& res) {
  166. handleSystemRestart(req, res);
  167. });
  168. // Models list endpoint
  169. m_httpServer->Get("/api/models", [this](const httplib::Request& req, httplib::Response& res) {
  170. handleModelsList(req, res);
  171. });
  172. // Model-specific endpoints
  173. m_httpServer->Get("/api/models/(.*)", [this](const httplib::Request& req, httplib::Response& res) {
  174. handleModelInfo(req, res);
  175. });
  176. m_httpServer->Post("/api/models/(.*)/load", [this](const httplib::Request& req, httplib::Response& res) {
  177. handleLoadModelById(req, res);
  178. });
  179. m_httpServer->Post("/api/models/(.*)/unload", [this](const httplib::Request& req, httplib::Response& res) {
  180. handleUnloadModelById(req, res);
  181. });
  182. // Model management endpoints
  183. m_httpServer->Get("/api/models/types", [this](const httplib::Request& req, httplib::Response& res) {
  184. handleModelTypes(req, res);
  185. });
  186. m_httpServer->Get("/api/models/directories", [this](const httplib::Request& req, httplib::Response& res) {
  187. handleModelDirectories(req, res);
  188. });
  189. m_httpServer->Post("/api/models/refresh", [this](const httplib::Request& req, httplib::Response& res) {
  190. handleRefreshModels(req, res);
  191. });
  192. m_httpServer->Post("/api/models/hash", [this](const httplib::Request& req, httplib::Response& res) {
  193. handleHashModels(req, res);
  194. });
  195. m_httpServer->Post("/api/models/convert", [this](const httplib::Request& req, httplib::Response& res) {
  196. handleConvertModel(req, res);
  197. });
  198. m_httpServer->Get("/api/models/stats", [this](const httplib::Request& req, httplib::Response& res) {
  199. handleModelStats(req, res);
  200. });
  201. m_httpServer->Post("/api/models/batch", [this](const httplib::Request& req, httplib::Response& res) {
  202. handleBatchModels(req, res);
  203. });
  204. // Model validation endpoints
  205. m_httpServer->Post("/api/models/validate", [this](const httplib::Request& req, httplib::Response& res) {
  206. handleValidateModel(req, res);
  207. });
  208. m_httpServer->Post("/api/models/compatible", [this](const httplib::Request& req, httplib::Response& res) {
  209. handleCheckCompatibility(req, res);
  210. });
  211. m_httpServer->Post("/api/models/requirements", [this](const httplib::Request& req, httplib::Response& res) {
  212. handleModelRequirements(req, res);
  213. });
  214. // Queue status endpoint
  215. m_httpServer->Get("/api/queue/status", [this](const httplib::Request& req, httplib::Response& res) {
  216. handleQueueStatus(req, res);
  217. });
  218. // Download job output file endpoint (must be before job status endpoint to match more specific pattern first)
  219. m_httpServer->Get("/api/queue/job/(.*)/output/(.*)", [this](const httplib::Request& req, httplib::Response& res) {
  220. handleDownloadOutput(req, res);
  221. });
  222. // Job status endpoint
  223. m_httpServer->Get("/api/queue/job/(.*)", [this](const httplib::Request& req, httplib::Response& res) {
  224. handleJobStatus(req, res);
  225. });
  226. // Cancel job endpoint
  227. m_httpServer->Post("/api/queue/cancel", [this](const httplib::Request& req, httplib::Response& res) {
  228. handleCancelJob(req, res);
  229. });
  230. // Clear queue endpoint
  231. m_httpServer->Post("/api/queue/clear", [this](const httplib::Request& req, httplib::Response& res) {
  232. handleClearQueue(req, res);
  233. });
  234. // Serve static web UI files if uiDir is configured
  235. if (!m_uiDir.empty() && std::filesystem::exists(m_uiDir)) {
  236. std::cout << "Serving static UI files from: " << m_uiDir << " at /ui" << std::endl;
  237. // Read UI version from version.json if available
  238. std::string uiVersion = "unknown";
  239. std::string versionFilePath = m_uiDir + "/version.json";
  240. if (std::filesystem::exists(versionFilePath)) {
  241. try {
  242. std::ifstream versionFile(versionFilePath);
  243. if (versionFile.is_open()) {
  244. nlohmann::json versionData = nlohmann::json::parse(versionFile);
  245. if (versionData.contains("version")) {
  246. uiVersion = versionData["version"].get<std::string>();
  247. }
  248. versionFile.close();
  249. }
  250. } catch (const std::exception& e) {
  251. std::cerr << "Failed to read UI version: " << e.what() << std::endl;
  252. }
  253. }
  254. std::cout << "UI version: " << uiVersion << std::endl;
  255. // Serve dynamic config.js that provides runtime configuration to the web UI
  256. m_httpServer->Get("/ui/config.js", [this, uiVersion](const httplib::Request& req, httplib::Response& res) {
  257. // Generate JavaScript configuration with current server settings
  258. std::ostringstream configJs;
  259. configJs << "// Auto-generated configuration\n"
  260. << "window.__SERVER_CONFIG__ = {\n"
  261. << " apiUrl: 'http://" << m_host << ":" << m_port << "',\n"
  262. << " apiBasePath: '/api',\n"
  263. << " host: '" << m_host << "',\n"
  264. << " port: " << m_port << ",\n"
  265. << " uiVersion: '" << uiVersion << "'\n"
  266. << "};\n";
  267. // No cache for config.js - always fetch fresh
  268. res.set_header("Cache-Control", "no-cache, no-store, must-revalidate");
  269. res.set_header("Pragma", "no-cache");
  270. res.set_header("Expires", "0");
  271. res.set_content(configJs.str(), "application/javascript");
  272. });
  273. // Set up file request handler for caching static assets
  274. m_httpServer->set_file_request_handler([uiVersion](const httplib::Request& req, httplib::Response& res) {
  275. // Add cache headers based on file type and version
  276. std::string path = req.path;
  277. // For versioned static assets (.js, .css, images), use long cache
  278. if (path.find("/_next/") != std::string::npos ||
  279. path.find(".js") != std::string::npos ||
  280. path.find(".css") != std::string::npos ||
  281. path.find(".png") != std::string::npos ||
  282. path.find(".jpg") != std::string::npos ||
  283. path.find(".svg") != std::string::npos ||
  284. path.find(".ico") != std::string::npos ||
  285. path.find(".woff") != std::string::npos ||
  286. path.find(".woff2") != std::string::npos ||
  287. path.find(".ttf") != std::string::npos) {
  288. // Long cache (1 year) for static assets
  289. res.set_header("Cache-Control", "public, max-age=31536000, immutable");
  290. // Add ETag based on UI version for cache validation
  291. res.set_header("ETag", "\"" + uiVersion + "\"");
  292. // Check If-None-Match for conditional requests
  293. if (req.has_header("If-None-Match")) {
  294. std::string clientETag = req.get_header_value("If-None-Match");
  295. if (clientETag == "\"" + uiVersion + "\"") {
  296. res.status = 304; // Not Modified
  297. return;
  298. }
  299. }
  300. } else if (path.find(".html") != std::string::npos || path == "/ui/" || path == "/ui") {
  301. // HTML files should revalidate but can be cached briefly
  302. res.set_header("Cache-Control", "public, max-age=0, must-revalidate");
  303. res.set_header("ETag", "\"" + uiVersion + "\"");
  304. }
  305. });
  306. // Mount the static file directory at /ui
  307. if (!m_httpServer->set_mount_point("/ui", m_uiDir)) {
  308. std::cerr << "Failed to mount UI directory: " << m_uiDir << std::endl;
  309. }
  310. // Redirect /ui to /ui/ to ensure proper routing
  311. m_httpServer->Get("/ui", [](const httplib::Request& req, httplib::Response& res) {
  312. res.set_redirect("/ui/");
  313. });
  314. }
  315. }
  316. void Server::setupCORS() {
  317. // Use post-routing handler to set CORS headers after the response is generated
  318. // This ensures we don't duplicate headers that may be set by other handlers
  319. m_httpServer->set_post_routing_handler([](const httplib::Request& req, httplib::Response& res) {
  320. // Only add CORS headers if they haven't been set already
  321. if (!res.has_header("Access-Control-Allow-Origin")) {
  322. res.set_header("Access-Control-Allow-Origin", "*");
  323. }
  324. if (!res.has_header("Access-Control-Allow-Methods")) {
  325. res.set_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS");
  326. }
  327. if (!res.has_header("Access-Control-Allow-Headers")) {
  328. res.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
  329. }
  330. });
  331. // Handle OPTIONS requests for CORS preflight (API endpoints only)
  332. m_httpServer->Options("/api/.*", [](const httplib::Request&, httplib::Response& res) {
  333. res.set_header("Access-Control-Allow-Origin", "*");
  334. res.set_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS");
  335. res.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
  336. res.status = 200;
  337. });
  338. }
  339. void Server::handleHealthCheck(const httplib::Request& req, httplib::Response& res) {
  340. try {
  341. json response = {
  342. {"status", "healthy"},
  343. {"timestamp", std::chrono::duration_cast<std::chrono::seconds>(
  344. std::chrono::system_clock::now().time_since_epoch()).count()},
  345. {"version", "1.0.0"}
  346. };
  347. sendJsonResponse(res, response);
  348. } catch (const std::exception& e) {
  349. sendErrorResponse(res, std::string("Health check failed: ") + e.what(), 500);
  350. }
  351. }
  352. void Server::handleApiStatus(const httplib::Request& req, httplib::Response& res) {
  353. try {
  354. json response = {
  355. {"server", {
  356. {"running", m_isRunning.load()},
  357. {"host", m_host},
  358. {"port", m_port}
  359. }},
  360. {"generation_queue", {
  361. {"running", m_generationQueue ? m_generationQueue->isRunning() : false},
  362. {"queue_size", m_generationQueue ? m_generationQueue->getQueueSize() : 0},
  363. {"active_generations", m_generationQueue ? m_generationQueue->getActiveGenerations() : 0}
  364. }},
  365. {"models", {
  366. {"loaded_count", m_modelManager ? m_modelManager->getLoadedModelsCount() : 0},
  367. {"available_count", m_modelManager ? m_modelManager->getAvailableModelsCount() : 0}
  368. }}
  369. };
  370. sendJsonResponse(res, response);
  371. } catch (const std::exception& e) {
  372. sendErrorResponse(res, std::string("Status check failed: ") + e.what(), 500);
  373. }
  374. }
  375. void Server::handleModelsList(const httplib::Request& req, httplib::Response& res) {
  376. std::string requestId = generateRequestId();
  377. try {
  378. if (!m_modelManager) {
  379. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  380. return;
  381. }
  382. // Parse query parameters for enhanced filtering
  383. std::string typeFilter = req.get_param_value("type");
  384. std::string searchQuery = req.get_param_value("search");
  385. std::string sortBy = req.get_param_value("sort_by");
  386. std::string sortOrder = req.get_param_value("sort_order");
  387. std::string dateFilter = req.get_param_value("date");
  388. std::string sizeFilter = req.get_param_value("size");
  389. // Pagination parameters
  390. int page = 1;
  391. int limit = 50;
  392. try {
  393. if (!req.get_param_value("page").empty()) {
  394. page = std::stoi(req.get_param_value("page"));
  395. if (page < 1) page = 1;
  396. }
  397. if (!req.get_param_value("limit").empty()) {
  398. limit = std::stoi(req.get_param_value("limit"));
  399. if (limit < 1) limit = 1;
  400. if (limit > 200) limit = 200; // Max limit to prevent performance issues
  401. }
  402. } catch (const std::exception& e) {
  403. sendErrorResponse(res, "Invalid pagination parameters", 400, "INVALID_PAGINATION", requestId);
  404. return;
  405. }
  406. // Filter parameters
  407. bool includeLoaded = req.get_param_value("loaded") == "true";
  408. bool includeUnloaded = req.get_param_value("unloaded") == "true";
  409. bool includeMetadata = req.get_param_value("include_metadata") == "true";
  410. bool includeThumbnails = req.get_param_value("include_thumbnails") == "true";
  411. // Get all models
  412. auto allModels = m_modelManager->getAllModels();
  413. json models = json::array();
  414. // Apply filters and build response
  415. for (const auto& pair : allModels) {
  416. const auto& modelInfo = pair.second;
  417. // Apply type filter
  418. if (!typeFilter.empty()) {
  419. ModelType filterType = ModelManager::stringToModelType(typeFilter);
  420. if (modelInfo.type != filterType) continue;
  421. }
  422. // Apply loaded/unloaded filters
  423. if (includeLoaded && !modelInfo.isLoaded) continue;
  424. if (includeUnloaded && modelInfo.isLoaded) continue;
  425. // Apply search filter (case-insensitive search in name and description)
  426. if (!searchQuery.empty()) {
  427. std::string searchLower = searchQuery;
  428. std::transform(searchLower.begin(), searchLower.end(), searchLower.begin(), ::tolower);
  429. std::string nameLower = modelInfo.name;
  430. std::transform(nameLower.begin(), nameLower.end(), nameLower.begin(), ::tolower);
  431. std::string descLower = modelInfo.description;
  432. std::transform(descLower.begin(), descLower.end(), descLower.begin(), ::tolower);
  433. if (nameLower.find(searchLower) == std::string::npos &&
  434. descLower.find(searchLower) == std::string::npos) {
  435. continue;
  436. }
  437. }
  438. // Apply date filter (simplified - expects "recent", "old", or YYYY-MM-DD)
  439. if (!dateFilter.empty()) {
  440. auto now = std::filesystem::file_time_type::clock::now();
  441. auto modelTime = modelInfo.modifiedAt;
  442. auto duration = std::chrono::duration_cast<std::chrono::hours>(now - modelTime).count();
  443. if (dateFilter == "recent" && duration > 24 * 7) continue; // Older than 1 week
  444. if (dateFilter == "old" && duration < 24 * 30) continue; // Newer than 1 month
  445. }
  446. // Apply size filter (expects "small", "medium", "large", or size in MB)
  447. if (!sizeFilter.empty()) {
  448. double sizeMB = modelInfo.fileSize / (1024.0 * 1024.0);
  449. if (sizeFilter == "small" && sizeMB > 1024) continue; // > 1GB
  450. if (sizeFilter == "medium" && (sizeMB < 1024 || sizeMB > 4096)) continue; // < 1GB or > 4GB
  451. if (sizeFilter == "large" && sizeMB < 4096) continue; // < 4GB
  452. // Try to parse as specific size in MB
  453. try {
  454. double maxSizeMB = std::stod(sizeFilter);
  455. if (sizeMB > maxSizeMB) continue;
  456. } catch (...) {
  457. // Ignore if parsing fails
  458. }
  459. }
  460. // Build model JSON with only essential information
  461. json modelJson = {
  462. {"name", modelInfo.name},
  463. {"type", ModelManager::modelTypeToString(modelInfo.type)},
  464. {"file_size", modelInfo.fileSize},
  465. {"file_size_mb", modelInfo.fileSize / (1024.0 * 1024.0)},
  466. {"sha256", modelInfo.sha256.empty() ? nullptr : json(modelInfo.sha256)},
  467. {"sha256_short", (modelInfo.sha256.empty() || modelInfo.sha256.length() < 10) ? nullptr : json(modelInfo.sha256.substr(0, 10))}
  468. };
  469. // Add architecture information if available (checkpoints only)
  470. if (!modelInfo.architecture.empty()) {
  471. modelJson["architecture"] = modelInfo.architecture;
  472. modelJson["recommended_vae"] = modelInfo.recommendedVAE.empty() ? nullptr : json(modelInfo.recommendedVAE);
  473. if (modelInfo.recommendedWidth > 0) {
  474. modelJson["recommended_width"] = modelInfo.recommendedWidth;
  475. }
  476. if (modelInfo.recommendedHeight > 0) {
  477. modelJson["recommended_height"] = modelInfo.recommendedHeight;
  478. }
  479. if (modelInfo.recommendedSteps > 0) {
  480. modelJson["recommended_steps"] = modelInfo.recommendedSteps;
  481. }
  482. if (!modelInfo.recommendedSampler.empty()) {
  483. modelJson["recommended_sampler"] = modelInfo.recommendedSampler;
  484. }
  485. if (!modelInfo.requiredModels.empty()) {
  486. modelJson["required_models"] = modelInfo.requiredModels;
  487. }
  488. if (!modelInfo.missingModels.empty()) {
  489. modelJson["missing_models"] = modelInfo.missingModels;
  490. modelJson["has_missing_dependencies"] = true;
  491. } else {
  492. modelJson["has_missing_dependencies"] = false;
  493. }
  494. }
  495. models.push_back(modelJson);
  496. }
  497. // Apply sorting
  498. if (!sortBy.empty()) {
  499. std::sort(models.begin(), models.end(), [&sortBy, &sortOrder](const json& a, const json& b) {
  500. bool ascending = sortOrder != "desc";
  501. if (sortBy == "name") {
  502. return ascending ? a["name"] < b["name"] : a["name"] > b["name"];
  503. } else if (sortBy == "size") {
  504. return ascending ? a["file_size"] < b["file_size"] : a["file_size"] > b["file_size"];
  505. } else if (sortBy == "date") {
  506. return ascending ? a["last_modified"] < b["last_modified"] : a["last_modified"] > b["last_modified"];
  507. } else if (sortBy == "type") {
  508. return ascending ? a["type"] < b["type"] : a["type"] > b["type"];
  509. } else if (sortBy == "loaded") {
  510. return ascending ? a["is_loaded"] < b["is_loaded"] : a["is_loaded"] > b["is_loaded"];
  511. }
  512. return false;
  513. });
  514. }
  515. // Apply pagination
  516. int totalCount = models.size();
  517. int totalPages = (totalCount + limit - 1) / limit;
  518. int startIndex = (page - 1) * limit;
  519. int endIndex = std::min(startIndex + limit, totalCount);
  520. json paginatedModels = json::array();
  521. for (int i = startIndex; i < endIndex; ++i) {
  522. paginatedModels.push_back(models[i]);
  523. }
  524. // Build comprehensive response
  525. json response = {
  526. {"models", paginatedModels},
  527. {"pagination", {
  528. {"page", page},
  529. {"limit", limit},
  530. {"total_count", totalCount},
  531. {"total_pages", totalPages},
  532. {"has_next", page < totalPages},
  533. {"has_prev", page > 1}
  534. }},
  535. {"filters_applied", {
  536. {"type", typeFilter.empty() ? json(nullptr) : json(typeFilter)},
  537. {"search", searchQuery.empty() ? json(nullptr) : json(searchQuery)},
  538. {"date", dateFilter.empty() ? json(nullptr) : json(dateFilter)},
  539. {"size", sizeFilter.empty() ? json(nullptr) : json(sizeFilter)},
  540. {"loaded", includeLoaded ? json(true) : json(nullptr)},
  541. {"unloaded", includeUnloaded ? json(true) : json(nullptr)}
  542. }},
  543. {"sorting", {
  544. {"sort_by", sortBy.empty() ? "name" : json(sortBy)},
  545. {"sort_order", sortOrder.empty() ? "asc" : json(sortOrder)}
  546. }},
  547. {"statistics", {
  548. {"loaded_count", m_modelManager->getLoadedModelsCount()},
  549. {"available_count", m_modelManager->getAvailableModelsCount()}
  550. }},
  551. {"request_id", requestId}
  552. };
  553. sendJsonResponse(res, response);
  554. } catch (const std::exception& e) {
  555. sendErrorResponse(res, std::string("Failed to list models: ") + e.what(), 500, "MODEL_LIST_ERROR", requestId);
  556. }
  557. }
  558. void Server::handleQueueStatus(const httplib::Request& req, httplib::Response& res) {
  559. try {
  560. if (!m_generationQueue) {
  561. sendErrorResponse(res, "Generation queue not available", 500);
  562. return;
  563. }
  564. // Get detailed queue status
  565. auto jobs = m_generationQueue->getQueueStatus();
  566. // Convert jobs to JSON
  567. json jobsJson = json::array();
  568. for (const auto& job : jobs) {
  569. std::string statusStr;
  570. switch (job.status) {
  571. case GenerationStatus::QUEUED: statusStr = "queued"; break;
  572. case GenerationStatus::PROCESSING: statusStr = "processing"; break;
  573. case GenerationStatus::COMPLETED: statusStr = "completed"; break;
  574. case GenerationStatus::FAILED: statusStr = "failed"; break;
  575. }
  576. // Convert time points to timestamps
  577. auto queuedTime = std::chrono::duration_cast<std::chrono::milliseconds>(
  578. job.queuedTime.time_since_epoch()).count();
  579. auto startTime = std::chrono::duration_cast<std::chrono::milliseconds>(
  580. job.startTime.time_since_epoch()).count();
  581. auto endTime = std::chrono::duration_cast<std::chrono::milliseconds>(
  582. job.endTime.time_since_epoch()).count();
  583. jobsJson.push_back({
  584. {"id", job.id},
  585. {"status", statusStr},
  586. {"prompt", job.prompt},
  587. {"queued_time", queuedTime},
  588. {"start_time", startTime > 0 ? json(startTime) : json(nullptr)},
  589. {"end_time", endTime > 0 ? json(endTime) : json(nullptr)},
  590. {"position", job.position},
  591. {"progress", job.progress}
  592. });
  593. }
  594. json response = {
  595. {"queue", {
  596. {"size", m_generationQueue->getQueueSize()},
  597. {"active_generations", m_generationQueue->getActiveGenerations()},
  598. {"running", m_generationQueue->isRunning()},
  599. {"jobs", jobsJson}
  600. }}
  601. };
  602. sendJsonResponse(res, response);
  603. } catch (const std::exception& e) {
  604. sendErrorResponse(res, std::string("Queue status check failed: ") + e.what(), 500);
  605. }
  606. }
  607. void Server::handleJobStatus(const httplib::Request& req, httplib::Response& res) {
  608. try {
  609. if (!m_generationQueue) {
  610. sendErrorResponse(res, "Generation queue not available", 500);
  611. return;
  612. }
  613. // Extract job ID from URL path
  614. std::string jobId = req.matches[1].str();
  615. if (jobId.empty()) {
  616. sendErrorResponse(res, "Missing job ID", 400);
  617. return;
  618. }
  619. // Get job information
  620. auto jobInfo = m_generationQueue->getJobInfo(jobId);
  621. if (jobInfo.id.empty()) {
  622. sendErrorResponse(res, "Job not found", 404);
  623. return;
  624. }
  625. // Convert status to string
  626. std::string statusStr;
  627. switch (jobInfo.status) {
  628. case GenerationStatus::QUEUED: statusStr = "queued"; break;
  629. case GenerationStatus::PROCESSING: statusStr = "processing"; break;
  630. case GenerationStatus::COMPLETED: statusStr = "completed"; break;
  631. case GenerationStatus::FAILED: statusStr = "failed"; break;
  632. }
  633. // Convert time points to timestamps
  634. auto queuedTime = std::chrono::duration_cast<std::chrono::milliseconds>(
  635. jobInfo.queuedTime.time_since_epoch()).count();
  636. auto startTime = std::chrono::duration_cast<std::chrono::milliseconds>(
  637. jobInfo.startTime.time_since_epoch()).count();
  638. auto endTime = std::chrono::duration_cast<std::chrono::milliseconds>(
  639. jobInfo.endTime.time_since_epoch()).count();
  640. // Create download URLs for output files
  641. json outputUrls = json::array();
  642. for (const auto& filePath : jobInfo.outputFiles) {
  643. // Extract filename from full path
  644. std::filesystem::path p(filePath);
  645. std::string filename = p.filename().string();
  646. // Create download URL
  647. std::string url = "/api/queue/job/" + jobInfo.id + "/output/" + filename;
  648. json fileInfo = {
  649. {"filename", filename},
  650. {"url", url},
  651. {"path", filePath}
  652. };
  653. outputUrls.push_back(fileInfo);
  654. }
  655. json response = {
  656. {"job", {
  657. {"id", jobInfo.id},
  658. {"status", statusStr},
  659. {"prompt", jobInfo.prompt},
  660. {"queued_time", queuedTime},
  661. {"start_time", startTime > 0 ? json(startTime) : json(nullptr)},
  662. {"end_time", endTime > 0 ? json(endTime) : json(nullptr)},
  663. {"position", jobInfo.position},
  664. {"outputs", outputUrls},
  665. {"error_message", jobInfo.errorMessage},
  666. {"progress", jobInfo.progress}
  667. }}
  668. };
  669. sendJsonResponse(res, response);
  670. } catch (const std::exception& e) {
  671. sendErrorResponse(res, std::string("Job status check failed: ") + e.what(), 500);
  672. }
  673. }
  674. void Server::handleCancelJob(const httplib::Request& req, httplib::Response& res) {
  675. try {
  676. if (!m_generationQueue) {
  677. sendErrorResponse(res, "Generation queue not available", 500);
  678. return;
  679. }
  680. // Parse JSON request body
  681. json requestJson = json::parse(req.body);
  682. // Validate required fields
  683. if (!requestJson.contains("job_id") || !requestJson["job_id"].is_string()) {
  684. sendErrorResponse(res, "Missing or invalid 'job_id' field", 400);
  685. return;
  686. }
  687. std::string jobId = requestJson["job_id"];
  688. // Try to cancel the job
  689. bool cancelled = m_generationQueue->cancelJob(jobId);
  690. if (cancelled) {
  691. json response = {
  692. {"status", "success"},
  693. {"message", "Job cancelled successfully"},
  694. {"job_id", jobId}
  695. };
  696. sendJsonResponse(res, response);
  697. } else {
  698. json response = {
  699. {"status", "error"},
  700. {"message", "Job not found or already processing"},
  701. {"job_id", jobId}
  702. };
  703. sendJsonResponse(res, response, 404);
  704. }
  705. } catch (const json::parse_error& e) {
  706. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400);
  707. } catch (const std::exception& e) {
  708. sendErrorResponse(res, std::string("Job cancellation failed: ") + e.what(), 500);
  709. }
  710. }
  711. void Server::handleClearQueue(const httplib::Request& req, httplib::Response& res) {
  712. try {
  713. if (!m_generationQueue) {
  714. sendErrorResponse(res, "Generation queue not available", 500);
  715. return;
  716. }
  717. // Clear the queue
  718. m_generationQueue->clearQueue();
  719. json response = {
  720. {"status", "success"},
  721. {"message", "Queue cleared successfully"}
  722. };
  723. sendJsonResponse(res, response);
  724. } catch (const std::exception& e) {
  725. sendErrorResponse(res, std::string("Queue clear failed: ") + e.what(), 500);
  726. }
  727. }
  728. void Server::handleDownloadOutput(const httplib::Request& req, httplib::Response& res) {
  729. try {
  730. // Extract job ID and filename from URL path
  731. if (req.matches.size() < 3) {
  732. sendErrorResponse(res, "Invalid request: job ID and filename required", 400);
  733. return;
  734. }
  735. std::string jobId = req.matches[1];
  736. std::string filename = req.matches[2];
  737. // Construct file path using the same logic as when saving:
  738. // {outputDir}/{jobId}/{filename}
  739. std::string fullPath = m_outputDir + "/" + jobId + "/" + filename;
  740. // Check if file exists
  741. if (!std::filesystem::exists(fullPath)) {
  742. sendErrorResponse(res, "Output file not found: " + fullPath, 404);
  743. return;
  744. }
  745. // Check if file exists on filesystem
  746. std::ifstream file(fullPath, std::ios::binary);
  747. if (!file.is_open()) {
  748. sendErrorResponse(res, "Output file not accessible", 404);
  749. return;
  750. }
  751. // Read file contents
  752. std::ostringstream fileContent;
  753. fileContent << file.rdbuf();
  754. file.close();
  755. // Determine content type based on file extension
  756. std::string contentType = "application/octet-stream";
  757. if (Utils::endsWith(filename, ".png")) {
  758. contentType = "image/png";
  759. } else if (Utils::endsWith(filename, ".jpg") || Utils::endsWith(filename, ".jpeg")) {
  760. contentType = "image/jpeg";
  761. } else if (Utils::endsWith(filename, ".mp4")) {
  762. contentType = "video/mp4";
  763. } else if (Utils::endsWith(filename, ".gif")) {
  764. contentType = "image/gif";
  765. }
  766. // Set response headers
  767. res.set_header("Content-Type", contentType);
  768. //res.set_header("Content-Disposition", "attachment; filename=\"" + filename + "\"");
  769. res.set_content(fileContent.str(), contentType);
  770. res.status = 200;
  771. } catch (const std::exception& e) {
  772. sendErrorResponse(res, std::string("Failed to download file: ") + e.what(), 500);
  773. }
  774. }
  775. void Server::sendJsonResponse(httplib::Response& res, const nlohmann::json& json, int status_code) {
  776. res.set_header("Content-Type", "application/json");
  777. res.status = status_code;
  778. res.body = json.dump();
  779. }
  780. void Server::sendErrorResponse(httplib::Response& res, const std::string& message, int status_code,
  781. const std::string& error_code, const std::string& request_id) {
  782. json errorResponse = {
  783. {"error", {
  784. {"message", message},
  785. {"status_code", status_code},
  786. {"error_code", error_code},
  787. {"request_id", request_id},
  788. {"timestamp", std::chrono::duration_cast<std::chrono::seconds>(
  789. std::chrono::system_clock::now().time_since_epoch()).count()}
  790. }}
  791. };
  792. sendJsonResponse(res, errorResponse, status_code);
  793. }
  794. std::pair<bool, std::string> Server::validateGenerationParameters(const nlohmann::json& params) {
  795. // Validate required fields
  796. if (!params.contains("prompt") || !params["prompt"].is_string()) {
  797. return {false, "Missing or invalid 'prompt' field"};
  798. }
  799. const std::string& prompt = params["prompt"];
  800. if (prompt.empty()) {
  801. return {false, "Prompt cannot be empty"};
  802. }
  803. if (prompt.length() > 10000) {
  804. return {false, "Prompt too long (max 10000 characters)"};
  805. }
  806. // Validate negative prompt if present
  807. if (params.contains("negative_prompt")) {
  808. if (!params["negative_prompt"].is_string()) {
  809. return {false, "Invalid 'negative_prompt' field, must be string"};
  810. }
  811. if (params["negative_prompt"].get<std::string>().length() > 10000) {
  812. return {false, "Negative prompt too long (max 10000 characters)"};
  813. }
  814. }
  815. // Validate width
  816. if (params.contains("width")) {
  817. if (!params["width"].is_number_integer()) {
  818. return {false, "Invalid 'width' field, must be integer"};
  819. }
  820. int width = params["width"];
  821. if (width < 64 || width > 2048 || width % 64 != 0) {
  822. return {false, "Width must be between 64 and 2048 and divisible by 64"};
  823. }
  824. }
  825. // Validate height
  826. if (params.contains("height")) {
  827. if (!params["height"].is_number_integer()) {
  828. return {false, "Invalid 'height' field, must be integer"};
  829. }
  830. int height = params["height"];
  831. if (height < 64 || height > 2048 || height % 64 != 0) {
  832. return {false, "Height must be between 64 and 2048 and divisible by 64"};
  833. }
  834. }
  835. // Validate batch count
  836. if (params.contains("batch_count")) {
  837. if (!params["batch_count"].is_number_integer()) {
  838. return {false, "Invalid 'batch_count' field, must be integer"};
  839. }
  840. int batchCount = params["batch_count"];
  841. if (batchCount < 1 || batchCount > 100) {
  842. return {false, "Batch count must be between 1 and 100"};
  843. }
  844. }
  845. // Validate steps
  846. if (params.contains("steps")) {
  847. if (!params["steps"].is_number_integer()) {
  848. return {false, "Invalid 'steps' field, must be integer"};
  849. }
  850. int steps = params["steps"];
  851. if (steps < 1 || steps > 150) {
  852. return {false, "Steps must be between 1 and 150"};
  853. }
  854. }
  855. // Validate CFG scale
  856. if (params.contains("cfg_scale")) {
  857. if (!params["cfg_scale"].is_number()) {
  858. return {false, "Invalid 'cfg_scale' field, must be number"};
  859. }
  860. float cfgScale = params["cfg_scale"];
  861. if (cfgScale < 1.0f || cfgScale > 30.0f) {
  862. return {false, "CFG scale must be between 1.0 and 30.0"};
  863. }
  864. }
  865. // Validate seed
  866. if (params.contains("seed")) {
  867. if (!params["seed"].is_string() && !params["seed"].is_number_integer()) {
  868. return {false, "Invalid 'seed' field, must be string or integer"};
  869. }
  870. }
  871. // Validate sampling method
  872. if (params.contains("sampling_method")) {
  873. if (!params["sampling_method"].is_string()) {
  874. return {false, "Invalid 'sampling_method' field, must be string"};
  875. }
  876. std::string method = params["sampling_method"];
  877. std::vector<std::string> validMethods = {
  878. "euler", "euler_a", "heun", "dpm2", "dpm++2s_a", "dpm++2m",
  879. "dpm++2mv2", "ipndm", "ipndm_v", "lcm", "ddim_trailing", "tcd", "default"
  880. };
  881. if (std::find(validMethods.begin(), validMethods.end(), method) == validMethods.end()) {
  882. return {false, "Invalid sampling method"};
  883. }
  884. }
  885. // Validate scheduler
  886. if (params.contains("scheduler")) {
  887. if (!params["scheduler"].is_string()) {
  888. return {false, "Invalid 'scheduler' field, must be string"};
  889. }
  890. std::string scheduler = params["scheduler"];
  891. std::vector<std::string> validSchedulers = {
  892. "discrete", "karras", "exponential", "ays", "gits",
  893. "smoothstep", "sgm_uniform", "simple", "default"
  894. };
  895. if (std::find(validSchedulers.begin(), validSchedulers.end(), scheduler) == validSchedulers.end()) {
  896. return {false, "Invalid scheduler"};
  897. }
  898. }
  899. // Validate strength
  900. if (params.contains("strength")) {
  901. if (!params["strength"].is_number()) {
  902. return {false, "Invalid 'strength' field, must be number"};
  903. }
  904. float strength = params["strength"];
  905. if (strength < 0.0f || strength > 1.0f) {
  906. return {false, "Strength must be between 0.0 and 1.0"};
  907. }
  908. }
  909. // Validate control strength
  910. if (params.contains("control_strength")) {
  911. if (!params["control_strength"].is_number()) {
  912. return {false, "Invalid 'control_strength' field, must be number"};
  913. }
  914. float controlStrength = params["control_strength"];
  915. if (controlStrength < 0.0f || controlStrength > 1.0f) {
  916. return {false, "Control strength must be between 0.0 and 1.0"};
  917. }
  918. }
  919. // Validate clip skip
  920. if (params.contains("clip_skip")) {
  921. if (!params["clip_skip"].is_number_integer()) {
  922. return {false, "Invalid 'clip_skip' field, must be integer"};
  923. }
  924. int clipSkip = params["clip_skip"];
  925. if (clipSkip < -1 || clipSkip > 12) {
  926. return {false, "Clip skip must be between -1 and 12"};
  927. }
  928. }
  929. // Validate threads
  930. if (params.contains("threads")) {
  931. if (!params["threads"].is_number_integer()) {
  932. return {false, "Invalid 'threads' field, must be integer"};
  933. }
  934. int threads = params["threads"];
  935. if (threads < -1 || threads > 32) {
  936. return {false, "Threads must be between -1 (auto) and 32"};
  937. }
  938. }
  939. return {true, ""};
  940. }
  941. SamplingMethod Server::parseSamplingMethod(const std::string& method) {
  942. if (method == "euler") return SamplingMethod::EULER;
  943. else if (method == "euler_a") return SamplingMethod::EULER_A;
  944. else if (method == "heun") return SamplingMethod::HEUN;
  945. else if (method == "dpm2") return SamplingMethod::DPM2;
  946. else if (method == "dpm++2s_a") return SamplingMethod::DPMPP2S_A;
  947. else if (method == "dpm++2m") return SamplingMethod::DPMPP2M;
  948. else if (method == "dpm++2mv2") return SamplingMethod::DPMPP2MV2;
  949. else if (method == "ipndm") return SamplingMethod::IPNDM;
  950. else if (method == "ipndm_v") return SamplingMethod::IPNDM_V;
  951. else if (method == "lcm") return SamplingMethod::LCM;
  952. else if (method == "ddim_trailing") return SamplingMethod::DDIM_TRAILING;
  953. else if (method == "tcd") return SamplingMethod::TCD;
  954. else return SamplingMethod::DEFAULT;
  955. }
  956. Scheduler Server::parseScheduler(const std::string& scheduler) {
  957. if (scheduler == "discrete") return Scheduler::DISCRETE;
  958. else if (scheduler == "karras") return Scheduler::KARRAS;
  959. else if (scheduler == "exponential") return Scheduler::EXPONENTIAL;
  960. else if (scheduler == "ays") return Scheduler::AYS;
  961. else if (scheduler == "gits") return Scheduler::GITS;
  962. else if (scheduler == "smoothstep") return Scheduler::SMOOTHSTEP;
  963. else if (scheduler == "sgm_uniform") return Scheduler::SGM_UNIFORM;
  964. else if (scheduler == "simple") return Scheduler::SIMPLE;
  965. else return Scheduler::DEFAULT;
  966. }
  967. std::string Server::generateRequestId() {
  968. std::random_device rd;
  969. std::mt19937 gen(rd());
  970. std::uniform_int_distribution<> dis(100000, 999999);
  971. return "req_" + std::to_string(dis(gen));
  972. }
  973. std::tuple<std::vector<uint8_t>, int, int, int, bool, std::string>
  974. Server::loadImageFromInput(const std::string& input) {
  975. std::vector<uint8_t> imageData;
  976. int width = 0, height = 0, channels = 0;
  977. // Auto-detect input source type
  978. // 1. Check if input is a URL (starts with http:// or https://)
  979. if (Utils::startsWith(input, "http://") || Utils::startsWith(input, "https://")) {
  980. // Parse URL to extract host and path
  981. std::string url = input;
  982. std::string scheme, host, path;
  983. int port = 80;
  984. // Determine scheme and port
  985. if (Utils::startsWith(url, "https://")) {
  986. scheme = "https";
  987. port = 443;
  988. url = url.substr(8); // Remove "https://"
  989. } else {
  990. scheme = "http";
  991. port = 80;
  992. url = url.substr(7); // Remove "http://"
  993. }
  994. // Extract host and path
  995. size_t slashPos = url.find('/');
  996. if (slashPos != std::string::npos) {
  997. host = url.substr(0, slashPos);
  998. path = url.substr(slashPos);
  999. } else {
  1000. host = url;
  1001. path = "/";
  1002. }
  1003. // Check for custom port
  1004. size_t colonPos = host.find(':');
  1005. if (colonPos != std::string::npos) {
  1006. try {
  1007. port = std::stoi(host.substr(colonPos + 1));
  1008. host = host.substr(0, colonPos);
  1009. } catch (...) {
  1010. return {imageData, 0, 0, 0, false, "Invalid port in URL"};
  1011. }
  1012. }
  1013. // Download image using httplib
  1014. try {
  1015. httplib::Result res;
  1016. if (scheme == "https") {
  1017. #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
  1018. httplib::SSLClient client(host, port);
  1019. client.set_follow_location(true);
  1020. client.set_connection_timeout(30, 0); // 30 seconds
  1021. client.set_read_timeout(60, 0); // 60 seconds
  1022. res = client.Get(path.c_str());
  1023. #else
  1024. return {imageData, 0, 0, 0, false, "HTTPS not supported (OpenSSL not available)"};
  1025. #endif
  1026. } else {
  1027. httplib::Client client(host, port);
  1028. client.set_follow_location(true);
  1029. client.set_connection_timeout(30, 0); // 30 seconds
  1030. client.set_read_timeout(60, 0); // 60 seconds
  1031. res = client.Get(path.c_str());
  1032. }
  1033. if (!res) {
  1034. return {imageData, 0, 0, 0, false, "Failed to download image from URL: Connection error"};
  1035. }
  1036. if (res->status != 200) {
  1037. return {imageData, 0, 0, 0, false, "Failed to download image from URL: HTTP " + std::to_string(res->status)};
  1038. }
  1039. // Convert response body to vector
  1040. std::vector<uint8_t> downloadedData(res->body.begin(), res->body.end());
  1041. // Load image from memory
  1042. int w, h, c;
  1043. unsigned char* pixels = stbi_load_from_memory(
  1044. downloadedData.data(),
  1045. downloadedData.size(),
  1046. &w, &h, &c,
  1047. 3 // Force RGB
  1048. );
  1049. if (!pixels) {
  1050. return {imageData, 0, 0, 0, false, "Failed to decode image from URL"};
  1051. }
  1052. width = w;
  1053. height = h;
  1054. channels = 3;
  1055. size_t dataSize = width * height * channels;
  1056. imageData.resize(dataSize);
  1057. std::memcpy(imageData.data(), pixels, dataSize);
  1058. stbi_image_free(pixels);
  1059. } catch (const std::exception& e) {
  1060. return {imageData, 0, 0, 0, false, "Failed to download image from URL: " + std::string(e.what())};
  1061. }
  1062. }
  1063. // 2. Check if input is base64 encoded data URI (starts with "data:image")
  1064. else if (Utils::startsWith(input, "data:image")) {
  1065. // Extract base64 data after the comma
  1066. size_t commaPos = input.find(',');
  1067. if (commaPos == std::string::npos) {
  1068. return {imageData, 0, 0, 0, false, "Invalid data URI format"};
  1069. }
  1070. std::string base64Data = input.substr(commaPos + 1);
  1071. std::vector<uint8_t> decodedData = Utils::base64Decode(base64Data);
  1072. // Load image from memory using stb_image
  1073. int w, h, c;
  1074. unsigned char* pixels = stbi_load_from_memory(
  1075. decodedData.data(),
  1076. decodedData.size(),
  1077. &w, &h, &c,
  1078. 3 // Force RGB
  1079. );
  1080. if (!pixels) {
  1081. return {imageData, 0, 0, 0, false, "Failed to decode image from base64 data URI"};
  1082. }
  1083. width = w;
  1084. height = h;
  1085. channels = 3; // We forced RGB
  1086. // Copy pixel data
  1087. size_t dataSize = width * height * channels;
  1088. imageData.resize(dataSize);
  1089. std::memcpy(imageData.data(), pixels, dataSize);
  1090. stbi_image_free(pixels);
  1091. }
  1092. // 3. Check if input is raw base64 (long string without slashes, likely base64)
  1093. else if (input.length() > 100 && input.find('/') == std::string::npos && input.find('.') == std::string::npos) {
  1094. // Likely raw base64 without data URI prefix
  1095. std::vector<uint8_t> decodedData = Utils::base64Decode(input);
  1096. int w, h, c;
  1097. unsigned char* pixels = stbi_load_from_memory(
  1098. decodedData.data(),
  1099. decodedData.size(),
  1100. &w, &h, &c,
  1101. 3 // Force RGB
  1102. );
  1103. if (!pixels) {
  1104. return {imageData, 0, 0, 0, false, "Failed to decode image from base64"};
  1105. }
  1106. width = w;
  1107. height = h;
  1108. channels = 3;
  1109. size_t dataSize = width * height * channels;
  1110. imageData.resize(dataSize);
  1111. std::memcpy(imageData.data(), pixels, dataSize);
  1112. stbi_image_free(pixels);
  1113. }
  1114. // 4. Treat as local file path
  1115. else {
  1116. int w, h, c;
  1117. unsigned char* pixels = stbi_load(input.c_str(), &w, &h, &c, 3);
  1118. if (!pixels) {
  1119. return {imageData, 0, 0, 0, false, "Failed to load image from file: " + input};
  1120. }
  1121. width = w;
  1122. height = h;
  1123. channels = 3;
  1124. size_t dataSize = width * height * channels;
  1125. imageData.resize(dataSize);
  1126. std::memcpy(imageData.data(), pixels, dataSize);
  1127. stbi_image_free(pixels);
  1128. }
  1129. return {imageData, width, height, channels, true, ""};
  1130. }
  1131. std::string Server::samplingMethodToString(SamplingMethod method) {
  1132. switch (method) {
  1133. case SamplingMethod::EULER: return "euler";
  1134. case SamplingMethod::EULER_A: return "euler_a";
  1135. case SamplingMethod::HEUN: return "heun";
  1136. case SamplingMethod::DPM2: return "dpm2";
  1137. case SamplingMethod::DPMPP2S_A: return "dpm++2s_a";
  1138. case SamplingMethod::DPMPP2M: return "dpm++2m";
  1139. case SamplingMethod::DPMPP2MV2: return "dpm++2mv2";
  1140. case SamplingMethod::IPNDM: return "ipndm";
  1141. case SamplingMethod::IPNDM_V: return "ipndm_v";
  1142. case SamplingMethod::LCM: return "lcm";
  1143. case SamplingMethod::DDIM_TRAILING: return "ddim_trailing";
  1144. case SamplingMethod::TCD: return "tcd";
  1145. default: return "default";
  1146. }
  1147. }
  1148. std::string Server::schedulerToString(Scheduler scheduler) {
  1149. switch (scheduler) {
  1150. case Scheduler::DISCRETE: return "discrete";
  1151. case Scheduler::KARRAS: return "karras";
  1152. case Scheduler::EXPONENTIAL: return "exponential";
  1153. case Scheduler::AYS: return "ays";
  1154. case Scheduler::GITS: return "gits";
  1155. case Scheduler::SMOOTHSTEP: return "smoothstep";
  1156. case Scheduler::SGM_UNIFORM: return "sgm_uniform";
  1157. case Scheduler::SIMPLE: return "simple";
  1158. default: return "default";
  1159. }
  1160. }
  1161. uint64_t Server::estimateGenerationTime(const GenerationRequest& request) {
  1162. // Basic estimation based on parameters
  1163. uint64_t baseTime = 1000; // 1 second base time
  1164. // Factor in steps
  1165. baseTime *= request.steps;
  1166. // Factor in resolution
  1167. double resolutionFactor = (request.width * request.height) / (512.0 * 512.0);
  1168. baseTime = static_cast<uint64_t>(baseTime * resolutionFactor);
  1169. // Factor in batch count
  1170. baseTime *= request.batchCount;
  1171. // Adjust for sampling method (some are faster than others)
  1172. switch (request.samplingMethod) {
  1173. case SamplingMethod::LCM:
  1174. baseTime /= 4; // LCM is much faster
  1175. break;
  1176. case SamplingMethod::EULER:
  1177. case SamplingMethod::EULER_A:
  1178. baseTime *= 0.8; // Euler methods are faster
  1179. break;
  1180. case SamplingMethod::DPM2:
  1181. case SamplingMethod::DPMPP2S_A:
  1182. baseTime *= 1.2; // DPM methods are slower
  1183. break;
  1184. default:
  1185. break;
  1186. }
  1187. return baseTime;
  1188. }
  1189. size_t Server::estimateMemoryUsage(const GenerationRequest& request) {
  1190. // Basic memory estimation in bytes
  1191. size_t baseMemory = 1024 * 1024 * 1024; // 1GB base
  1192. // Factor in resolution
  1193. double resolutionFactor = (request.width * request.height) / (512.0 * 512.0);
  1194. baseMemory = static_cast<size_t>(baseMemory * resolutionFactor);
  1195. // Factor in batch count
  1196. baseMemory *= request.batchCount;
  1197. // Additional memory for certain features
  1198. if (request.diffusionFlashAttn) {
  1199. baseMemory += 512 * 1024 * 1024; // Extra 512MB for flash attention
  1200. }
  1201. if (!request.controlNetPath.empty()) {
  1202. baseMemory += 1024 * 1024 * 1024; // Extra 1GB for ControlNet
  1203. }
  1204. return baseMemory;
  1205. }
  1206. // Specialized generation endpoints
  1207. void Server::handleText2Img(const httplib::Request& req, httplib::Response& res) {
  1208. std::string requestId = generateRequestId();
  1209. try {
  1210. if (!m_generationQueue) {
  1211. sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
  1212. return;
  1213. }
  1214. json requestJson = json::parse(req.body);
  1215. // Validate required fields for text2img
  1216. if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
  1217. sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
  1218. return;
  1219. }
  1220. // Validate all parameters
  1221. auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
  1222. if (!isValid) {
  1223. sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
  1224. return;
  1225. }
  1226. // Check if any model is loaded
  1227. if (!m_modelManager) {
  1228. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  1229. return;
  1230. }
  1231. // Get currently loaded checkpoint model
  1232. auto allModels = m_modelManager->getAllModels();
  1233. std::string loadedModelName;
  1234. for (const auto& [modelName, modelInfo] : allModels) {
  1235. if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
  1236. loadedModelName = modelName;
  1237. break;
  1238. }
  1239. }
  1240. if (loadedModelName.empty()) {
  1241. sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
  1242. return;
  1243. }
  1244. // Create generation request specifically for text2img
  1245. GenerationRequest genRequest;
  1246. genRequest.id = requestId;
  1247. genRequest.modelName = loadedModelName; // Use the currently loaded model
  1248. genRequest.prompt = requestJson["prompt"];
  1249. genRequest.negativePrompt = requestJson.value("negative_prompt", "");
  1250. genRequest.width = requestJson.value("width", 512);
  1251. genRequest.height = requestJson.value("height", 512);
  1252. genRequest.batchCount = requestJson.value("batch_count", 1);
  1253. genRequest.steps = requestJson.value("steps", 20);
  1254. genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
  1255. genRequest.seed = requestJson.value("seed", "random");
  1256. // Parse optional parameters
  1257. if (requestJson.contains("sampling_method")) {
  1258. genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
  1259. }
  1260. if (requestJson.contains("scheduler")) {
  1261. genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
  1262. }
  1263. // Set text2img specific defaults
  1264. genRequest.strength = 1.0f; // Full strength for text2img
  1265. // Optional VAE model
  1266. if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
  1267. std::string vaeModelId = requestJson["vae_model"];
  1268. if (!vaeModelId.empty()) {
  1269. auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
  1270. if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
  1271. genRequest.vaePath = vaeInfo.path;
  1272. } else {
  1273. sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
  1274. return;
  1275. }
  1276. }
  1277. }
  1278. // Optional TAESD model
  1279. if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
  1280. std::string taesdModelId = requestJson["taesd_model"];
  1281. if (!taesdModelId.empty()) {
  1282. auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
  1283. if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
  1284. genRequest.taesdPath = taesdInfo.path;
  1285. } else {
  1286. sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
  1287. return;
  1288. }
  1289. }
  1290. }
  1291. // Enqueue request
  1292. auto future = m_generationQueue->enqueueRequest(genRequest);
  1293. json params = {
  1294. {"prompt", genRequest.prompt},
  1295. {"negative_prompt", genRequest.negativePrompt},
  1296. {"model", genRequest.modelName},
  1297. {"width", genRequest.width},
  1298. {"height", genRequest.height},
  1299. {"batch_count", genRequest.batchCount},
  1300. {"steps", genRequest.steps},
  1301. {"cfg_scale", genRequest.cfgScale},
  1302. {"seed", genRequest.seed},
  1303. {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
  1304. {"scheduler", schedulerToString(genRequest.scheduler)}
  1305. };
  1306. // Add VAE/TAESD if specified
  1307. if (!genRequest.vaePath.empty()) {
  1308. params["vae_model"] = requestJson.value("vae_model", "");
  1309. }
  1310. if (!genRequest.taesdPath.empty()) {
  1311. params["taesd_model"] = requestJson.value("taesd_model", "");
  1312. }
  1313. json response = {
  1314. {"request_id", requestId},
  1315. {"status", "queued"},
  1316. {"message", "Text-to-image generation request queued successfully"},
  1317. {"queue_position", m_generationQueue->getQueueSize()},
  1318. {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
  1319. {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
  1320. {"type", "text2img"},
  1321. {"parameters", params}
  1322. };
  1323. sendJsonResponse(res, response, 202);
  1324. } catch (const json::parse_error& e) {
  1325. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  1326. } catch (const std::exception& e) {
  1327. sendErrorResponse(res, std::string("Text-to-image request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  1328. }
  1329. }
  1330. void Server::handleImg2Img(const httplib::Request& req, httplib::Response& res) {
  1331. std::string requestId = generateRequestId();
  1332. try {
  1333. if (!m_generationQueue) {
  1334. sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
  1335. return;
  1336. }
  1337. json requestJson = json::parse(req.body);
  1338. // Validate required fields for img2img
  1339. if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
  1340. sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
  1341. return;
  1342. }
  1343. if (!requestJson.contains("init_image") || !requestJson["init_image"].is_string()) {
  1344. sendErrorResponse(res, "Missing or invalid 'init_image' field", 400, "INVALID_PARAMETERS", requestId);
  1345. return;
  1346. }
  1347. // Validate all parameters
  1348. auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
  1349. if (!isValid) {
  1350. sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
  1351. return;
  1352. }
  1353. // Check if any model is loaded
  1354. if (!m_modelManager) {
  1355. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  1356. return;
  1357. }
  1358. // Get currently loaded checkpoint model
  1359. auto allModels = m_modelManager->getAllModels();
  1360. std::string loadedModelName;
  1361. for (const auto& [modelName, modelInfo] : allModels) {
  1362. if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
  1363. loadedModelName = modelName;
  1364. break;
  1365. }
  1366. }
  1367. if (loadedModelName.empty()) {
  1368. sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
  1369. return;
  1370. }
  1371. // Load the init image
  1372. std::string initImageInput = requestJson["init_image"];
  1373. auto [imageData, imgWidth, imgHeight, imgChannels, success, loadError] = loadImageFromInput(initImageInput);
  1374. if (!success) {
  1375. sendErrorResponse(res, "Failed to load init image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId);
  1376. return;
  1377. }
  1378. // Create generation request specifically for img2img
  1379. GenerationRequest genRequest;
  1380. genRequest.id = requestId;
  1381. genRequest.requestType = GenerationRequest::RequestType::IMG2IMG;
  1382. genRequest.modelName = loadedModelName; // Use the currently loaded model
  1383. genRequest.prompt = requestJson["prompt"];
  1384. genRequest.negativePrompt = requestJson.value("negative_prompt", "");
  1385. genRequest.width = requestJson.value("width", imgWidth); // Default to input image dimensions
  1386. genRequest.height = requestJson.value("height", imgHeight);
  1387. genRequest.batchCount = requestJson.value("batch_count", 1);
  1388. genRequest.steps = requestJson.value("steps", 20);
  1389. genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
  1390. genRequest.seed = requestJson.value("seed", "random");
  1391. genRequest.strength = requestJson.value("strength", 0.75f);
  1392. // Set init image data
  1393. genRequest.initImageData = imageData;
  1394. genRequest.initImageWidth = imgWidth;
  1395. genRequest.initImageHeight = imgHeight;
  1396. genRequest.initImageChannels = imgChannels;
  1397. // Parse optional parameters
  1398. if (requestJson.contains("sampling_method")) {
  1399. genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
  1400. }
  1401. if (requestJson.contains("scheduler")) {
  1402. genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
  1403. }
  1404. // Optional VAE model
  1405. if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
  1406. std::string vaeModelId = requestJson["vae_model"];
  1407. if (!vaeModelId.empty()) {
  1408. auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
  1409. if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
  1410. genRequest.vaePath = vaeInfo.path;
  1411. } else {
  1412. sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
  1413. return;
  1414. }
  1415. }
  1416. }
  1417. // Optional TAESD model
  1418. if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
  1419. std::string taesdModelId = requestJson["taesd_model"];
  1420. if (!taesdModelId.empty()) {
  1421. auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
  1422. if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
  1423. genRequest.taesdPath = taesdInfo.path;
  1424. } else {
  1425. sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
  1426. return;
  1427. }
  1428. }
  1429. }
  1430. // Enqueue request
  1431. auto future = m_generationQueue->enqueueRequest(genRequest);
  1432. json params = {
  1433. {"prompt", genRequest.prompt},
  1434. {"negative_prompt", genRequest.negativePrompt},
  1435. {"init_image", requestJson["init_image"]},
  1436. {"model", genRequest.modelName},
  1437. {"width", genRequest.width},
  1438. {"height", genRequest.height},
  1439. {"batch_count", genRequest.batchCount},
  1440. {"steps", genRequest.steps},
  1441. {"cfg_scale", genRequest.cfgScale},
  1442. {"seed", genRequest.seed},
  1443. {"strength", genRequest.strength},
  1444. {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
  1445. {"scheduler", schedulerToString(genRequest.scheduler)}
  1446. };
  1447. // Add VAE/TAESD if specified
  1448. if (!genRequest.vaePath.empty()) {
  1449. params["vae_model"] = requestJson.value("vae_model", "");
  1450. }
  1451. if (!genRequest.taesdPath.empty()) {
  1452. params["taesd_model"] = requestJson.value("taesd_model", "");
  1453. }
  1454. json response = {
  1455. {"request_id", requestId},
  1456. {"status", "queued"},
  1457. {"message", "Image-to-image generation request queued successfully"},
  1458. {"queue_position", m_generationQueue->getQueueSize()},
  1459. {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
  1460. {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
  1461. {"type", "img2img"},
  1462. {"parameters", params}
  1463. };
  1464. sendJsonResponse(res, response, 202);
  1465. } catch (const json::parse_error& e) {
  1466. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  1467. } catch (const std::exception& e) {
  1468. sendErrorResponse(res, std::string("Image-to-image request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  1469. }
  1470. }
  1471. void Server::handleControlNet(const httplib::Request& req, httplib::Response& res) {
  1472. std::string requestId = generateRequestId();
  1473. try {
  1474. if (!m_generationQueue) {
  1475. sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
  1476. return;
  1477. }
  1478. json requestJson = json::parse(req.body);
  1479. // Validate required fields for ControlNet
  1480. if (!requestJson.contains("prompt") || !requestJson["prompt"].is_string()) {
  1481. sendErrorResponse(res, "Missing or invalid 'prompt' field", 400, "INVALID_PARAMETERS", requestId);
  1482. return;
  1483. }
  1484. if (!requestJson.contains("control_image") || !requestJson["control_image"].is_string()) {
  1485. sendErrorResponse(res, "Missing or invalid 'control_image' field", 400, "INVALID_PARAMETERS", requestId);
  1486. return;
  1487. }
  1488. // Validate all parameters
  1489. auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
  1490. if (!isValid) {
  1491. sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
  1492. return;
  1493. }
  1494. // Check if any model is loaded
  1495. if (!m_modelManager) {
  1496. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  1497. return;
  1498. }
  1499. // Get currently loaded checkpoint model
  1500. auto allModels = m_modelManager->getAllModels();
  1501. std::string loadedModelName;
  1502. for (const auto& [modelName, modelInfo] : allModels) {
  1503. if (modelInfo.type == ModelType::CHECKPOINT && modelInfo.isLoaded) {
  1504. loadedModelName = modelName;
  1505. break;
  1506. }
  1507. }
  1508. if (loadedModelName.empty()) {
  1509. sendErrorResponse(res, "No checkpoint model loaded. Please load a checkpoint model first using POST /api/models/{hash}/load", 400, "NO_CHECKPOINT_LOADED", requestId);
  1510. return;
  1511. }
  1512. // Create generation request specifically for ControlNet
  1513. GenerationRequest genRequest;
  1514. genRequest.id = requestId;
  1515. genRequest.modelName = loadedModelName; // Use the currently loaded model
  1516. genRequest.prompt = requestJson["prompt"];
  1517. genRequest.negativePrompt = requestJson.value("negative_prompt", "");
  1518. genRequest.width = requestJson.value("width", 512);
  1519. genRequest.height = requestJson.value("height", 512);
  1520. genRequest.batchCount = requestJson.value("batch_count", 1);
  1521. genRequest.steps = requestJson.value("steps", 20);
  1522. genRequest.cfgScale = requestJson.value("cfg_scale", 7.5f);
  1523. genRequest.seed = requestJson.value("seed", "random");
  1524. genRequest.controlStrength = requestJson.value("control_strength", 0.9f);
  1525. genRequest.controlNetPath = requestJson.value("control_net_model", "");
  1526. // Parse optional parameters
  1527. if (requestJson.contains("sampling_method")) {
  1528. genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
  1529. }
  1530. if (requestJson.contains("scheduler")) {
  1531. genRequest.scheduler = parseScheduler(requestJson["scheduler"]);
  1532. }
  1533. // Optional VAE model
  1534. if (requestJson.contains("vae_model") && requestJson["vae_model"].is_string()) {
  1535. std::string vaeModelId = requestJson["vae_model"];
  1536. if (!vaeModelId.empty()) {
  1537. auto vaeInfo = m_modelManager->getModelInfo(vaeModelId);
  1538. if (!vaeInfo.name.empty() && vaeInfo.type == ModelType::VAE) {
  1539. genRequest.vaePath = vaeInfo.path;
  1540. } else {
  1541. sendErrorResponse(res, "VAE model not found or invalid: " + vaeModelId, 400, "INVALID_VAE_MODEL", requestId);
  1542. return;
  1543. }
  1544. }
  1545. }
  1546. // Optional TAESD model
  1547. if (requestJson.contains("taesd_model") && requestJson["taesd_model"].is_string()) {
  1548. std::string taesdModelId = requestJson["taesd_model"];
  1549. if (!taesdModelId.empty()) {
  1550. auto taesdInfo = m_modelManager->getModelInfo(taesdModelId);
  1551. if (!taesdInfo.name.empty() && taesdInfo.type == ModelType::TAESD) {
  1552. genRequest.taesdPath = taesdInfo.path;
  1553. } else {
  1554. sendErrorResponse(res, "TAESD model not found or invalid: " + taesdModelId, 400, "INVALID_TAESD_MODEL", requestId);
  1555. return;
  1556. }
  1557. }
  1558. }
  1559. // Store control image path (would be handled in actual implementation)
  1560. genRequest.outputPath = requestJson.value("control_image", "");
  1561. // Enqueue request
  1562. auto future = m_generationQueue->enqueueRequest(genRequest);
  1563. json params = {
  1564. {"prompt", genRequest.prompt},
  1565. {"negative_prompt", genRequest.negativePrompt},
  1566. {"control_image", requestJson["control_image"]},
  1567. {"control_net_model", genRequest.controlNetPath},
  1568. {"model", genRequest.modelName},
  1569. {"width", genRequest.width},
  1570. {"height", genRequest.height},
  1571. {"batch_count", genRequest.batchCount},
  1572. {"steps", genRequest.steps},
  1573. {"cfg_scale", genRequest.cfgScale},
  1574. {"seed", genRequest.seed},
  1575. {"control_strength", genRequest.controlStrength},
  1576. {"sampling_method", samplingMethodToString(genRequest.samplingMethod)},
  1577. {"scheduler", schedulerToString(genRequest.scheduler)}
  1578. };
  1579. // Add VAE/TAESD if specified
  1580. if (!genRequest.vaePath.empty()) {
  1581. params["vae_model"] = requestJson.value("vae_model", "");
  1582. }
  1583. if (!genRequest.taesdPath.empty()) {
  1584. params["taesd_model"] = requestJson.value("taesd_model", "");
  1585. }
  1586. json response = {
  1587. {"request_id", requestId},
  1588. {"status", "queued"},
  1589. {"message", "ControlNet generation request queued successfully"},
  1590. {"queue_position", m_generationQueue->getQueueSize()},
  1591. {"estimated_time_seconds", estimateGenerationTime(genRequest) / 1000},
  1592. {"estimated_memory_mb", estimateMemoryUsage(genRequest) / (1024 * 1024)},
  1593. {"type", "controlnet"},
  1594. {"parameters", params}
  1595. };
  1596. sendJsonResponse(res, response, 202);
  1597. } catch (const json::parse_error& e) {
  1598. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  1599. } catch (const std::exception& e) {
  1600. sendErrorResponse(res, std::string("ControlNet request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  1601. }
  1602. }
  1603. void Server::handleUpscale(const httplib::Request& req, httplib::Response& res) {
  1604. std::string requestId = generateRequestId();
  1605. try {
  1606. if (!m_generationQueue) {
  1607. sendErrorResponse(res, "Generation queue not available", 500, "QUEUE_UNAVAILABLE", requestId);
  1608. return;
  1609. }
  1610. json requestJson = json::parse(req.body);
  1611. // Validate required fields for upscaler
  1612. if (!requestJson.contains("image") || !requestJson["image"].is_string()) {
  1613. sendErrorResponse(res, "Missing or invalid 'image' field", 400, "INVALID_PARAMETERS", requestId);
  1614. return;
  1615. }
  1616. if (!requestJson.contains("esrgan_model") || !requestJson["esrgan_model"].is_string()) {
  1617. sendErrorResponse(res, "Missing or invalid 'esrgan_model' field (model hash or name)", 400, "INVALID_PARAMETERS", requestId);
  1618. return;
  1619. }
  1620. // Check if model manager is available
  1621. if (!m_modelManager) {
  1622. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  1623. return;
  1624. }
  1625. // Get the ESRGAN/upscaler model
  1626. std::string esrganModelId = requestJson["esrgan_model"];
  1627. auto modelInfo = m_modelManager->getModelInfo(esrganModelId);
  1628. if (modelInfo.name.empty()) {
  1629. sendErrorResponse(res, "ESRGAN model not found: " + esrganModelId, 404, "MODEL_NOT_FOUND", requestId);
  1630. return;
  1631. }
  1632. if (modelInfo.type != ModelType::ESRGAN && modelInfo.type != ModelType::UPSCALER) {
  1633. sendErrorResponse(res, "Model is not an ESRGAN/upscaler model", 400, "INVALID_MODEL_TYPE", requestId);
  1634. return;
  1635. }
  1636. // Load the input image
  1637. std::string imageInput = requestJson["image"];
  1638. auto [imageData, imgWidth, imgHeight, imgChannels, success, loadError] = loadImageFromInput(imageInput);
  1639. if (!success) {
  1640. sendErrorResponse(res, "Failed to load image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId);
  1641. return;
  1642. }
  1643. // Create upscaler request
  1644. GenerationRequest genRequest;
  1645. genRequest.id = requestId;
  1646. genRequest.requestType = GenerationRequest::RequestType::UPSCALER;
  1647. genRequest.esrganPath = modelInfo.path;
  1648. genRequest.upscaleFactor = requestJson.value("upscale_factor", 4);
  1649. genRequest.nThreads = requestJson.value("threads", -1);
  1650. genRequest.offloadParamsToCpu = requestJson.value("offload_to_cpu", false);
  1651. genRequest.diffusionConvDirect = requestJson.value("direct", false);
  1652. // Set input image data
  1653. genRequest.initImageData = imageData;
  1654. genRequest.initImageWidth = imgWidth;
  1655. genRequest.initImageHeight = imgHeight;
  1656. genRequest.initImageChannels = imgChannels;
  1657. // Enqueue request
  1658. auto future = m_generationQueue->enqueueRequest(genRequest);
  1659. json response = {
  1660. {"request_id", requestId},
  1661. {"status", "queued"},
  1662. {"message", "Upscale request queued successfully"},
  1663. {"queue_position", m_generationQueue->getQueueSize()},
  1664. {"type", "upscale"},
  1665. {"parameters", {
  1666. {"esrgan_model", esrganModelId},
  1667. {"upscale_factor", genRequest.upscaleFactor},
  1668. {"input_width", imgWidth},
  1669. {"input_height", imgHeight},
  1670. {"output_width", imgWidth * genRequest.upscaleFactor},
  1671. {"output_height", imgHeight * genRequest.upscaleFactor}
  1672. }}
  1673. };
  1674. sendJsonResponse(res, response, 202);
  1675. } catch (const json::parse_error& e) {
  1676. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  1677. } catch (const std::exception& e) {
  1678. sendErrorResponse(res, std::string("Upscale request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  1679. }
  1680. }
  1681. // Utility endpoints
  1682. void Server::handleSamplers(const httplib::Request& req, httplib::Response& res) {
  1683. try {
  1684. json samplers = {
  1685. {"samplers", {
  1686. {
  1687. {"name", "euler"},
  1688. {"description", "Euler sampler - fast and simple"},
  1689. {"recommended_steps", 20}
  1690. },
  1691. {
  1692. {"name", "euler_a"},
  1693. {"description", "Euler ancestral sampler - adds randomness"},
  1694. {"recommended_steps", 20}
  1695. },
  1696. {
  1697. {"name", "heun"},
  1698. {"description", "Heun sampler - more accurate but slower"},
  1699. {"recommended_steps", 20}
  1700. },
  1701. {
  1702. {"name", "dpm2"},
  1703. {"description", "DPM2 sampler - second-order DPM"},
  1704. {"recommended_steps", 20}
  1705. },
  1706. {
  1707. {"name", "dpm++2s_a"},
  1708. {"description", "DPM++ 2s ancestral sampler"},
  1709. {"recommended_steps", 20}
  1710. },
  1711. {
  1712. {"name", "dpm++2m"},
  1713. {"description", "DPM++ 2m sampler - multistep"},
  1714. {"recommended_steps", 20}
  1715. },
  1716. {
  1717. {"name", "dpm++2mv2"},
  1718. {"description", "DPM++ 2m v2 sampler - improved multistep"},
  1719. {"recommended_steps", 20}
  1720. },
  1721. {
  1722. {"name", "ipndm"},
  1723. {"description", "IPNDM sampler - improved noise prediction"},
  1724. {"recommended_steps", 20}
  1725. },
  1726. {
  1727. {"name", "ipndm_v"},
  1728. {"description", "IPNDM v sampler - variant of IPNDM"},
  1729. {"recommended_steps", 20}
  1730. },
  1731. {
  1732. {"name", "lcm"},
  1733. {"description", "LCM sampler - Latent Consistency Model, very fast"},
  1734. {"recommended_steps", 4}
  1735. },
  1736. {
  1737. {"name", "ddim_trailing"},
  1738. {"description", "DDIM trailing sampler - deterministic"},
  1739. {"recommended_steps", 20}
  1740. },
  1741. {
  1742. {"name", "tcd"},
  1743. {"description", "TCD sampler - Trajectory Consistency Distillation"},
  1744. {"recommended_steps", 8}
  1745. },
  1746. {
  1747. {"name", "default"},
  1748. {"description", "Use model's default sampler"},
  1749. {"recommended_steps", 20}
  1750. }
  1751. }}
  1752. };
  1753. sendJsonResponse(res, samplers);
  1754. } catch (const std::exception& e) {
  1755. sendErrorResponse(res, std::string("Failed to get samplers: ") + e.what(), 500);
  1756. }
  1757. }
  1758. void Server::handleSchedulers(const httplib::Request& req, httplib::Response& res) {
  1759. try {
  1760. json schedulers = {
  1761. {"schedulers", {
  1762. {
  1763. {"name", "discrete"},
  1764. {"description", "Discrete scheduler - standard noise schedule"}
  1765. },
  1766. {
  1767. {"name", "karras"},
  1768. {"description", "Karras scheduler - improved noise schedule"}
  1769. },
  1770. {
  1771. {"name", "exponential"},
  1772. {"description", "Exponential scheduler - exponential noise decay"}
  1773. },
  1774. {
  1775. {"name", "ays"},
  1776. {"description", "AYS scheduler - Adaptive Your Scheduler"}
  1777. },
  1778. {
  1779. {"name", "gits"},
  1780. {"description", "GITS scheduler - Generalized Iterative Time Steps"}
  1781. },
  1782. {
  1783. {"name", "smoothstep"},
  1784. {"description", "Smoothstep scheduler - smooth transition function"}
  1785. },
  1786. {
  1787. {"name", "sgm_uniform"},
  1788. {"description", "SGM uniform scheduler - uniform noise schedule"}
  1789. },
  1790. {
  1791. {"name", "simple"},
  1792. {"description", "Simple scheduler - basic linear schedule"}
  1793. },
  1794. {
  1795. {"name", "default"},
  1796. {"description", "Use model's default scheduler"}
  1797. }
  1798. }}
  1799. };
  1800. sendJsonResponse(res, schedulers);
  1801. } catch (const std::exception& e) {
  1802. sendErrorResponse(res, std::string("Failed to get schedulers: ") + e.what(), 500);
  1803. }
  1804. }
  1805. void Server::handleParameters(const httplib::Request& req, httplib::Response& res) {
  1806. try {
  1807. json parameters = {
  1808. {"parameters", {
  1809. {
  1810. {"name", "prompt"},
  1811. {"type", "string"},
  1812. {"required", true},
  1813. {"description", "Text prompt for image generation"},
  1814. {"min_length", 1},
  1815. {"max_length", 10000},
  1816. {"example", "a beautiful landscape with mountains"}
  1817. },
  1818. {
  1819. {"name", "negative_prompt"},
  1820. {"type", "string"},
  1821. {"required", false},
  1822. {"description", "Negative prompt to guide generation away from"},
  1823. {"min_length", 0},
  1824. {"max_length", 10000},
  1825. {"example", "blurry, low quality, distorted"}
  1826. },
  1827. {
  1828. {"name", "width"},
  1829. {"type", "integer"},
  1830. {"required", false},
  1831. {"description", "Image width in pixels"},
  1832. {"min", 64},
  1833. {"max", 2048},
  1834. {"multiple_of", 64},
  1835. {"default", 512}
  1836. },
  1837. {
  1838. {"name", "height"},
  1839. {"type", "integer"},
  1840. {"required", false},
  1841. {"description", "Image height in pixels"},
  1842. {"min", 64},
  1843. {"max", 2048},
  1844. {"multiple_of", 64},
  1845. {"default", 512}
  1846. },
  1847. {
  1848. {"name", "steps"},
  1849. {"type", "integer"},
  1850. {"required", false},
  1851. {"description", "Number of diffusion steps"},
  1852. {"min", 1},
  1853. {"max", 150},
  1854. {"default", 20}
  1855. },
  1856. {
  1857. {"name", "cfg_scale"},
  1858. {"type", "number"},
  1859. {"required", false},
  1860. {"description", "Classifier-Free Guidance scale"},
  1861. {"min", 1.0},
  1862. {"max", 30.0},
  1863. {"default", 7.5}
  1864. },
  1865. {
  1866. {"name", "seed"},
  1867. {"type", "string|integer"},
  1868. {"required", false},
  1869. {"description", "Seed for generation (use 'random' for random seed)"},
  1870. {"example", "42"}
  1871. },
  1872. {
  1873. {"name", "sampling_method"},
  1874. {"type", "string"},
  1875. {"required", false},
  1876. {"description", "Sampling method to use"},
  1877. {"enum", {"euler", "euler_a", "heun", "dpm2", "dpm++2s_a", "dpm++2m", "dpm++2mv2", "ipndm", "ipndm_v", "lcm", "ddim_trailing", "tcd", "default"}},
  1878. {"default", "default"}
  1879. },
  1880. {
  1881. {"name", "scheduler"},
  1882. {"type", "string"},
  1883. {"required", false},
  1884. {"description", "Scheduler to use"},
  1885. {"enum", {"discrete", "karras", "exponential", "ays", "gits", "smoothstep", "sgm_uniform", "simple", "default"}},
  1886. {"default", "default"}
  1887. },
  1888. {
  1889. {"name", "batch_count"},
  1890. {"type", "integer"},
  1891. {"required", false},
  1892. {"description", "Number of images to generate"},
  1893. {"min", 1},
  1894. {"max", 100},
  1895. {"default", 1}
  1896. },
  1897. {
  1898. {"name", "strength"},
  1899. {"type", "number"},
  1900. {"required", false},
  1901. {"description", "Strength for img2img (0.0-1.0)"},
  1902. {"min", 0.0},
  1903. {"max", 1.0},
  1904. {"default", 0.75}
  1905. },
  1906. {
  1907. {"name", "control_strength"},
  1908. {"type", "number"},
  1909. {"required", false},
  1910. {"description", "ControlNet strength (0.0-1.0)"},
  1911. {"min", 0.0},
  1912. {"max", 1.0},
  1913. {"default", 0.9}
  1914. }
  1915. }},
  1916. {"openapi", {
  1917. {"version", "3.0.0"},
  1918. {"info", {
  1919. {"title", "Stable Diffusion REST API"},
  1920. {"version", "1.0.0"},
  1921. {"description", "Comprehensive REST API for stable-diffusion.cpp functionality"}
  1922. }},
  1923. {"components", {
  1924. {"schemas", {
  1925. {"GenerationRequest", {
  1926. {"type", "object"},
  1927. {"required", {"prompt"}},
  1928. {"properties", {
  1929. {"prompt", {{"type", "string"}, {"description", "Text prompt for generation"}}},
  1930. {"negative_prompt", {{"type", "string"}, {"description", "Negative prompt"}}},
  1931. {"width", {{"type", "integer"}, {"minimum", 64}, {"maximum", 2048}, {"default", 512}}},
  1932. {"height", {{"type", "integer"}, {"minimum", 64}, {"maximum", 2048}, {"default", 512}}},
  1933. {"steps", {{"type", "integer"}, {"minimum", 1}, {"maximum", 150}, {"default", 20}}},
  1934. {"cfg_scale", {{"type", "number"}, {"minimum", 1.0}, {"maximum", 30.0}, {"default", 7.5}}}
  1935. }}
  1936. }}
  1937. }}
  1938. }}
  1939. }}
  1940. };
  1941. sendJsonResponse(res, parameters);
  1942. } catch (const std::exception& e) {
  1943. sendErrorResponse(res, std::string("Failed to get parameters: ") + e.what(), 500);
  1944. }
  1945. }
  1946. void Server::handleValidate(const httplib::Request& req, httplib::Response& res) {
  1947. std::string requestId = generateRequestId();
  1948. try {
  1949. json requestJson = json::parse(req.body);
  1950. // Validate parameters
  1951. auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
  1952. json response = {
  1953. {"request_id", requestId},
  1954. {"valid", isValid},
  1955. {"message", isValid ? "Parameters are valid" : errorMessage},
  1956. {"errors", isValid ? json::array() : json::array({errorMessage})}
  1957. };
  1958. sendJsonResponse(res, response, isValid ? 200 : 400);
  1959. } catch (const json::parse_error& e) {
  1960. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  1961. } catch (const std::exception& e) {
  1962. sendErrorResponse(res, std::string("Validation failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  1963. }
  1964. }
  1965. void Server::handleEstimate(const httplib::Request& req, httplib::Response& res) {
  1966. std::string requestId = generateRequestId();
  1967. try {
  1968. json requestJson = json::parse(req.body);
  1969. // Validate parameters first
  1970. auto [isValid, errorMessage] = validateGenerationParameters(requestJson);
  1971. if (!isValid) {
  1972. sendErrorResponse(res, errorMessage, 400, "INVALID_PARAMETERS", requestId);
  1973. return;
  1974. }
  1975. // Create a temporary request to estimate
  1976. GenerationRequest genRequest;
  1977. genRequest.prompt = requestJson["prompt"];
  1978. genRequest.width = requestJson.value("width", 512);
  1979. genRequest.height = requestJson.value("height", 512);
  1980. genRequest.batchCount = requestJson.value("batch_count", 1);
  1981. genRequest.steps = requestJson.value("steps", 20);
  1982. genRequest.diffusionFlashAttn = requestJson.value("diffusion_flash_attn", false);
  1983. genRequest.controlNetPath = requestJson.value("control_net_path", "");
  1984. if (requestJson.contains("sampling_method")) {
  1985. genRequest.samplingMethod = parseSamplingMethod(requestJson["sampling_method"]);
  1986. }
  1987. // Calculate estimates
  1988. uint64_t estimatedTime = estimateGenerationTime(genRequest);
  1989. size_t estimatedMemory = estimateMemoryUsage(genRequest);
  1990. json response = {
  1991. {"request_id", requestId},
  1992. {"estimated_time_seconds", estimatedTime / 1000},
  1993. {"estimated_memory_mb", estimatedMemory / (1024 * 1024)},
  1994. {"parameters", {
  1995. {"resolution", std::to_string(genRequest.width) + "x" + std::to_string(genRequest.height)},
  1996. {"steps", genRequest.steps},
  1997. {"batch_count", genRequest.batchCount},
  1998. {"sampling_method", samplingMethodToString(genRequest.samplingMethod)}
  1999. }}
  2000. };
  2001. sendJsonResponse(res, response);
  2002. } catch (const json::parse_error& e) {
  2003. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  2004. } catch (const std::exception& e) {
  2005. sendErrorResponse(res, std::string("Estimation failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  2006. }
  2007. }
  2008. void Server::handleConfig(const httplib::Request& req, httplib::Response& res) {
  2009. std::string requestId = generateRequestId();
  2010. try {
  2011. // Get current configuration
  2012. json config = {
  2013. {"request_id", requestId},
  2014. {"config", {
  2015. {"server", {
  2016. {"host", m_host},
  2017. {"port", m_port},
  2018. {"max_concurrent_generations", 1}
  2019. }},
  2020. {"generation", {
  2021. {"default_width", 512},
  2022. {"default_height", 512},
  2023. {"default_steps", 20},
  2024. {"default_cfg_scale", 7.5},
  2025. {"max_batch_count", 100},
  2026. {"max_steps", 150},
  2027. {"max_resolution", 2048}
  2028. }},
  2029. {"rate_limiting", {
  2030. {"requests_per_minute", 60},
  2031. {"enabled", true}
  2032. }}
  2033. }}
  2034. };
  2035. sendJsonResponse(res, config);
  2036. } catch (const std::exception& e) {
  2037. sendErrorResponse(res, std::string("Config operation failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  2038. }
  2039. }
  2040. void Server::handleSystem(const httplib::Request& req, httplib::Response& res) {
  2041. try {
  2042. json system = {
  2043. {"system", {
  2044. {"version", "1.0.0"},
  2045. {"build", "stable-diffusion.cpp-rest"},
  2046. {"uptime", std::chrono::duration_cast<std::chrono::seconds>(
  2047. std::chrono::steady_clock::now().time_since_epoch()).count()},
  2048. {"capabilities", {
  2049. {"text2img", true},
  2050. {"img2img", true},
  2051. {"controlnet", true},
  2052. {"batch_generation", true},
  2053. {"parameter_validation", true},
  2054. {"estimation", true}
  2055. }},
  2056. {"supported_formats", {
  2057. {"input", {"png", "jpg", "jpeg", "webp"}},
  2058. {"output", {"png", "jpg", "jpeg", "webp"}}
  2059. }},
  2060. {"limits", {
  2061. {"max_resolution", 2048},
  2062. {"max_steps", 150},
  2063. {"max_batch_count", 100},
  2064. {"max_prompt_length", 10000}
  2065. }}
  2066. }},
  2067. {"hardware", {
  2068. {"cpu_threads", std::thread::hardware_concurrency()}
  2069. }}
  2070. };
  2071. sendJsonResponse(res, system);
  2072. } catch (const std::exception& e) {
  2073. sendErrorResponse(res, std::string("System info failed: ") + e.what(), 500);
  2074. }
  2075. }
  2076. void Server::handleSystemRestart(const httplib::Request& req, httplib::Response& res) {
  2077. try {
  2078. json response = {
  2079. {"message", "Server restart initiated. The server will shut down gracefully and exit. Please use a process manager to automatically restart it."},
  2080. {"status", "restarting"}
  2081. };
  2082. sendJsonResponse(res, response);
  2083. // Schedule server stop after response is sent
  2084. // Using a separate thread to allow the response to be sent first
  2085. std::thread([this]() {
  2086. std::this_thread::sleep_for(std::chrono::seconds(1));
  2087. this->stop();
  2088. // Exit with code 42 to signal restart intent to process manager
  2089. std::exit(42);
  2090. }).detach();
  2091. } catch (const std::exception& e) {
  2092. sendErrorResponse(res, std::string("Restart failed: ") + e.what(), 500);
  2093. }
  2094. }
  2095. // Helper methods for model management
  2096. json Server::getModelCapabilities(ModelType type) {
  2097. json capabilities = json::object();
  2098. switch (type) {
  2099. case ModelType::CHECKPOINT:
  2100. capabilities = {
  2101. {"text2img", true},
  2102. {"img2img", true},
  2103. {"inpainting", true},
  2104. {"outpainting", true},
  2105. {"controlnet", true},
  2106. {"lora", true},
  2107. {"vae", true},
  2108. {"sampling_methods", {"euler", "euler_a", "heun", "dpm2", "dpm++2s_a", "dpm++2m", "dpm++2mv2", "ipndm", "ipndm_v", "lcm", "ddim_trailing", "tcd"}},
  2109. {"schedulers", {"discrete", "karras", "exponential", "ays", "gits", "smoothstep", "sgm_uniform", "simple"}},
  2110. {"recommended_resolution", "512x512"},
  2111. {"max_resolution", "2048x2048"},
  2112. {"supports_batch", true}
  2113. };
  2114. break;
  2115. case ModelType::LORA:
  2116. capabilities = {
  2117. {"text2img", true},
  2118. {"img2img", true},
  2119. {"inpainting", true},
  2120. {"controlnet", false},
  2121. {"lora", true},
  2122. {"vae", false},
  2123. {"requires_checkpoint", true},
  2124. {"strength_range", {0.0, 2.0}},
  2125. {"recommended_strength", 1.0}
  2126. };
  2127. break;
  2128. case ModelType::CONTROLNET:
  2129. capabilities = {
  2130. {"text2img", false},
  2131. {"img2img", true},
  2132. {"inpainting", true},
  2133. {"controlnet", true},
  2134. {"requires_checkpoint", true},
  2135. {"control_modes", {"canny", "depth", "pose", "scribble", "hed", "mlsd", "normal", "seg"}},
  2136. {"strength_range", {0.0, 1.0}},
  2137. {"recommended_strength", 0.9}
  2138. };
  2139. break;
  2140. case ModelType::VAE:
  2141. capabilities = {
  2142. {"text2img", false},
  2143. {"img2img", false},
  2144. {"inpainting", false},
  2145. {"vae", true},
  2146. {"requires_checkpoint", true},
  2147. {"encoding", true},
  2148. {"decoding", true},
  2149. {"precision", {"fp16", "fp32"}}
  2150. };
  2151. break;
  2152. case ModelType::EMBEDDING:
  2153. capabilities = {
  2154. {"text2img", true},
  2155. {"img2img", true},
  2156. {"inpainting", true},
  2157. {"embedding", true},
  2158. {"requires_checkpoint", true},
  2159. {"token_count", 1},
  2160. {"compatible_with", {"checkpoint", "lora"}}
  2161. };
  2162. break;
  2163. case ModelType::TAESD:
  2164. capabilities = {
  2165. {"text2img", false},
  2166. {"img2img", false},
  2167. {"inpainting", false},
  2168. {"vae", true},
  2169. {"requires_checkpoint", true},
  2170. {"fast_decoding", true},
  2171. {"real_time", true},
  2172. {"precision", {"fp16", "fp32"}}
  2173. };
  2174. break;
  2175. case ModelType::ESRGAN:
  2176. capabilities = {
  2177. {"text2img", false},
  2178. {"img2img", false},
  2179. {"inpainting", false},
  2180. {"upscaling", true},
  2181. {"scale_factors", {2, 4}},
  2182. {"models", {"ESRGAN", "RealESRGAN", "SwinIR"}},
  2183. {"supports_alpha", false}
  2184. };
  2185. break;
  2186. default:
  2187. capabilities = {
  2188. {"text2img", false},
  2189. {"img2img", false},
  2190. {"inpainting", false},
  2191. {"capabilities", {}}
  2192. };
  2193. break;
  2194. }
  2195. return capabilities;
  2196. }
  2197. json Server::getModelTypeStatistics() {
  2198. if (!m_modelManager) return json::object();
  2199. json stats = json::object();
  2200. auto allModels = m_modelManager->getAllModels();
  2201. // Initialize counters for each type
  2202. std::map<ModelType, int> typeCounts;
  2203. std::map<ModelType, int> loadedCounts;
  2204. std::map<ModelType, size_t> sizeByType;
  2205. for (const auto& pair : allModels) {
  2206. ModelType type = pair.second.type;
  2207. typeCounts[type]++;
  2208. if (pair.second.isLoaded) {
  2209. loadedCounts[type]++;
  2210. }
  2211. sizeByType[type] += pair.second.fileSize;
  2212. }
  2213. // Build statistics JSON
  2214. for (const auto& count : typeCounts) {
  2215. std::string typeName = ModelManager::modelTypeToString(count.first);
  2216. stats[typeName] = {
  2217. {"total_count", count.second},
  2218. {"loaded_count", loadedCounts[count.first]},
  2219. {"total_size_bytes", sizeByType[count.first]},
  2220. {"total_size_mb", sizeByType[count.first] / (1024.0 * 1024.0)},
  2221. {"average_size_mb", count.second > 0 ? (sizeByType[count.first] / (1024.0 * 1024.0)) / count.second : 0.0}
  2222. };
  2223. }
  2224. return stats;
  2225. }
  2226. // Additional helper methods for model management
  2227. json Server::getModelCompatibility(const ModelManager::ModelInfo& modelInfo) {
  2228. json compatibility = {
  2229. {"is_compatible", true},
  2230. {"compatibility_score", 100},
  2231. {"issues", json::array()},
  2232. {"warnings", json::array()},
  2233. {"requirements", {
  2234. {"min_memory_mb", 1024},
  2235. {"recommended_memory_mb", 2048},
  2236. {"supported_formats", {"safetensors", "ckpt", "gguf"}},
  2237. {"required_dependencies", {}}
  2238. }}
  2239. };
  2240. // Check for specific compatibility issues based on model type
  2241. if (modelInfo.type == ModelType::LORA) {
  2242. compatibility["requirements"]["required_dependencies"] = {"checkpoint"};
  2243. } else if (modelInfo.type == ModelType::CONTROLNET) {
  2244. compatibility["requirements"]["required_dependencies"] = {"checkpoint"};
  2245. } else if (modelInfo.type == ModelType::VAE) {
  2246. compatibility["requirements"]["required_dependencies"] = {"checkpoint"};
  2247. }
  2248. return compatibility;
  2249. }
  2250. json Server::getModelRequirements(ModelType type) {
  2251. json requirements = {
  2252. {"min_memory_mb", 1024},
  2253. {"recommended_memory_mb", 2048},
  2254. {"min_disk_space_mb", 1024},
  2255. {"supported_formats", {"safetensors", "ckpt", "gguf"}},
  2256. {"required_dependencies", json::array()},
  2257. {"optional_dependencies", json::array()},
  2258. {"system_requirements", {
  2259. {"cpu_cores", 4},
  2260. {"cpu_architecture", "x86_64"},
  2261. {"os", "Linux/Windows/macOS"},
  2262. {"gpu_memory_mb", 2048},
  2263. {"gpu_compute_capability", "3.5+"}
  2264. }}
  2265. };
  2266. switch (type) {
  2267. case ModelType::CHECKPOINT:
  2268. requirements["min_memory_mb"] = 2048;
  2269. requirements["recommended_memory_mb"] = 4096;
  2270. requirements["min_disk_space_mb"] = 2048;
  2271. requirements["supported_formats"] = {"safetensors", "ckpt", "gguf"};
  2272. break;
  2273. case ModelType::LORA:
  2274. requirements["min_memory_mb"] = 512;
  2275. requirements["recommended_memory_mb"] = 1024;
  2276. requirements["min_disk_space_mb"] = 100;
  2277. requirements["supported_formats"] = {"safetensors", "ckpt"};
  2278. requirements["required_dependencies"] = {"checkpoint"};
  2279. break;
  2280. case ModelType::CONTROLNET:
  2281. requirements["min_memory_mb"] = 1024;
  2282. requirements["recommended_memory_mb"] = 2048;
  2283. requirements["min_disk_space_mb"] = 500;
  2284. requirements["supported_formats"] = {"safetensors", "pth"};
  2285. requirements["required_dependencies"] = {"checkpoint"};
  2286. break;
  2287. case ModelType::VAE:
  2288. requirements["min_memory_mb"] = 512;
  2289. requirements["recommended_memory_mb"] = 1024;
  2290. requirements["min_disk_space_mb"] = 200;
  2291. requirements["supported_formats"] = {"safetensors", "pt", "ckpt", "gguf"};
  2292. requirements["required_dependencies"] = {"checkpoint"};
  2293. break;
  2294. case ModelType::EMBEDDING:
  2295. requirements["min_memory_mb"] = 64;
  2296. requirements["recommended_memory_mb"] = 256;
  2297. requirements["min_disk_space_mb"] = 10;
  2298. requirements["supported_formats"] = {"safetensors", "pt"};
  2299. requirements["required_dependencies"] = {"checkpoint"};
  2300. break;
  2301. case ModelType::TAESD:
  2302. requirements["min_memory_mb"] = 256;
  2303. requirements["recommended_memory_mb"] = 512;
  2304. requirements["min_disk_space_mb"] = 100;
  2305. requirements["supported_formats"] = {"safetensors", "pth", "gguf"};
  2306. requirements["required_dependencies"] = {"checkpoint"};
  2307. break;
  2308. case ModelType::ESRGAN:
  2309. requirements["min_memory_mb"] = 1024;
  2310. requirements["recommended_memory_mb"] = 2048;
  2311. requirements["min_disk_space_mb"] = 500;
  2312. requirements["supported_formats"] = {"pth", "pt"};
  2313. requirements["optional_dependencies"] = {"checkpoint"};
  2314. break;
  2315. default:
  2316. break;
  2317. }
  2318. return requirements;
  2319. }
  2320. json Server::getRecommendedUsage(ModelType type) {
  2321. json usage = {
  2322. {"text2img", false},
  2323. {"img2img", false},
  2324. {"inpainting", false},
  2325. {"controlnet", false},
  2326. {"lora", false},
  2327. {"vae", false},
  2328. {"recommended_resolution", "512x512"},
  2329. {"recommended_steps", 20},
  2330. {"recommended_cfg_scale", 7.5},
  2331. {"recommended_batch_size", 1}
  2332. };
  2333. switch (type) {
  2334. case ModelType::CHECKPOINT:
  2335. usage = {
  2336. {"text2img", true},
  2337. {"img2img", true},
  2338. {"inpainting", true},
  2339. {"controlnet", true},
  2340. {"lora", true},
  2341. {"vae", true},
  2342. {"recommended_resolution", "512x512"},
  2343. {"recommended_steps", 20},
  2344. {"recommended_cfg_scale", 7.5},
  2345. {"recommended_batch_size", 1}
  2346. };
  2347. break;
  2348. case ModelType::LORA:
  2349. usage = {
  2350. {"text2img", true},
  2351. {"img2img", true},
  2352. {"inpainting", true},
  2353. {"controlnet", false},
  2354. {"lora", true},
  2355. {"vae", false},
  2356. {"recommended_strength", 1.0},
  2357. {"recommended_usage", "Style transfer, character customization"}
  2358. };
  2359. break;
  2360. case ModelType::CONTROLNET:
  2361. usage = {
  2362. {"text2img", false},
  2363. {"img2img", true},
  2364. {"inpainting", true},
  2365. {"controlnet", true},
  2366. {"lora", false},
  2367. {"vae", false},
  2368. {"recommended_strength", 0.9},
  2369. {"recommended_usage", "Precise control over output"}
  2370. };
  2371. break;
  2372. case ModelType::VAE:
  2373. usage = {
  2374. {"text2img", false},
  2375. {"img2img", false},
  2376. {"inpainting", false},
  2377. {"controlnet", false},
  2378. {"lora", false},
  2379. {"vae", true},
  2380. {"recommended_usage", "Improved encoding/decoding quality"}
  2381. };
  2382. break;
  2383. case ModelType::EMBEDDING:
  2384. usage = {
  2385. {"text2img", true},
  2386. {"img2img", true},
  2387. {"inpainting", true},
  2388. {"controlnet", false},
  2389. {"lora", false},
  2390. {"vae", false},
  2391. {"embedding", true},
  2392. {"recommended_usage", "Concept control, style words"}
  2393. };
  2394. break;
  2395. case ModelType::TAESD:
  2396. usage = {
  2397. {"text2img", false},
  2398. {"img2img", false},
  2399. {"inpainting", false},
  2400. {"controlnet", false},
  2401. {"lora", false},
  2402. {"vae", true},
  2403. {"recommended_usage", "Real-time decoding"}
  2404. };
  2405. break;
  2406. case ModelType::ESRGAN:
  2407. usage = {
  2408. {"text2img", false},
  2409. {"img2img", false},
  2410. {"inpainting", false},
  2411. {"controlnet", false},
  2412. {"lora", false},
  2413. {"vae", false},
  2414. {"upscaling", true},
  2415. {"recommended_usage", "Image upscaling and quality enhancement"}
  2416. };
  2417. break;
  2418. default:
  2419. break;
  2420. }
  2421. return usage;
  2422. }
  2423. std::string Server::getModelTypeFromDirectoryName(const std::string& dirName) {
  2424. if (dirName == "stable-diffusion" || dirName == "checkpoints") {
  2425. return "checkpoint";
  2426. } else if (dirName == "lora") {
  2427. return "lora";
  2428. } else if (dirName == "controlnet") {
  2429. return "controlnet";
  2430. } else if (dirName == "vae") {
  2431. return "vae";
  2432. } else if (dirName == "taesd") {
  2433. return "taesd";
  2434. } else if (dirName == "esrgan" || dirName == "upscaler") {
  2435. return "esrgan";
  2436. } else if (dirName == "embeddings" || dirName == "textual-inversion") {
  2437. return "embedding";
  2438. } else {
  2439. return "unknown";
  2440. }
  2441. }
  2442. std::string Server::getDirectoryDescription(const std::string& dirName) {
  2443. if (dirName == "stable-diffusion" || dirName == "checkpoints") {
  2444. return "Main stable diffusion model files";
  2445. } else if (dirName == "lora") {
  2446. return "LoRA adapter models for style transfer";
  2447. } else if (dirName == "controlnet") {
  2448. return "ControlNet models for precise control";
  2449. } else if (dirName == "vae") {
  2450. return "VAE models for improved encoding/decoding";
  2451. } else if (dirName == "taesd") {
  2452. return "TAESD models for real-time decoding";
  2453. } else if (dirName == "esrgan" || dirName == "upscaler") {
  2454. return "ESRGAN models for image upscaling";
  2455. } else if (dirName == "embeddings" || dirName == "textual-inversion") {
  2456. return "Text embeddings for concept control";
  2457. } else {
  2458. return "Unknown model directory";
  2459. }
  2460. }
  2461. json Server::getDirectoryContents(const std::string& dirPath) {
  2462. json contents = json::array();
  2463. try {
  2464. if (std::filesystem::exists(dirPath) && std::filesystem::is_directory(dirPath)) {
  2465. for (const auto& entry : std::filesystem::directory_iterator(dirPath)) {
  2466. if (entry.is_regular_file()) {
  2467. json file = {
  2468. {"name", entry.path().filename().string()},
  2469. {"path", entry.path().string()},
  2470. {"size", std::filesystem::file_size(entry.path())},
  2471. {"size_mb", std::filesystem::file_size(entry.path()) / (1024.0 * 1024.0)},
  2472. {"last_modified", std::chrono::duration_cast<std::chrono::seconds>(
  2473. std::filesystem::last_write_time(entry.path()).time_since_epoch()).count()}
  2474. };
  2475. contents.push_back(file);
  2476. }
  2477. }
  2478. }
  2479. } catch (const std::exception& e) {
  2480. // Return empty array if directory access fails
  2481. }
  2482. return contents;
  2483. }
  2484. json Server::getLargestModel(const std::map<std::string, ModelManager::ModelInfo>& allModels) {
  2485. json largest = json::object();
  2486. size_t maxSize = 0;
  2487. std::string largestName;
  2488. for (const auto& pair : allModels) {
  2489. if (pair.second.fileSize > maxSize) {
  2490. maxSize = pair.second.fileSize;
  2491. largestName = pair.second.name;
  2492. }
  2493. }
  2494. if (!largestName.empty()) {
  2495. largest = {
  2496. {"name", largestName},
  2497. {"size", maxSize},
  2498. {"size_mb", maxSize / (1024.0 * 1024.0)},
  2499. {"type", ModelManager::modelTypeToString(allModels.at(largestName).type)}
  2500. };
  2501. }
  2502. return largest;
  2503. }
  2504. json Server::getSmallestModel(const std::map<std::string, ModelManager::ModelInfo>& allModels) {
  2505. json smallest = json::object();
  2506. size_t minSize = SIZE_MAX;
  2507. std::string smallestName;
  2508. for (const auto& pair : allModels) {
  2509. if (pair.second.fileSize < minSize) {
  2510. minSize = pair.second.fileSize;
  2511. smallestName = pair.second.name;
  2512. }
  2513. }
  2514. if (!smallestName.empty()) {
  2515. smallest = {
  2516. {"name", smallestName},
  2517. {"size", minSize},
  2518. {"size_mb", minSize / (1024.0 * 1024.0)},
  2519. {"type", ModelManager::modelTypeToString(allModels.at(smallestName).type)}
  2520. };
  2521. }
  2522. return smallest;
  2523. }
  2524. json Server::validateModelFile(const std::string& modelPath, const std::string& modelType) {
  2525. json validation = {
  2526. {"is_valid", false},
  2527. {"errors", json::array()},
  2528. {"warnings", json::array()},
  2529. {"file_info", json::object()},
  2530. {"compatibility", json::object()},
  2531. {"recommendations", json::array()}
  2532. };
  2533. try {
  2534. if (!std::filesystem::exists(modelPath)) {
  2535. validation["errors"].push_back("File does not exist");
  2536. return validation;
  2537. }
  2538. if (!std::filesystem::is_regular_file(modelPath)) {
  2539. validation["errors"].push_back("Path is not a regular file");
  2540. return validation;
  2541. }
  2542. // Check file extension
  2543. std::string extension = std::filesystem::path(modelPath).extension().string();
  2544. if (extension.empty()) {
  2545. validation["errors"].push_back("Missing file extension");
  2546. return validation;
  2547. }
  2548. // Remove dot and convert to lowercase
  2549. if (extension[0] == '.') {
  2550. extension = extension.substr(1);
  2551. }
  2552. std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower);
  2553. // Validate extension based on model type
  2554. ModelType type = ModelManager::stringToModelType(modelType);
  2555. bool validExtension = false;
  2556. switch (type) {
  2557. case ModelType::CHECKPOINT:
  2558. validExtension = (extension == "safetensors" || extension == "ckpt" || extension == "gguf");
  2559. break;
  2560. case ModelType::LORA:
  2561. validExtension = (extension == "safetensors" || extension == "ckpt");
  2562. break;
  2563. case ModelType::CONTROLNET:
  2564. validExtension = (extension == "safetensors" || extension == "pth");
  2565. break;
  2566. case ModelType::VAE:
  2567. validExtension = (extension == "safetensors" || extension == "pt" || extension == "ckpt" || extension == "gguf");
  2568. break;
  2569. case ModelType::EMBEDDING:
  2570. validExtension = (extension == "safetensors" || extension == "pt");
  2571. break;
  2572. case ModelType::TAESD:
  2573. validExtension = (extension == "safetensors" || extension == "pth" || extension == "gguf");
  2574. break;
  2575. case ModelType::ESRGAN:
  2576. validExtension = (extension == "pth" || extension == "pt");
  2577. break;
  2578. default:
  2579. break;
  2580. }
  2581. if (!validExtension) {
  2582. validation["errors"].push_back("Invalid file extension for model type: " + extension);
  2583. }
  2584. // Check file size
  2585. size_t fileSize = std::filesystem::file_size(modelPath);
  2586. if (fileSize == 0) {
  2587. validation["errors"].push_back("File is empty");
  2588. } else if (fileSize > 8ULL * 1024 * 1024 * 1024) { // 8GB
  2589. validation["warnings"].push_back("Very large file may cause performance issues");
  2590. }
  2591. // Build file info
  2592. validation["file_info"] = {
  2593. {"path", modelPath},
  2594. {"size", fileSize},
  2595. {"size_mb", fileSize / (1024.0 * 1024.0)},
  2596. {"extension", extension},
  2597. {"last_modified", std::chrono::duration_cast<std::chrono::seconds>(
  2598. std::filesystem::last_write_time(modelPath).time_since_epoch()).count()}
  2599. };
  2600. // Check compatibility
  2601. validation["compatibility"] = {
  2602. {"extension_valid", validExtension},
  2603. {"size_appropriate", fileSize <= 4ULL * 1024 * 1024 * 1024}, // 4GB
  2604. {"recommended_format", "safetensors"}
  2605. };
  2606. // Add recommendations
  2607. if (!validExtension) {
  2608. validation["recommendations"].push_back("Convert to SafeTensors format for better security and performance");
  2609. }
  2610. if (fileSize > 2ULL * 1024 * 1024 * 1024) { // 2GB
  2611. validation["recommendations"].push_back("Consider using a smaller model for better performance");
  2612. }
  2613. // If no errors found, mark as valid
  2614. if (validation["errors"].empty()) {
  2615. validation["is_valid"] = true;
  2616. }
  2617. } catch (const std::exception& e) {
  2618. validation["errors"].push_back("Validation failed: " + std::string(e.what()));
  2619. }
  2620. return validation;
  2621. }
  2622. json Server::checkModelCompatibility(const ModelManager::ModelInfo& modelInfo, const std::string& systemInfo) {
  2623. json compatibility = {
  2624. {"is_compatible", true},
  2625. {"compatibility_score", 100},
  2626. {"issues", json::array()},
  2627. {"warnings", json::array()},
  2628. {"requirements", json::object()},
  2629. {"recommendations", json::array()},
  2630. {"system_info", json::object()}
  2631. };
  2632. // Check system compatibility
  2633. if (systemInfo == "auto") {
  2634. compatibility["system_info"] = {
  2635. {"cpu_cores", std::thread::hardware_concurrency()}
  2636. };
  2637. }
  2638. // Check model-specific compatibility issues
  2639. if (modelInfo.type == ModelType::CHECKPOINT) {
  2640. if (modelInfo.fileSize > 4ULL * 1024 * 1024 * 1024) { // 4GB
  2641. compatibility["warnings"].push_back("Large checkpoint model may require significant memory");
  2642. compatibility["compatibility_score"] = 80;
  2643. }
  2644. if (modelInfo.fileSize < 500 * 1024 * 1024) { // 500MB
  2645. compatibility["warnings"].push_back("Small checkpoint model may have limited capabilities");
  2646. compatibility["compatibility_score"] = 85;
  2647. }
  2648. } else if (modelInfo.type == ModelType::LORA) {
  2649. if (modelInfo.fileSize > 500 * 1024 * 1024) { // 500MB
  2650. compatibility["warnings"].push_back("Large LoRA may impact performance");
  2651. compatibility["compatibility_score"] = 75;
  2652. }
  2653. }
  2654. return compatibility;
  2655. }
  2656. json Server::calculateSpecificRequirements(const std::string& modelType, const std::string& resolution, const std::string& batchSize) {
  2657. json specific = {
  2658. {"memory_requirements", json::object()},
  2659. {"performance_impact", json::object()},
  2660. {"quality_expectations", json::object()}
  2661. };
  2662. // Parse resolution
  2663. int width = 512, height = 512;
  2664. try {
  2665. size_t xPos = resolution.find('x');
  2666. if (xPos != std::string::npos) {
  2667. width = std::stoi(resolution.substr(0, xPos));
  2668. height = std::stoi(resolution.substr(xPos + 1));
  2669. }
  2670. } catch (...) {
  2671. // Use defaults if parsing fails
  2672. }
  2673. // Parse batch size
  2674. int batch = 1;
  2675. try {
  2676. batch = std::stoi(batchSize);
  2677. } catch (...) {
  2678. // Use default if parsing fails
  2679. }
  2680. // Calculate memory requirements based on resolution and batch
  2681. size_t pixels = width * height;
  2682. size_t baseMemory = 1024 * 1024 * 1024; // 1GB base
  2683. size_t resolutionMemory = (pixels * 4) / (512 * 512); // Scale based on 512x512
  2684. size_t batchMemory = (batch - 1) * baseMemory * 0.5; // Additional memory for batch
  2685. specific["memory_requirements"] = {
  2686. {"base_memory_mb", baseMemory / (1024 * 1024)},
  2687. {"resolution_memory_mb", resolutionMemory / (1024 * 1024)},
  2688. {"batch_memory_mb", batchMemory / (1024 * 1024)},
  2689. {"total_memory_mb", (baseMemory + resolutionMemory + batchMemory) / (1024 * 1024)}
  2690. };
  2691. // Calculate performance impact
  2692. double performanceFactor = 1.0;
  2693. if (pixels > 512 * 512) {
  2694. performanceFactor = 1.5;
  2695. }
  2696. if (batch > 1) {
  2697. performanceFactor *= 1.2;
  2698. }
  2699. specific["performance_impact"] = {
  2700. {"resolution_factor", pixels > 512 * 512 ? 1.5 : 1.0},
  2701. {"batch_factor", batch > 1 ? 1.2 : 1.0},
  2702. {"overall_factor", performanceFactor}
  2703. };
  2704. return specific;
  2705. }
  2706. // Enhanced model management endpoint implementations
  2707. void Server::handleModelInfo(const httplib::Request& req, httplib::Response& res) {
  2708. std::string requestId = generateRequestId();
  2709. try {
  2710. if (!m_modelManager) {
  2711. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  2712. return;
  2713. }
  2714. // Extract model ID from URL path
  2715. std::string modelId = req.matches[1].str();
  2716. if (modelId.empty()) {
  2717. sendErrorResponse(res, "Missing model ID", 400, "MISSING_MODEL_ID", requestId);
  2718. return;
  2719. }
  2720. // Get model information
  2721. auto modelInfo = m_modelManager->getModelInfo(modelId);
  2722. if (modelInfo.name.empty()) {
  2723. sendErrorResponse(res, "Model not found", 404, "MODEL_NOT_FOUND", requestId);
  2724. return;
  2725. }
  2726. // Build comprehensive model information
  2727. json response = {
  2728. {"model", {
  2729. {"name", modelInfo.name},
  2730. {"path", modelInfo.path},
  2731. {"type", ModelManager::modelTypeToString(modelInfo.type)},
  2732. {"is_loaded", modelInfo.isLoaded},
  2733. {"file_size", modelInfo.fileSize},
  2734. {"file_size_mb", modelInfo.fileSize / (1024.0 * 1024.0)},
  2735. {"description", modelInfo.description},
  2736. {"metadata", modelInfo.metadata},
  2737. {"capabilities", getModelCapabilities(modelInfo.type)},
  2738. {"compatibility", getModelCompatibility(modelInfo)},
  2739. {"requirements", getModelRequirements(modelInfo.type)},
  2740. {"recommended_usage", getRecommendedUsage(modelInfo.type)},
  2741. {"last_modified", std::chrono::duration_cast<std::chrono::seconds>(
  2742. modelInfo.modifiedAt.time_since_epoch()).count()}
  2743. }},
  2744. {"request_id", requestId}
  2745. };
  2746. sendJsonResponse(res, response);
  2747. } catch (const std::exception& e) {
  2748. sendErrorResponse(res, std::string("Failed to get model info: ") + e.what(), 500, "MODEL_INFO_ERROR", requestId);
  2749. }
  2750. }
  2751. void Server::handleLoadModelById(const httplib::Request& req, httplib::Response& res) {
  2752. std::string requestId = generateRequestId();
  2753. try {
  2754. if (!m_modelManager) {
  2755. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  2756. return;
  2757. }
  2758. // Extract model ID from URL path (could be hash or name)
  2759. std::string modelIdentifier = req.matches[1].str();
  2760. if (modelIdentifier.empty()) {
  2761. sendErrorResponse(res, "Missing model identifier", 400, "MISSING_MODEL_ID", requestId);
  2762. return;
  2763. }
  2764. // Try to find by hash first (if it looks like a hash - 10+ hex chars)
  2765. std::string modelId = modelIdentifier;
  2766. if (modelIdentifier.length() >= 10 &&
  2767. std::all_of(modelIdentifier.begin(), modelIdentifier.end(),
  2768. [](char c) { return std::isxdigit(c); })) {
  2769. std::string foundName = m_modelManager->findModelByHash(modelIdentifier);
  2770. if (!foundName.empty()) {
  2771. modelId = foundName;
  2772. std::cout << "Resolved hash " << modelIdentifier << " to model: " << modelId << std::endl;
  2773. }
  2774. }
  2775. // Parse optional parameters from request body
  2776. json requestJson;
  2777. if (!req.body.empty()) {
  2778. try {
  2779. requestJson = json::parse(req.body);
  2780. } catch (const json::parse_error& e) {
  2781. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  2782. return;
  2783. }
  2784. }
  2785. // Unload previous model if one is loaded
  2786. std::string previousModel;
  2787. {
  2788. std::lock_guard<std::mutex> lock(m_currentModelMutex);
  2789. previousModel = m_currentlyLoadedModel;
  2790. }
  2791. if (!previousModel.empty() && previousModel != modelId) {
  2792. std::cout << "Unloading previous model: " << previousModel << std::endl;
  2793. m_modelManager->unloadModel(previousModel);
  2794. }
  2795. // Load model
  2796. bool success = m_modelManager->loadModel(modelId);
  2797. if (success) {
  2798. // Update currently loaded model
  2799. {
  2800. std::lock_guard<std::mutex> lock(m_currentModelMutex);
  2801. m_currentlyLoadedModel = modelId;
  2802. }
  2803. auto modelInfo = m_modelManager->getModelInfo(modelId);
  2804. json response = {
  2805. {"status", "success"},
  2806. {"model", {
  2807. {"name", modelInfo.name},
  2808. {"path", modelInfo.path},
  2809. {"type", ModelManager::modelTypeToString(modelInfo.type)},
  2810. {"is_loaded", modelInfo.isLoaded}
  2811. }},
  2812. {"request_id", requestId}
  2813. };
  2814. sendJsonResponse(res, response);
  2815. } else {
  2816. sendErrorResponse(res, "Failed to load model", 400, "MODEL_LOAD_FAILED", requestId);
  2817. }
  2818. } catch (const std::exception& e) {
  2819. sendErrorResponse(res, std::string("Model load failed: ") + e.what(), 500, "MODEL_LOAD_ERROR", requestId);
  2820. }
  2821. }
  2822. void Server::handleUnloadModelById(const httplib::Request& req, httplib::Response& res) {
  2823. std::string requestId = generateRequestId();
  2824. try {
  2825. if (!m_modelManager) {
  2826. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  2827. return;
  2828. }
  2829. // Extract model ID from URL path
  2830. std::string modelId = req.matches[1].str();
  2831. if (modelId.empty()) {
  2832. sendErrorResponse(res, "Missing model ID", 400, "MISSING_MODEL_ID", requestId);
  2833. return;
  2834. }
  2835. // Unload model
  2836. bool success = m_modelManager->unloadModel(modelId);
  2837. if (success) {
  2838. // Clear currently loaded model if it matches
  2839. {
  2840. std::lock_guard<std::mutex> lock(m_currentModelMutex);
  2841. if (m_currentlyLoadedModel == modelId) {
  2842. m_currentlyLoadedModel = "";
  2843. }
  2844. }
  2845. json response = {
  2846. {"status", "success"},
  2847. {"model", {
  2848. {"name", modelId},
  2849. {"is_loaded", false}
  2850. }},
  2851. {"request_id", requestId}
  2852. };
  2853. sendJsonResponse(res, response);
  2854. } else {
  2855. sendErrorResponse(res, "Failed to unload model or model not found", 404, "MODEL_UNLOAD_FAILED", requestId);
  2856. }
  2857. } catch (const std::exception& e) {
  2858. sendErrorResponse(res, std::string("Model unload failed: ") + e.what(), 500, "MODEL_UNLOAD_ERROR", requestId);
  2859. }
  2860. }
  2861. void Server::handleModelTypes(const httplib::Request& req, httplib::Response& res) {
  2862. std::string requestId = generateRequestId();
  2863. try {
  2864. json types = {
  2865. {"model_types", {
  2866. {
  2867. {"type", "checkpoint"},
  2868. {"description", "Main stable diffusion model files for text-to-image, image-to-image, and inpainting"},
  2869. {"extensions", {"safetensors", "ckpt", "gguf"}},
  2870. {"capabilities", {"text2img", "img2img", "inpainting", "controlnet", "lora", "vae"}},
  2871. {"recommended_for", "General purpose image generation"}
  2872. },
  2873. {
  2874. {"type", "lora"},
  2875. {"description", "LoRA adapter models for style transfer and character customization"},
  2876. {"extensions", {"safetensors", "ckpt"}},
  2877. {"capabilities", {"style_transfer", "character_customization"}},
  2878. {"requires", {"checkpoint"}},
  2879. {"recommended_for", "Style modification and character-specific generation"}
  2880. },
  2881. {
  2882. {"type", "controlnet"},
  2883. {"description", "ControlNet models for precise control over output composition"},
  2884. {"extensions", {"safetensors", "pth"}},
  2885. {"capabilities", {"precise_control", "composition_control"}},
  2886. {"requires", {"checkpoint"}},
  2887. {"recommended_for", "Precise control over image generation"}
  2888. },
  2889. {
  2890. {"type", "vae"},
  2891. {"description", "VAE models for improved encoding and decoding quality"},
  2892. {"extensions", {"safetensors", "pt", "ckpt", "gguf"}},
  2893. {"capabilities", {"encoding", "decoding", "quality_improvement"}},
  2894. {"requires", {"checkpoint"}},
  2895. {"recommended_for", "Improved image quality and encoding"}
  2896. },
  2897. {
  2898. {"type", "embedding"},
  2899. {"description", "Text embeddings for concept control and style words"},
  2900. {"extensions", {"safetensors", "pt"}},
  2901. {"capabilities", {"concept_control", "style_words"}},
  2902. {"requires", {"checkpoint"}},
  2903. {"recommended_for", "Concept control and specific styles"}
  2904. },
  2905. {
  2906. {"type", "taesd"},
  2907. {"description", "TAESD models for real-time decoding"},
  2908. {"extensions", {"safetensors", "pth", "gguf"}},
  2909. {"capabilities", {"real_time_decoding", "fast_preview"}},
  2910. {"requires", {"checkpoint"}},
  2911. {"recommended_for", "Real-time applications and fast previews"}
  2912. },
  2913. {
  2914. {"type", "esrgan"},
  2915. {"description", "ESRGAN models for image upscaling and enhancement"},
  2916. {"extensions", {"pth", "pt"}},
  2917. {"capabilities", {"upscaling", "enhancement", "quality_improvement"}},
  2918. {"recommended_for", "Image upscaling and quality enhancement"}
  2919. }
  2920. }},
  2921. {"request_id", requestId}
  2922. };
  2923. sendJsonResponse(res, types);
  2924. } catch (const std::exception& e) {
  2925. sendErrorResponse(res, std::string("Failed to get model types: ") + e.what(), 500, "MODEL_TYPES_ERROR", requestId);
  2926. }
  2927. }
  2928. void Server::handleModelDirectories(const httplib::Request& req, httplib::Response& res) {
  2929. std::string requestId = generateRequestId();
  2930. try {
  2931. if (!m_modelManager) {
  2932. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  2933. return;
  2934. }
  2935. std::string modelsDir = m_modelManager->getModelsDirectory();
  2936. json directories = json::array();
  2937. // Define expected model directories
  2938. std::vector<std::string> modelDirs = {
  2939. "stable-diffusion", "checkpoints", "lora", "controlnet",
  2940. "vae", "taesd", "esrgan", "embeddings"
  2941. };
  2942. for (const auto& dirName : modelDirs) {
  2943. std::string dirPath = modelsDir + "/" + dirName;
  2944. std::string type = getModelTypeFromDirectoryName(dirName);
  2945. std::string description = getDirectoryDescription(dirName);
  2946. json dirInfo = {
  2947. {"name", dirName},
  2948. {"path", dirPath},
  2949. {"type", type},
  2950. {"description", description},
  2951. {"exists", std::filesystem::exists(dirPath) && std::filesystem::is_directory(dirPath)},
  2952. {"contents", getDirectoryContents(dirPath)}
  2953. };
  2954. directories.push_back(dirInfo);
  2955. }
  2956. json response = {
  2957. {"models_directory", modelsDir},
  2958. {"directories", directories},
  2959. {"request_id", requestId}
  2960. };
  2961. sendJsonResponse(res, response);
  2962. } catch (const std::exception& e) {
  2963. sendErrorResponse(res, std::string("Failed to get model directories: ") + e.what(), 500, "MODEL_DIRECTORIES_ERROR", requestId);
  2964. }
  2965. }
  2966. void Server::handleRefreshModels(const httplib::Request& req, httplib::Response& res) {
  2967. std::string requestId = generateRequestId();
  2968. try {
  2969. if (!m_modelManager) {
  2970. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  2971. return;
  2972. }
  2973. // Force refresh of model cache
  2974. bool success = m_modelManager->scanModelsDirectory();
  2975. if (success) {
  2976. json response = {
  2977. {"status", "success"},
  2978. {"message", "Model cache refreshed successfully"},
  2979. {"models_found", m_modelManager->getAvailableModelsCount()},
  2980. {"models_loaded", m_modelManager->getLoadedModelsCount()},
  2981. {"models_directory", m_modelManager->getModelsDirectory()},
  2982. {"request_id", requestId}
  2983. };
  2984. sendJsonResponse(res, response);
  2985. } else {
  2986. sendErrorResponse(res, "Failed to refresh model cache", 500, "MODEL_REFRESH_FAILED", requestId);
  2987. }
  2988. } catch (const std::exception& e) {
  2989. sendErrorResponse(res, std::string("Model refresh failed: ") + e.what(), 500, "MODEL_REFRESH_ERROR", requestId);
  2990. }
  2991. }
  2992. void Server::handleHashModels(const httplib::Request& req, httplib::Response& res) {
  2993. std::string requestId = generateRequestId();
  2994. try {
  2995. if (!m_generationQueue || !m_modelManager) {
  2996. sendErrorResponse(res, "Services not available", 500, "SERVICE_UNAVAILABLE", requestId);
  2997. return;
  2998. }
  2999. // Parse request body
  3000. json requestJson;
  3001. if (!req.body.empty()) {
  3002. requestJson = json::parse(req.body);
  3003. }
  3004. HashRequest hashReq;
  3005. hashReq.id = requestId;
  3006. hashReq.forceRehash = requestJson.value("force_rehash", false);
  3007. if (requestJson.contains("models") && requestJson["models"].is_array()) {
  3008. for (const auto& model : requestJson["models"]) {
  3009. hashReq.modelNames.push_back(model.get<std::string>());
  3010. }
  3011. }
  3012. // Enqueue hash request
  3013. auto future = m_generationQueue->enqueueHashRequest(hashReq);
  3014. json response = {
  3015. {"request_id", requestId},
  3016. {"status", "queued"},
  3017. {"message", "Hash job queued successfully"},
  3018. {"models_to_hash", hashReq.modelNames.empty() ? "all_unhashed" : std::to_string(hashReq.modelNames.size())}
  3019. };
  3020. sendJsonResponse(res, response, 202);
  3021. } catch (const json::parse_error& e) {
  3022. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  3023. } catch (const std::exception& e) {
  3024. sendErrorResponse(res, std::string("Hash request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  3025. }
  3026. }
  3027. void Server::handleConvertModel(const httplib::Request& req, httplib::Response& res) {
  3028. std::string requestId = generateRequestId();
  3029. try {
  3030. if (!m_generationQueue || !m_modelManager) {
  3031. sendErrorResponse(res, "Services not available", 500, "SERVICE_UNAVAILABLE", requestId);
  3032. return;
  3033. }
  3034. // Parse request body
  3035. json requestJson;
  3036. try {
  3037. requestJson = json::parse(req.body);
  3038. } catch (const json::parse_error& e) {
  3039. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  3040. return;
  3041. }
  3042. // Validate required fields
  3043. if (!requestJson.contains("model_name")) {
  3044. sendErrorResponse(res, "Missing required field: model_name", 400, "MISSING_FIELD", requestId);
  3045. return;
  3046. }
  3047. if (!requestJson.contains("quantization_type")) {
  3048. sendErrorResponse(res, "Missing required field: quantization_type", 400, "MISSING_FIELD", requestId);
  3049. return;
  3050. }
  3051. std::string modelName = requestJson["model_name"].get<std::string>();
  3052. std::string quantizationType = requestJson["quantization_type"].get<std::string>();
  3053. // Validate quantization type
  3054. const std::vector<std::string> validTypes = {"f32", "f16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "q2_K", "q3_K", "q4_K"};
  3055. if (std::find(validTypes.begin(), validTypes.end(), quantizationType) == validTypes.end()) {
  3056. sendErrorResponse(res, "Invalid quantization_type. Valid types: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K",
  3057. 400, "INVALID_QUANTIZATION_TYPE", requestId);
  3058. return;
  3059. }
  3060. // Get model info to find the full path
  3061. auto modelInfo = m_modelManager->getModelInfo(modelName);
  3062. if (modelInfo.name.empty()) {
  3063. sendErrorResponse(res, "Model not found: " + modelName, 404, "MODEL_NOT_FOUND", requestId);
  3064. return;
  3065. }
  3066. // Check if model is already GGUF
  3067. if (modelInfo.fullPath.find(".gguf") != std::string::npos) {
  3068. sendErrorResponse(res, "Model is already in GGUF format. Cannot convert GGUF to GGUF.",
  3069. 400, "ALREADY_GGUF", requestId);
  3070. return;
  3071. }
  3072. // Build output path
  3073. std::string outputPath = requestJson.value("output_path", "");
  3074. if (outputPath.empty()) {
  3075. // Generate default output path: model_name_quantization.gguf
  3076. namespace fs = std::filesystem;
  3077. fs::path inputPath(modelInfo.fullPath);
  3078. std::string baseName = inputPath.stem().string();
  3079. std::string outputDir = inputPath.parent_path().string();
  3080. outputPath = outputDir + "/" + baseName + "_" + quantizationType + ".gguf";
  3081. }
  3082. // Create conversion request
  3083. ConversionRequest convReq;
  3084. convReq.id = requestId;
  3085. convReq.modelName = modelName;
  3086. convReq.modelPath = modelInfo.fullPath;
  3087. convReq.outputPath = outputPath;
  3088. convReq.quantizationType = quantizationType;
  3089. // Enqueue conversion request
  3090. auto future = m_generationQueue->enqueueConversionRequest(convReq);
  3091. json response = {
  3092. {"request_id", requestId},
  3093. {"status", "queued"},
  3094. {"message", "Model conversion queued successfully"},
  3095. {"model_name", modelName},
  3096. {"input_path", modelInfo.fullPath},
  3097. {"output_path", outputPath},
  3098. {"quantization_type", quantizationType}
  3099. };
  3100. sendJsonResponse(res, response, 202);
  3101. } catch (const std::exception& e) {
  3102. sendErrorResponse(res, std::string("Conversion request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
  3103. }
  3104. }
  3105. void Server::handleModelStats(const httplib::Request& req, httplib::Response& res) {
  3106. std::string requestId = generateRequestId();
  3107. try {
  3108. if (!m_modelManager) {
  3109. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  3110. return;
  3111. }
  3112. auto allModels = m_modelManager->getAllModels();
  3113. json response = {
  3114. {"statistics", {
  3115. {"total_models", allModels.size()},
  3116. {"loaded_models", m_modelManager->getLoadedModelsCount()},
  3117. {"available_models", m_modelManager->getAvailableModelsCount()},
  3118. {"model_types", getModelTypeStatistics()},
  3119. {"largest_model", getLargestModel(allModels)},
  3120. {"smallest_model", getSmallestModel(allModels)}
  3121. }},
  3122. {"request_id", requestId}
  3123. };
  3124. sendJsonResponse(res, response);
  3125. } catch (const std::exception& e) {
  3126. sendErrorResponse(res, std::string("Failed to get model stats: ") + e.what(), 500, "MODEL_STATS_ERROR", requestId);
  3127. }
  3128. }
  3129. void Server::handleBatchModels(const httplib::Request& req, httplib::Response& res) {
  3130. std::string requestId = generateRequestId();
  3131. try {
  3132. if (!m_modelManager) {
  3133. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  3134. return;
  3135. }
  3136. // Parse JSON request body
  3137. json requestJson = json::parse(req.body);
  3138. if (!requestJson.contains("operation") || !requestJson["operation"].is_string()) {
  3139. sendErrorResponse(res, "Missing or invalid 'operation' field", 400, "INVALID_OPERATION", requestId);
  3140. return;
  3141. }
  3142. if (!requestJson.contains("models") || !requestJson["models"].is_array()) {
  3143. sendErrorResponse(res, "Missing or invalid 'models' field", 400, "INVALID_MODELS", requestId);
  3144. return;
  3145. }
  3146. std::string operation = requestJson["operation"];
  3147. json models = requestJson["models"];
  3148. json results = json::array();
  3149. for (const auto& model : models) {
  3150. if (!model.is_string()) {
  3151. results.push_back({
  3152. {"model", model},
  3153. {"success", false},
  3154. {"error", "Invalid model name"}
  3155. });
  3156. continue;
  3157. }
  3158. std::string modelName = model;
  3159. bool success = false;
  3160. std::string error = "";
  3161. if (operation == "load") {
  3162. success = m_modelManager->loadModel(modelName);
  3163. if (!success) error = "Failed to load model";
  3164. } else if (operation == "unload") {
  3165. success = m_modelManager->unloadModel(modelName);
  3166. if (!success) error = "Failed to unload model";
  3167. } else {
  3168. error = "Unsupported operation";
  3169. }
  3170. results.push_back({
  3171. {"model", modelName},
  3172. {"success", success},
  3173. {"error", error.empty() ? json(nullptr) : json(error)}
  3174. });
  3175. }
  3176. json response = {
  3177. {"operation", operation},
  3178. {"results", results},
  3179. {"successful_count", std::count_if(results.begin(), results.end(),
  3180. [](const json& result) { return result["success"].get<bool>(); })},
  3181. {"failed_count", std::count_if(results.begin(), results.end(),
  3182. [](const json& result) { return !result["success"].get<bool>(); })},
  3183. {"request_id", requestId}
  3184. };
  3185. sendJsonResponse(res, response);
  3186. } catch (const json::parse_error& e) {
  3187. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  3188. } catch (const std::exception& e) {
  3189. sendErrorResponse(res, std::string("Batch operation failed: ") + e.what(), 500, "BATCH_OPERATION_ERROR", requestId);
  3190. }
  3191. }
  3192. void Server::handleValidateModel(const httplib::Request& req, httplib::Response& res) {
  3193. std::string requestId = generateRequestId();
  3194. try {
  3195. // Parse JSON request body
  3196. json requestJson = json::parse(req.body);
  3197. if (!requestJson.contains("model_path") || !requestJson["model_path"].is_string()) {
  3198. sendErrorResponse(res, "Missing or invalid 'model_path' field", 400, "INVALID_MODEL_PATH", requestId);
  3199. return;
  3200. }
  3201. std::string modelPath = requestJson["model_path"];
  3202. std::string modelType = requestJson.value("model_type", "checkpoint");
  3203. // Validate model file
  3204. json validation = validateModelFile(modelPath, modelType);
  3205. json response = {
  3206. {"validation", validation},
  3207. {"request_id", requestId}
  3208. };
  3209. sendJsonResponse(res, response);
  3210. } catch (const json::parse_error& e) {
  3211. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  3212. } catch (const std::exception& e) {
  3213. sendErrorResponse(res, std::string("Model validation failed: ") + e.what(), 500, "MODEL_VALIDATION_ERROR", requestId);
  3214. }
  3215. }
  3216. void Server::handleCheckCompatibility(const httplib::Request& req, httplib::Response& res) {
  3217. std::string requestId = generateRequestId();
  3218. try {
  3219. if (!m_modelManager) {
  3220. sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
  3221. return;
  3222. }
  3223. // Parse JSON request body
  3224. json requestJson = json::parse(req.body);
  3225. if (!requestJson.contains("model_name") || !requestJson["model_name"].is_string()) {
  3226. sendErrorResponse(res, "Missing or invalid 'model_name' field", 400, "INVALID_MODEL_NAME", requestId);
  3227. return;
  3228. }
  3229. std::string modelName = requestJson["model_name"];
  3230. std::string systemInfo = requestJson.value("system_info", "auto");
  3231. // Get model information
  3232. auto modelInfo = m_modelManager->getModelInfo(modelName);
  3233. if (modelInfo.name.empty()) {
  3234. sendErrorResponse(res, "Model not found", 404, "MODEL_NOT_FOUND", requestId);
  3235. return;
  3236. }
  3237. // Check compatibility
  3238. json compatibility = checkModelCompatibility(modelInfo, systemInfo);
  3239. json response = {
  3240. {"model", modelName},
  3241. {"compatibility", compatibility},
  3242. {"request_id", requestId}
  3243. };
  3244. sendJsonResponse(res, response);
  3245. } catch (const json::parse_error& e) {
  3246. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  3247. } catch (const std::exception& e) {
  3248. sendErrorResponse(res, std::string("Compatibility check failed: ") + e.what(), 500, "COMPATIBILITY_CHECK_ERROR", requestId);
  3249. }
  3250. }
  3251. void Server::handleModelRequirements(const httplib::Request& req, httplib::Response& res) {
  3252. std::string requestId = generateRequestId();
  3253. try {
  3254. // Parse JSON request body
  3255. json requestJson = json::parse(req.body);
  3256. std::string modelType = requestJson.value("model_type", "checkpoint");
  3257. std::string resolution = requestJson.value("resolution", "512x512");
  3258. std::string batchSize = requestJson.value("batch_size", "1");
  3259. // Calculate specific requirements
  3260. json requirements = calculateSpecificRequirements(modelType, resolution, batchSize);
  3261. // Get general requirements for model type
  3262. ModelType type = ModelManager::stringToModelType(modelType);
  3263. json generalRequirements = getModelRequirements(type);
  3264. json response = {
  3265. {"model_type", modelType},
  3266. {"configuration", {
  3267. {"resolution", resolution},
  3268. {"batch_size", batchSize}
  3269. }},
  3270. {"specific_requirements", requirements},
  3271. {"general_requirements", generalRequirements},
  3272. {"request_id", requestId}
  3273. };
  3274. sendJsonResponse(res, response);
  3275. } catch (const json::parse_error& e) {
  3276. sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
  3277. } catch (const std::exception& e) {
  3278. sendErrorResponse(res, std::string("Requirements calculation failed: ") + e.what(), 500, "REQUIREMENTS_ERROR", requestId);
  3279. }
  3280. }
  3281. void Server::serverThreadFunction(const std::string& host, int port) {
  3282. try {
  3283. std::cout << "Server thread starting, attempting to bind to " << host << ":" << port << std::endl;
  3284. // Check if port is available before attempting to bind
  3285. std::cout << "Checking if port " << port << " is available..." << std::endl;
  3286. // Try to create a test socket to check if port is in use
  3287. int test_socket = socket(AF_INET, SOCK_STREAM, 0);
  3288. if (test_socket >= 0) {
  3289. // Set SO_REUSEADDR to avoid TIME_WAIT issues
  3290. int opt = 1;
  3291. setsockopt(test_socket, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
  3292. struct sockaddr_in addr;
  3293. addr.sin_family = AF_INET;
  3294. addr.sin_port = htons(port);
  3295. addr.sin_addr.s_addr = INADDR_ANY;
  3296. // Try to bind to the port
  3297. if (bind(test_socket, (struct sockaddr*)&addr, sizeof(addr)) < 0) {
  3298. close(test_socket);
  3299. std::cerr << "ERROR: Port " << port << " is already in use! Cannot start server." << std::endl;
  3300. std::cerr << "Please stop the existing instance or use a different port." << std::endl;
  3301. m_isRunning.store(false);
  3302. m_startupFailed.store(true);
  3303. return;
  3304. }
  3305. close(test_socket);
  3306. }
  3307. std::cout << "Port " << port << " is available, proceeding with server startup..." << std::endl;
  3308. std::cout << "Calling listen()..." << std::endl;
  3309. // The listen() call will block until server is stopped
  3310. // listen() returns true if it successfully binds and starts
  3311. // Once it binds successfully, we set m_isRunning to true via a callback
  3312. // Set up a flag to track if listen started successfully
  3313. std::atomic<bool> listenStarted{false};
  3314. // We need to set m_isRunning after successful bind but before blocking
  3315. // cpp-httplib doesn't provide a callback, so we set it optimistically
  3316. // and clear it if listen() returns false
  3317. m_isRunning.store(true);
  3318. bool listenResult = m_httpServer->listen(host.c_str(), port);
  3319. std::cout << "listen() returned: " << (listenResult ? "true" : "false") << std::endl;
  3320. // If we reach here, server has stopped (either normally or due to error)
  3321. m_isRunning.store(false);
  3322. if (!listenResult) {
  3323. std::cerr << "Server listen failed! This usually means port is in use or permission denied." << std::endl;
  3324. }
  3325. } catch (const std::exception& e) {
  3326. std::cerr << "Exception in server thread: " << e.what() << std::endl;
  3327. m_isRunning.store(false);
  3328. }
  3329. }