api.ts 32 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246
  1. // API client for stable-diffusion REST API
  2. // Type for server config injected by the server
  3. declare global {
  4. interface Window {
  5. __SERVER_CONFIG__?: {
  6. apiUrl: string;
  7. apiBasePath: string;
  8. host: string;
  9. port: number;
  10. authMethod: "none" | "unix" | "jwt";
  11. authEnabled: boolean;
  12. };
  13. }
  14. }
  15. // Request throttling to prevent excessive API calls
  16. class RequestThrottler {
  17. private requests: Map<string, { count: number; resetTime: number }> =
  18. new Map();
  19. private maxRequests: number = 10; // Max requests per time window
  20. private timeWindow: number = 1000; // Time window in milliseconds
  21. canMakeRequest(key: string): boolean {
  22. const now = Date.now();
  23. const request = this.requests.get(key);
  24. if (!request || now >= request.resetTime) {
  25. this.requests.set(key, { count: 1, resetTime: now + this.timeWindow });
  26. return true;
  27. }
  28. if (request.count >= this.maxRequests) {
  29. return false;
  30. }
  31. request.count++;
  32. return true;
  33. }
  34. getWaitTime(key: string): number {
  35. const request = this.requests.get(key);
  36. if (!request) return 0;
  37. const now = Date.now();
  38. if (now >= request.resetTime) return 0;
  39. return request.resetTime - now;
  40. }
  41. }
  42. // Global throttler instance
  43. const throttler = new RequestThrottler();
  44. // Debounce utility for frequent calls
  45. function debounce<T extends (...args: any[]) => any>(
  46. func: T,
  47. wait: number,
  48. immediate?: boolean,
  49. ): (...args: Parameters<T>) => void {
  50. let timeout: NodeJS.Timeout | null = null;
  51. return function executedFunction(...args: Parameters<T>) {
  52. const later = () => {
  53. timeout = null;
  54. if (!immediate) func(...args);
  55. };
  56. const callNow = immediate && !timeout;
  57. if (timeout) clearTimeout(timeout);
  58. timeout = setTimeout(later, wait);
  59. if (callNow) func(...args);
  60. };
  61. }
  62. // Cache for API responses to reduce redundant calls
  63. class ApiCache {
  64. private cache: Map<string, { data: any; timestamp: number; ttl: number }> =
  65. new Map();
  66. private defaultTtl: number = 5000; // 5 seconds default TTL
  67. set(key: string, data: any, ttl?: number): void {
  68. this.cache.set(key, {
  69. data,
  70. timestamp: Date.now(),
  71. ttl: ttl || this.defaultTtl,
  72. });
  73. }
  74. get(key: string): any | null {
  75. const cached = this.cache.get(key);
  76. if (!cached) return null;
  77. if (Date.now() - cached.timestamp > cached.ttl) {
  78. this.cache.delete(key);
  79. return null;
  80. }
  81. return cached.data;
  82. }
  83. clear(): void {
  84. this.cache.clear();
  85. }
  86. delete(key: string): void {
  87. this.cache.delete(key);
  88. }
  89. }
  90. const cache = new ApiCache();
  91. // Get configuration from server-injected config or fallback to environment/defaults
  92. // This function is called at runtime to ensure __SERVER_CONFIG__ is available
  93. function getApiConfig() {
  94. if (typeof window !== "undefined" && window.__SERVER_CONFIG__) {
  95. let apiUrl = window.__SERVER_CONFIG__.apiUrl;
  96. // Fix 0.0.0.0 host to use actual browser host
  97. if (apiUrl && apiUrl.includes("0.0.0.0")) {
  98. const protocol = window.location.protocol;
  99. const host = window.location.hostname;
  100. const port = window.location.port;
  101. apiUrl = `${protocol}//${host}${port ? ":" + port : ""}`;
  102. }
  103. return {
  104. apiUrl,
  105. apiBase: window.__SERVER_CONFIG__.apiBasePath,
  106. };
  107. }
  108. // Fallback for development mode - use current window location
  109. if (typeof window !== "undefined") {
  110. const protocol = window.location.protocol;
  111. const host = window.location.hostname;
  112. const port = window.location.port;
  113. return {
  114. apiUrl: `${protocol}//${host}:${port}`,
  115. apiBase: "/api",
  116. };
  117. }
  118. // Server-side fallback
  119. return {
  120. apiUrl: process.env.NEXT_PUBLIC_API_URL || "http://localhost:8081",
  121. apiBase: process.env.NEXT_PUBLIC_API_BASE_PATH || "/api",
  122. };
  123. }
  124. export interface GenerationRequest {
  125. model?: string;
  126. prompt: string;
  127. negative_prompt?: string;
  128. width?: number;
  129. height?: number;
  130. steps?: number;
  131. cfg_scale?: number;
  132. seed?: string;
  133. sampling_method?: string;
  134. scheduler?: string;
  135. batch_count?: number;
  136. clip_skip?: number;
  137. strength?: number;
  138. control_strength?: number;
  139. }
  140. export interface JobInfo {
  141. id?: string;
  142. request_id?: string;
  143. status:
  144. | "pending"
  145. | "processing"
  146. | "completed"
  147. | "failed"
  148. | "cancelled"
  149. | "queued";
  150. progress?: number;
  151. result?: {
  152. images: string[];
  153. };
  154. outputs?: Array<{
  155. filename: string;
  156. url: string;
  157. path: string;
  158. }>;
  159. error?: string;
  160. created_at?: string;
  161. updated_at?: string;
  162. message?: string;
  163. queue_position?: number;
  164. prompt?: string;
  165. end_time?: number;
  166. start_time?: number;
  167. queued_time?: number;
  168. error_message?: string;
  169. position?: number;
  170. }
  171. // API response wrapper for job details
  172. export interface JobDetailsResponse {
  173. job: JobInfo;
  174. }
  175. export interface ModelInfo {
  176. id?: string;
  177. name: string;
  178. path?: string;
  179. type: string;
  180. size?: number;
  181. file_size?: number;
  182. file_size_mb?: number;
  183. sha256?: string | null;
  184. sha256_short?: string | null;
  185. loaded?: boolean;
  186. architecture?: string;
  187. required_models?: RequiredModelInfo[];
  188. recommended_vae?: RecommendedModelInfo;
  189. recommended_textual_inversions?: RecommendedModelInfo[];
  190. recommended_loras?: RecommendedModelInfo[];
  191. metadata?: Record<string, any>;
  192. }
  193. export interface RequiredModelInfo {
  194. type: string;
  195. name?: string;
  196. description?: string;
  197. optional?: boolean;
  198. priority?: number;
  199. }
  200. export interface RecommendedModelInfo {
  201. type: string;
  202. name?: string;
  203. description?: string;
  204. reason?: string;
  205. }
  206. export interface AutoSelectionState {
  207. selectedModels: Record<string, string>; // modelType -> modelName
  208. autoSelectedModels: Record<string, string>; // modelType -> modelName
  209. missingModels: string[]; // modelType names
  210. warnings: string[];
  211. errors: string[];
  212. isAutoSelecting: boolean;
  213. }
  214. export interface EnhancedModelsResponse {
  215. models: ModelInfo[];
  216. pagination: {
  217. page: number;
  218. limit: number;
  219. total_count: number;
  220. total_pages: number;
  221. has_next: boolean;
  222. has_prev: boolean;
  223. };
  224. statistics: any;
  225. auto_selection?: AutoSelectionState;
  226. }
  227. export interface QueueStatus {
  228. active_generations: number;
  229. jobs: JobInfo[];
  230. running: boolean;
  231. size: number;
  232. }
  233. export interface HealthStatus {
  234. status: "ok" | "error" | "degraded";
  235. message: string;
  236. timestamp: string;
  237. uptime?: number;
  238. version?: string;
  239. }
  240. export interface VersionInfo {
  241. version: string;
  242. type: string;
  243. commit: {
  244. short: string;
  245. full: string;
  246. };
  247. branch: string;
  248. clean: boolean;
  249. build_time: string;
  250. }
  251. class ApiClient {
  252. private baseUrl: string = "";
  253. private isInitialized: boolean = false;
  254. // Initialize base URL
  255. private initBaseUrl(): string {
  256. if (!this.isInitialized) {
  257. const config = getApiConfig();
  258. this.baseUrl = `${config.apiUrl}${config.apiBase}`;
  259. this.isInitialized = true;
  260. }
  261. return this.baseUrl;
  262. }
  263. // Get base URL dynamically at runtime to ensure server config is loaded
  264. private getBaseUrl(): string {
  265. return this.initBaseUrl();
  266. }
  267. private async request<T>(
  268. endpoint: string,
  269. options: RequestInit = {},
  270. ): Promise<T> {
  271. const url = `${this.getBaseUrl()}${endpoint}`;
  272. // Check request throttling for certain endpoints
  273. const needsThrottling =
  274. endpoint.includes("/queue/status") || endpoint.includes("/health");
  275. if (needsThrottling) {
  276. const waitTime = throttler.getWaitTime(endpoint);
  277. if (waitTime > 0) {
  278. // Wait before making the request
  279. await new Promise((resolve) => setTimeout(resolve, waitTime));
  280. }
  281. if (!throttler.canMakeRequest(endpoint)) {
  282. throw new Error(
  283. "Too many requests. Please wait before making another request.",
  284. );
  285. }
  286. }
  287. // Get authentication method from server config
  288. const authMethod =
  289. typeof window !== "undefined" && window.__SERVER_CONFIG__
  290. ? window.__SERVER_CONFIG__.authMethod
  291. : "jwt";
  292. // Add auth token or Unix user header based on auth method
  293. const token =
  294. typeof window !== "undefined" ? localStorage.getItem("auth_token") : null;
  295. const unixUser =
  296. typeof window !== "undefined" ? localStorage.getItem("unix_user") : null;
  297. const headers: Record<string, string> = {
  298. "Content-Type": "application/json",
  299. ...(options.headers as Record<string, string>),
  300. };
  301. if (authMethod === "unix" && unixUser) {
  302. // For Unix auth, send the username in X-Unix-User header
  303. headers["X-Unix-User"] = unixUser;
  304. } else if (token) {
  305. // For JWT auth, send the token in Authorization header
  306. headers["Authorization"] = `Bearer ${token}`;
  307. }
  308. const response = await fetch(url, {
  309. ...options,
  310. headers,
  311. });
  312. if (!response.ok) {
  313. const errorData = await response.json().catch(() => ({
  314. error: { message: response.statusText },
  315. }));
  316. // Handle nested error structure: { error: { message: "..." } }
  317. const errorMessage =
  318. errorData.error?.message ||
  319. errorData.message ||
  320. errorData.error ||
  321. "API request failed";
  322. throw new Error(errorMessage);
  323. }
  324. return response.json();
  325. }
  326. // Enhanced health check with caching and better error handling
  327. async checkHealth(): Promise<HealthStatus> {
  328. const cacheKey = "health_check";
  329. const cachedResult = cache.get(cacheKey);
  330. if (cachedResult) {
  331. return cachedResult;
  332. }
  333. const endpoints = ["/queue/status", "/health", "/status", "/"];
  334. for (const endpoint of endpoints) {
  335. try {
  336. const response = await fetch(`${this.getBaseUrl()}${endpoint}`, {
  337. method: "GET",
  338. headers: {
  339. "Content-Type": "application/json",
  340. },
  341. // Add timeout to prevent hanging requests
  342. signal: AbortSignal.timeout(3000), // Reduced timeout
  343. });
  344. if (response.ok) {
  345. const data = await response.json();
  346. // For queue status, consider it healthy if it returns valid structure
  347. if (endpoint === "/queue/status" && data.queue) {
  348. const result = {
  349. status: "ok" as const,
  350. message: "API is running and queue is accessible",
  351. timestamp: new Date().toISOString(),
  352. };
  353. cache.set(cacheKey, result, 10000); // Cache for 10 seconds
  354. return result;
  355. }
  356. // For other health endpoints
  357. const healthStatus: HealthStatus = {
  358. status: "ok",
  359. message: "API is running",
  360. timestamp: new Date().toISOString(),
  361. uptime: data.uptime,
  362. version: data.version || data.build || data.git_version,
  363. };
  364. // Handle different response formats
  365. if (data.status) {
  366. if (
  367. data.status === "healthy" ||
  368. data.status === "running" ||
  369. data.status === "ok"
  370. ) {
  371. healthStatus.status = "ok";
  372. healthStatus.message = data.message || "API is running";
  373. } else if (data.status === "degraded") {
  374. healthStatus.status = "degraded";
  375. healthStatus.message =
  376. data.message || "API is running in degraded mode";
  377. } else {
  378. healthStatus.status = "error";
  379. healthStatus.message = data.message || "API status is unknown";
  380. }
  381. }
  382. cache.set(cacheKey, healthStatus, 10000); // Cache for 10 seconds
  383. return healthStatus;
  384. }
  385. } catch (error) {
  386. // Continue to next endpoint if this one fails
  387. console.warn(`Health check failed for endpoint ${endpoint}:`, error);
  388. continue;
  389. }
  390. }
  391. // If all endpoints fail
  392. throw new Error("All health check endpoints are unavailable");
  393. }
  394. // Alternative simple connectivity check with caching
  395. async checkConnectivity(): Promise<boolean> {
  396. const cacheKey = "connectivity_check";
  397. const cachedResult = cache.get(cacheKey);
  398. if (cachedResult !== null) {
  399. return cachedResult;
  400. }
  401. try {
  402. const response = await fetch(`${this.getBaseUrl()}`, {
  403. method: "HEAD",
  404. signal: AbortSignal.timeout(2000), // Reduced timeout
  405. });
  406. const result = response.ok || response.status < 500;
  407. cache.set(cacheKey, result, 5000); // Cache for 5 seconds
  408. return result;
  409. } catch (error) {
  410. cache.set(cacheKey, false, 5000); // Cache failure for 5 seconds
  411. return false;
  412. }
  413. }
  414. // Generation endpoints
  415. async generateImage(params: GenerationRequest): Promise<JobInfo> {
  416. return this.request<JobInfo>("/generate/text2img", {
  417. method: "POST",
  418. body: JSON.stringify(params),
  419. });
  420. }
  421. async text2img(params: GenerationRequest): Promise<JobInfo> {
  422. return this.request<JobInfo>("/generate/text2img", {
  423. method: "POST",
  424. body: JSON.stringify(params),
  425. });
  426. }
  427. async img2img(
  428. params: GenerationRequest & { image: string },
  429. ): Promise<JobInfo> {
  430. // Convert frontend field name to backend field name
  431. const backendParams = {
  432. ...params,
  433. init_image: params.image,
  434. image: undefined, // Remove the frontend field
  435. };
  436. return this.request<JobInfo>("/generate/img2img", {
  437. method: "POST",
  438. body: JSON.stringify(backendParams),
  439. });
  440. }
  441. async inpainting(
  442. params: GenerationRequest & { source_image: string; mask_image: string },
  443. ): Promise<JobInfo> {
  444. return this.request<JobInfo>("/generate/inpainting", {
  445. method: "POST",
  446. body: JSON.stringify(params),
  447. });
  448. }
  449. // Job management with caching for status checks
  450. async getJobStatus(jobId: string): Promise<JobDetailsResponse> {
  451. const cacheKey = `job_status_${jobId}`;
  452. const cachedResult = cache.get(cacheKey);
  453. if (cachedResult) {
  454. return cachedResult;
  455. }
  456. const result = await this.request<JobDetailsResponse>(
  457. `/queue/job/${jobId}`,
  458. );
  459. // Cache job status for a short time
  460. if (result.job.status === "processing" || result.job.status === "queued") {
  461. cache.set(cacheKey, result, 2000); // Cache for 2 seconds for active jobs
  462. } else {
  463. cache.set(cacheKey, result, 10000); // Cache for 10 seconds for completed jobs
  464. }
  465. return result;
  466. }
  467. // Get authenticated image URL with cache-busting
  468. getImageUrl(jobId: string, filename: string): string {
  469. const baseUrl = this.getBaseUrl();
  470. // Add cache-busting timestamp
  471. const timestamp = Date.now();
  472. const url = `${baseUrl}/queue/job/${jobId}/output/${filename}?t=${timestamp}`;
  473. return url;
  474. }
  475. // Download image with authentication
  476. async downloadImage(jobId: string, filename: string): Promise<Blob> {
  477. const url = this.getImageUrl(jobId, filename);
  478. // Get authentication method from server config
  479. const authMethod =
  480. typeof window !== "undefined" && window.__SERVER_CONFIG__
  481. ? window.__SERVER_CONFIG__.authMethod
  482. : "jwt";
  483. // Add auth token or Unix user header based on auth method
  484. const token =
  485. typeof window !== "undefined" ? localStorage.getItem("auth_token") : null;
  486. const unixUser =
  487. typeof window !== "undefined" ? localStorage.getItem("unix_user") : null;
  488. const headers: Record<string, string> = {};
  489. if (authMethod === "unix" && unixUser) {
  490. // For Unix auth, send the username in X-Unix-User header
  491. headers["X-Unix-User"] = unixUser;
  492. } else if (token) {
  493. // For JWT auth, send the token in Authorization header
  494. headers["Authorization"] = `Bearer ${token}`;
  495. }
  496. const response = await fetch(url, {
  497. headers,
  498. });
  499. if (!response.ok) {
  500. const errorData = await response.json().catch(() => ({
  501. error: { message: response.statusText },
  502. }));
  503. // Handle nested error structure: { error: { message: "..." } }
  504. const errorMessage =
  505. errorData.error?.message ||
  506. errorData.message ||
  507. errorData.error ||
  508. "Failed to download image";
  509. throw new Error(errorMessage);
  510. }
  511. return response.blob();
  512. }
  513. // Download image from URL with server-side proxy to avoid CORS issues
  514. async downloadImageFromUrl(url: string): Promise<{
  515. mimeType: string;
  516. filename: string;
  517. base64Data: string;
  518. tempUrl?: string;
  519. tempFilename?: string;
  520. }> {
  521. const apiUrl = `${this.getBaseUrl()}/image/download?url=${encodeURIComponent(url)}`;
  522. const response = await fetch(apiUrl);
  523. if (!response.ok) {
  524. const errorData = await response.json().catch(() => ({
  525. error: { message: response.statusText },
  526. }));
  527. // Handle nested error structure: { error: { message: "..." } }
  528. const errorMessage =
  529. errorData.error?.message ||
  530. errorData.message ||
  531. errorData.error ||
  532. "Failed to download image from URL";
  533. throw new Error(errorMessage);
  534. }
  535. const result = await response.json();
  536. return {
  537. mimeType: result.mime_type,
  538. filename: result.filename,
  539. base64Data: result.base64_data,
  540. tempUrl: result.temp_url,
  541. tempFilename: result.temp_filename,
  542. };
  543. }
  544. async cancelJob(jobId: string): Promise<void> {
  545. // Clear job status cache when cancelling
  546. cache.delete(`job_status_${jobId}`);
  547. return this.request<void>("/queue/cancel", {
  548. method: "POST",
  549. body: JSON.stringify({ job_id: jobId }),
  550. });
  551. }
  552. // Get queue status with caching and throttling
  553. async getQueueStatus(): Promise<QueueStatus> {
  554. const cacheKey = "queue_status";
  555. const cachedResult = cache.get(cacheKey);
  556. if (cachedResult) {
  557. return cachedResult;
  558. }
  559. const response = await this.request<{ queue: QueueStatus }>(
  560. "/queue/status",
  561. );
  562. // Cache queue status based on current activity
  563. const hasActiveJobs = response.queue.jobs.some(
  564. (job) => job.status === "processing" || job.status === "queued",
  565. );
  566. // Cache for shorter time if there are active jobs
  567. const cacheTime = hasActiveJobs ? 1000 : 5000; // 1 second for active, 5 seconds for idle
  568. cache.set(cacheKey, response.queue, cacheTime);
  569. return response.queue;
  570. }
  571. async clearQueue(): Promise<void> {
  572. // Clear all related caches
  573. cache.delete("queue_status");
  574. return this.request<void>("/queue/clear", {
  575. method: "POST",
  576. });
  577. }
  578. // Model management
  579. async getModels(
  580. type?: string,
  581. loaded?: boolean,
  582. page: number = 1,
  583. limit: number = -1,
  584. search?: string,
  585. ): Promise<EnhancedModelsResponse> {
  586. const cacheKey = `models_${type || "all"}_${loaded ? "loaded" : "all"}_${page}_${limit}_${search || "all"}`;
  587. const cachedResult = cache.get(cacheKey);
  588. if (cachedResult) {
  589. return cachedResult;
  590. }
  591. let endpoint = "/models";
  592. const params = [];
  593. if (type && type !== "loaded") params.push(`type=${type}`);
  594. if (type === "loaded" || loaded) params.push("loaded=true");
  595. // Only add page parameter if we're using pagination (limit > 0)
  596. if (limit > 0) {
  597. params.push(`page=${page}`);
  598. params.push(`limit=${limit}`);
  599. } else {
  600. // When limit is 0 (default), we want all models, so add limit=0 to disable pagination
  601. params.push("limit=0");
  602. }
  603. if (search) params.push(`search=${encodeURIComponent(search)}`);
  604. // Add include_metadata for additional information
  605. params.push("include_metadata=true");
  606. if (params.length > 0) endpoint += "?" + params.join("&");
  607. const response = await this.request<EnhancedModelsResponse>(endpoint);
  608. const models = response.models.map((model) => ({
  609. ...model,
  610. id: model.sha256_short || model.name,
  611. size: model.file_size || model.size,
  612. path: model.path || model.name,
  613. }));
  614. const result = {
  615. ...response,
  616. models,
  617. };
  618. // Cache models for 30 seconds as they don't change frequently
  619. cache.set(cacheKey, result, 30000);
  620. return result;
  621. }
  622. // Get models with automatic selection information
  623. async getModelsForAutoSelection(
  624. checkpointModel?: string,
  625. ): Promise<EnhancedModelsResponse> {
  626. const cacheKey = `models_auto_selection_${checkpointModel || "none"}`;
  627. const cachedResult = cache.get(cacheKey);
  628. if (cachedResult) {
  629. return cachedResult;
  630. }
  631. let endpoint = "/models";
  632. const params = [];
  633. params.push("include_metadata=true");
  634. params.push("include_requirements=true");
  635. if (checkpointModel) {
  636. params.push(`checkpoint=${encodeURIComponent(checkpointModel)}`);
  637. }
  638. params.push("limit=0"); // Get all models
  639. if (params.length > 0) endpoint += "?" + params.join("&");
  640. const response = await this.request<EnhancedModelsResponse>(endpoint);
  641. const models = response.models.map((model) => ({
  642. ...model,
  643. id: model.sha256_short || model.name,
  644. size: model.file_size || model.size,
  645. path: model.path || model.name,
  646. }));
  647. const result = {
  648. ...response,
  649. models,
  650. };
  651. // Cache for 30 seconds
  652. cache.set(cacheKey, result, 30000);
  653. return result;
  654. }
  655. // Utility function to get models by type
  656. getModelsByType(models: ModelInfo[], type: string): ModelInfo[] {
  657. return models.filter(
  658. (model) => model.type.toLowerCase() === type.toLowerCase(),
  659. );
  660. }
  661. // Utility function to find models by name pattern
  662. findModelsByName(models: ModelInfo[], namePattern: string): ModelInfo[] {
  663. const pattern = namePattern.toLowerCase();
  664. return models.filter((model) => model.name.toLowerCase().includes(pattern));
  665. }
  666. // Utility function to get loaded models by type
  667. getLoadedModelsByType(models: ModelInfo[], type: string): ModelInfo[] {
  668. return this.getModelsByType(models, type).filter((model) => model.loaded);
  669. }
  670. // Get all models (for backward compatibility)
  671. async getAllModels(type?: string, loaded?: boolean): Promise<ModelInfo[]> {
  672. const allModels: ModelInfo[] = [];
  673. let page = 1;
  674. const limit = 100;
  675. while (true) {
  676. const response = await this.getModels(type, loaded, page, limit);
  677. allModels.push(...response.models);
  678. if (!response.pagination.has_next) {
  679. break;
  680. }
  681. page++;
  682. }
  683. return allModels;
  684. }
  685. async getModelInfo(modelId: string): Promise<ModelInfo> {
  686. const cacheKey = `model_info_${modelId}`;
  687. const cachedResult = cache.get(cacheKey);
  688. if (cachedResult) {
  689. return cachedResult;
  690. }
  691. const result = await this.request<ModelInfo>(`/models/${modelId}`);
  692. cache.set(cacheKey, result, 30000); // Cache for 30 seconds
  693. return result;
  694. }
  695. async loadModel(modelId: string): Promise<void> {
  696. // Clear model cache when loading
  697. cache.delete(`model_info_${modelId}`);
  698. return this.request<void>(`/models/${modelId}/load`, {
  699. method: "POST",
  700. });
  701. }
  702. async unloadModel(modelId: string): Promise<void> {
  703. // Clear model cache when unloading
  704. cache.delete(`model_info_${modelId}`);
  705. return this.request<void>(`/models/${modelId}/unload`, {
  706. method: "POST",
  707. });
  708. }
  709. async scanModels(): Promise<void> {
  710. // Clear all model caches when scanning
  711. cache.clear();
  712. return this.request<void>("/models/refresh", {
  713. method: "POST",
  714. });
  715. }
  716. async getModelTypes(): Promise<
  717. Array<{
  718. type: string;
  719. description: string;
  720. extensions: string[];
  721. capabilities: string[];
  722. requires?: string[];
  723. recommended_for: string;
  724. }>
  725. > {
  726. const cacheKey = "model_types";
  727. const cachedResult = cache.get(cacheKey);
  728. if (cachedResult) {
  729. return cachedResult;
  730. }
  731. const response = await this.request<{
  732. model_types: Array<{
  733. type: string;
  734. description: string;
  735. extensions: string[];
  736. capabilities: string[];
  737. requires?: string[];
  738. recommended_for: string;
  739. }>;
  740. }>("/models/types");
  741. cache.set(cacheKey, response.model_types, 60000); // Cache for 1 minute
  742. return response.model_types;
  743. }
  744. async convertModel(
  745. modelName: string,
  746. quantizationType: string,
  747. outputPath?: string,
  748. ): Promise<{ request_id: string; message: string }> {
  749. return this.request<{ request_id: string; message: string }>(
  750. "/models/convert",
  751. {
  752. method: "POST",
  753. body: JSON.stringify({
  754. model_name: modelName,
  755. quantization_type: quantizationType,
  756. output_path: outputPath,
  757. }),
  758. },
  759. );
  760. }
  761. // System endpoints
  762. async getHealth(): Promise<{ status: string }> {
  763. return this.request<{ status: string }>("/health");
  764. }
  765. async getStatus(): Promise<any> {
  766. return this.request<any>("/status");
  767. }
  768. async getSystemInfo(): Promise<any> {
  769. return this.request<any>("/system");
  770. }
  771. async restartServer(): Promise<{ message: string }> {
  772. return this.request<{ message: string }>("/system/restart", {
  773. method: "POST",
  774. body: JSON.stringify({}),
  775. });
  776. }
  777. // Image manipulation endpoints
  778. async resizeImage(
  779. image: string,
  780. width: number,
  781. height: number,
  782. ): Promise<{ image: string }> {
  783. return this.request<{ image: string }>("/image/resize", {
  784. method: "POST",
  785. body: JSON.stringify({
  786. image,
  787. width,
  788. height,
  789. }),
  790. });
  791. }
  792. async cropImage(
  793. image: string,
  794. x: number,
  795. y: number,
  796. width: number,
  797. height: number,
  798. ): Promise<{ image: string }> {
  799. return this.request<{ image: string }>("/image/crop", {
  800. method: "POST",
  801. body: JSON.stringify({
  802. image,
  803. x,
  804. y,
  805. width,
  806. height,
  807. }),
  808. });
  809. }
  810. // Configuration endpoints with caching
  811. async getSamplers(): Promise<
  812. Array<{ name: string; description: string; recommended_steps: number }>
  813. > {
  814. const cacheKey = "samplers";
  815. const cachedResult = cache.get(cacheKey);
  816. if (cachedResult) {
  817. return cachedResult;
  818. }
  819. const response = await this.request<{
  820. samplers: Array<{
  821. name: string;
  822. description: string;
  823. recommended_steps: number;
  824. }>;
  825. }>("/samplers");
  826. cache.set(cacheKey, response.samplers, 60000); // Cache for 1 minute
  827. return response.samplers;
  828. }
  829. async getSchedulers(): Promise<Array<{ name: string; description: string }>> {
  830. const cacheKey = "schedulers";
  831. const cachedResult = cache.get(cacheKey);
  832. if (cachedResult) {
  833. return cachedResult;
  834. }
  835. const response = await this.request<{
  836. schedulers: Array<{ name: string; description: string }>;
  837. }>("/schedulers");
  838. cache.set(cacheKey, response.schedulers, 60000); // Cache for 1 minute
  839. return response.schedulers;
  840. }
  841. // Cache management methods
  842. clearCache(): void {
  843. cache.clear();
  844. }
  845. clearCacheByPrefix(prefix: string): void {
  846. const keysToDelete: string[] = [];
  847. (cache as any).cache.forEach((_: any, key: string) => {
  848. if (key.startsWith(prefix)) {
  849. keysToDelete.push(key);
  850. }
  851. });
  852. keysToDelete.forEach((key) => cache.delete(key));
  853. }
  854. }
  855. // Generic API request function for authentication
  856. export async function apiRequest(
  857. endpoint: string,
  858. options: RequestInit = {},
  859. ): Promise<Response> {
  860. const { apiUrl, apiBase } = getApiConfig();
  861. const url = `${apiUrl}${apiBase}${endpoint}`;
  862. // Get authentication method from server config
  863. const authMethod =
  864. typeof window !== "undefined" && window.__SERVER_CONFIG__
  865. ? window.__SERVER_CONFIG__.authMethod
  866. : "jwt";
  867. // Add auth token or Unix user header based on auth method
  868. const token =
  869. typeof window !== "undefined" ? localStorage.getItem("auth_token") : null;
  870. const unixUser =
  871. typeof window !== "undefined" ? localStorage.getItem("unix_user") : null;
  872. const headers: Record<string, string> = {
  873. "Content-Type": "application/json",
  874. ...(options.headers as Record<string, string>),
  875. };
  876. if (authMethod === "unix" && unixUser) {
  877. // For Unix auth, send the username in X-Unix-User header
  878. headers["X-Unix-User"] = unixUser;
  879. } else if (token) {
  880. // For JWT auth, send the token in Authorization header
  881. headers["Authorization"] = `Bearer ${token}`;
  882. }
  883. return fetch(url, {
  884. ...options,
  885. headers,
  886. });
  887. }
  888. // Authentication API endpoints
  889. export const authApi = {
  890. async login(username: string, password?: string) {
  891. // Get authentication method from server config
  892. const authMethod =
  893. typeof window !== "undefined" && window.__SERVER_CONFIG__
  894. ? window.__SERVER_CONFIG__.authMethod
  895. : "jwt";
  896. // For both Unix and JWT auth, send username and password
  897. // The server will handle whether password is required based on PAM availability
  898. const response = await apiRequest("/auth/login", {
  899. method: "POST",
  900. body: JSON.stringify({ username, password }),
  901. });
  902. if (!response.ok) {
  903. const error = await response
  904. .json()
  905. .catch(() => ({ message: "Login failed" }));
  906. throw new Error(error.message || "Login failed");
  907. }
  908. return response.json();
  909. },
  910. async validateToken(token: string) {
  911. const response = await apiRequest("/auth/validate", {
  912. headers: { Authorization: `Bearer ${token}` },
  913. });
  914. if (!response.ok) {
  915. throw new Error("Token validation failed");
  916. }
  917. return response.json();
  918. },
  919. async refreshToken() {
  920. const response = await apiRequest("/auth/refresh", {
  921. method: "POST",
  922. });
  923. if (!response.ok) {
  924. throw new Error("Token refresh failed");
  925. }
  926. return response.json();
  927. },
  928. async logout() {
  929. await apiRequest("/auth/logout", {
  930. method: "POST",
  931. });
  932. },
  933. async getCurrentUser() {
  934. const response = await apiRequest("/auth/me");
  935. if (!response.ok) {
  936. throw new Error("Failed to get current user");
  937. }
  938. return response.json();
  939. },
  940. async changePassword(oldPassword: string, newPassword: string) {
  941. const response = await apiRequest("/auth/change-password", {
  942. method: "POST",
  943. body: JSON.stringify({
  944. old_password: oldPassword,
  945. new_password: newPassword,
  946. }),
  947. });
  948. if (!response.ok) {
  949. const error = await response
  950. .json()
  951. .catch(() => ({ message: "Password change failed" }));
  952. throw new Error(error.message || "Password change failed");
  953. }
  954. return response.json();
  955. },
  956. // API Key management
  957. async getApiKeys() {
  958. const response = await apiRequest("/auth/api-keys");
  959. if (!response.ok) {
  960. throw new Error("Failed to get API keys");
  961. }
  962. return response.json();
  963. },
  964. async createApiKey(name: string, scopes?: string[]) {
  965. const response = await apiRequest("/auth/api-keys", {
  966. method: "POST",
  967. body: JSON.stringify({ name, scopes }),
  968. });
  969. if (!response.ok) {
  970. const error = await response
  971. .json()
  972. .catch(() => ({ message: "Failed to create API key" }));
  973. throw new Error(error.message || "Failed to create API key");
  974. }
  975. return response.json();
  976. },
  977. async deleteApiKey(keyId: string) {
  978. const response = await apiRequest(`/auth/api-keys/${keyId}`, {
  979. method: "DELETE",
  980. });
  981. if (!response.ok) {
  982. throw new Error("Failed to delete API key");
  983. }
  984. return response.json();
  985. },
  986. // User management (admin only)
  987. async getUsers() {
  988. const response = await apiRequest("/auth/users");
  989. if (!response.ok) {
  990. throw new Error("Failed to get users");
  991. }
  992. return response.json();
  993. },
  994. async createUser(userData: {
  995. username: string;
  996. email?: string;
  997. password: string;
  998. role?: string;
  999. }) {
  1000. const response = await apiRequest("/auth/users", {
  1001. method: "POST",
  1002. body: JSON.stringify(userData),
  1003. });
  1004. if (!response.ok) {
  1005. const error = await response
  1006. .json()
  1007. .catch(() => ({ message: "Failed to create user" }));
  1008. throw new Error(error.message || "Failed to create user");
  1009. }
  1010. return response.json();
  1011. },
  1012. async updateUser(
  1013. userId: string,
  1014. userData: { email?: string; role?: string; active?: boolean },
  1015. ) {
  1016. const response = await apiRequest(`/auth/users/${userId}`, {
  1017. method: "PUT",
  1018. body: JSON.stringify(userData),
  1019. });
  1020. if (!response.ok) {
  1021. const error = await response
  1022. .json()
  1023. .catch(() => ({ message: "Failed to update user" }));
  1024. throw new Error(error.message || "Failed to update user");
  1025. }
  1026. return response.json();
  1027. },
  1028. async deleteUser(userId: string) {
  1029. const response = await apiRequest(`/auth/users/${userId}`, {
  1030. method: "DELETE",
  1031. });
  1032. if (!response.ok) {
  1033. throw new Error("Failed to delete user");
  1034. }
  1035. return response.json();
  1036. },
  1037. };
  1038. // Version API
  1039. export async function getVersion(): Promise<VersionInfo> {
  1040. const response = await apiRequest("/version");
  1041. if (!response.ok) {
  1042. throw new Error("Failed to get version information");
  1043. }
  1044. return response.json();
  1045. }
  1046. export const apiClient = new ApiClient();