page.tsx 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676
  1. "use client";
  2. import { useState, useEffect, useRef } from "react";
  3. import { Header } from "@/components/layout";
  4. import { AppLayout } from "@/components/layout";
  5. import { Button } from "@/components/ui/button";
  6. import { Input } from "@/components/ui/input";
  7. import { Progress } from "@/components/ui/progress";
  8. import { PromptTextarea } from "@/components/forms";
  9. import { Label } from "@/components/ui/label";
  10. import { Card, CardContent } from "@/components/ui/card";
  11. import {
  12. Select,
  13. SelectContent,
  14. SelectItem,
  15. SelectTrigger,
  16. SelectValue,
  17. } from "@/components/ui/select";
  18. import {
  19. apiClient,
  20. type GenerationRequest,
  21. type JobInfo,
  22. type JobDetailsResponse,
  23. } from "@/lib/api";
  24. import { Loader2, Download, X, Trash2, RotateCcw, Power } from "lucide-react";
  25. import { downloadImage, downloadAuthenticatedImage } from "@/lib/utils";
  26. import { useLocalStorage, useGeneratedImages } from "@/lib/storage";
  27. import { useModelTypeSelection } from "@/contexts/model-selection-context";
  28. const defaultFormData: GenerationRequest = {
  29. prompt: "",
  30. negative_prompt: "",
  31. width: 512,
  32. height: 512,
  33. steps: 20,
  34. cfg_scale: 7.5,
  35. seed: "",
  36. sampling_method: "euler_a",
  37. scheduler: "default",
  38. batch_count: 1,
  39. };
  40. function Text2ImgForm() {
  41. const {
  42. availableModels: vaeModels,
  43. selectedModel: selectedVae,
  44. setSelectedModel: setSelectedVae,
  45. } = useModelTypeSelection("vae");
  46. const {
  47. availableModels: taesdModels,
  48. selectedModel: selectedTaesd,
  49. setSelectedModel: setSelectedTaesd,
  50. } = useModelTypeSelection("taesd");
  51. const [formData, setFormData] = useLocalStorage<GenerationRequest>(
  52. "text2img-form-data",
  53. defaultFormData,
  54. { excludeLargeData: true, maxSize: 512 * 1024 }, // 512KB limit
  55. );
  56. const [loading, setLoading] = useState(false);
  57. const [jobInfo, setJobInfo] = useState<JobInfo | null>(null);
  58. const { images: storedImages, addImages, getLatestImages } = useGeneratedImages('text2img');
  59. const [generatedImages, setGeneratedImages] = useState<string[]>(() => storedImages.map(img => img.url));
  60. const [samplers, setSamplers] = useState<
  61. Array<{ name: string; description: string }>
  62. >([]);
  63. const [schedulers, setSchedulers] = useState<
  64. Array<{ name: string; description: string }>
  65. >([]);
  66. const [loraModels, setLoraModels] = useState<string[]>([]);
  67. const [embeddings, setEmbeddings] = useState<string[]>([]);
  68. const [error, setError] = useState<string | null>(null);
  69. const pollCleanupRef = useRef<(() => void) | null>(null);
  70. // Cleanup polling on unmount
  71. useEffect(() => {
  72. return () => {
  73. if (pollCleanupRef.current) {
  74. pollCleanupRef.current();
  75. pollCleanupRef.current = null;
  76. }
  77. };
  78. }, []);
  79. useEffect(() => {
  80. const loadOptions = async () => {
  81. try {
  82. const [samplersData, schedulersData, loras, embeds] = await Promise.all(
  83. [
  84. apiClient.getSamplers(),
  85. apiClient.getSchedulers(),
  86. apiClient.getModels("lora"),
  87. apiClient.getModels("embedding"),
  88. ],
  89. );
  90. setSamplers(samplersData);
  91. setSchedulers(schedulersData);
  92. setLoraModels(loras.models.map((m) => m.name));
  93. setEmbeddings(embeds.models.map((m) => m.name));
  94. } catch (err) {
  95. console.error("Failed to load options:", err);
  96. }
  97. };
  98. loadOptions();
  99. }, []);
  100. const handleInputChange = (
  101. e: React.ChangeEvent<
  102. HTMLInputElement | HTMLTextAreaElement | HTMLSelectElement
  103. >,
  104. ) => {
  105. const { name, value } = e.target;
  106. setFormData((prev) => ({
  107. ...prev,
  108. [name]:
  109. name === "prompt" ||
  110. name === "negative_prompt" ||
  111. name === "seed" ||
  112. name === "sampling_method" ||
  113. name === "scheduler"
  114. ? value
  115. : Number(value),
  116. }));
  117. };
  118. const pollJobStatus = async (jobId: string) => {
  119. const maxAttempts = 300; // 5 minutes with 2 second interval
  120. let attempts = 0;
  121. let isPolling = true;
  122. let timeoutId: NodeJS.Timeout | null = null;
  123. const poll = async () => {
  124. if (!isPolling) return;
  125. try {
  126. const status: JobDetailsResponse = await apiClient.getJobStatus(jobId);
  127. setJobInfo(status.job);
  128. console.log(`[DEBUG] Job ${jobId} status: ${status.job.status}, progress: ${status.job.progress}, outputs:`, status.job.outputs);
  129. if (status.job.status === "completed") {
  130. let imageUrls: string[] = [];
  131. // Handle both old format (result.images) and new format (outputs)
  132. if (status.job.outputs && status.job.outputs.length > 0) {
  133. console.log(`[DEBUG] Processing ${status.job.outputs.length} outputs`);
  134. // New format: convert output URLs to authenticated image URLs with cache-busting
  135. imageUrls = status.job.outputs.map((output: { filename: string }) => {
  136. const filename = output.filename;
  137. const imageUrl = apiClient.getImageUrl(jobId, filename);
  138. console.log(`[DEBUG] Generated URL for ${filename}: ${imageUrl}`);
  139. return imageUrl;
  140. });
  141. } else if (
  142. status.job.result?.images &&
  143. status.job.result.images.length > 0
  144. ) {
  145. console.log(`[DEBUG] Using old format with ${status.job.result.images.length} images`);
  146. // Old format: convert image URLs to authenticated URLs
  147. imageUrls = status.job.result.images.map((imageUrl: string) => {
  148. // Extract filename from URL if it's a full URL
  149. if (imageUrl.includes("/output/")) {
  150. const parts = imageUrl.split("/output/");
  151. if (parts.length === 2) {
  152. const filename = parts[1].split("?")[0]; // Remove query params
  153. return apiClient.getImageUrl(jobId, filename);
  154. }
  155. }
  156. // If it's just a filename, convert it directly
  157. return apiClient.getImageUrl(jobId, imageUrl);
  158. });
  159. } else {
  160. console.log(`[DEBUG] No outputs or images found in job response`);
  161. }
  162. console.log(`[DEBUG] Final image URLs:`, imageUrls);
  163. // Create a new array to trigger React re-render
  164. setGeneratedImages([...imageUrls]);
  165. addImages(imageUrls, jobId);
  166. setLoading(false);
  167. isPolling = false;
  168. } else if (status.job.status === "failed") {
  169. console.log(`[DEBUG] Job failed with error: ${status.job.error}`);
  170. setError(status.job.error || "Generation failed");
  171. setLoading(false);
  172. isPolling = false;
  173. } else if (status.job.status === "cancelled") {
  174. console.log(`[DEBUG] Job was cancelled`);
  175. setError("Generation was cancelled");
  176. setLoading(false);
  177. isPolling = false;
  178. } else if (attempts < maxAttempts) {
  179. attempts++;
  180. timeoutId = setTimeout(poll, 2000);
  181. } else {
  182. console.log(`[DEBUG] Job polling timeout after ${attempts} attempts`);
  183. setError("Job polling timeout");
  184. setLoading(false);
  185. isPolling = false;
  186. }
  187. } catch (err) {
  188. console.log(`[DEBUG] Error polling job status:`, err);
  189. if (isPolling) {
  190. setError(
  191. err instanceof Error ? err.message : "Failed to check job status",
  192. );
  193. setLoading(false);
  194. isPolling = false;
  195. }
  196. }
  197. };
  198. poll();
  199. // Return cleanup function
  200. return () => {
  201. isPolling = false;
  202. if (timeoutId) {
  203. clearTimeout(timeoutId);
  204. }
  205. };
  206. };
  207. const handleGenerate = async (e: React.FormEvent) => {
  208. e.preventDefault();
  209. setLoading(true);
  210. setError(null);
  211. setGeneratedImages([]);
  212. setJobInfo(null);
  213. try {
  214. const requestData = {
  215. ...formData,
  216. vae: selectedVae || undefined,
  217. taesd: selectedTaesd || undefined,
  218. };
  219. const job = await apiClient.text2img(requestData);
  220. setJobInfo(job);
  221. const jobId = job.request_id || job.id;
  222. if (jobId) {
  223. pollJobStatus(jobId).then((cleanup) => {
  224. pollCleanupRef.current = cleanup;
  225. });
  226. } else {
  227. setError("No job ID returned from server");
  228. setLoading(false);
  229. }
  230. } catch (err) {
  231. setError(err instanceof Error ? err.message : "Failed to generate image");
  232. setLoading(false);
  233. }
  234. };
  235. const handleCancel = async () => {
  236. const jobId = jobInfo?.request_id || jobInfo?.id;
  237. if (jobId) {
  238. try {
  239. await apiClient.cancelJob(jobId);
  240. setLoading(false);
  241. setError("Generation cancelled");
  242. // Cleanup polling
  243. if (pollCleanupRef.current) {
  244. pollCleanupRef.current();
  245. pollCleanupRef.current = null;
  246. }
  247. } catch (err) {
  248. console.error("Failed to cancel job:", err);
  249. }
  250. }
  251. };
  252. const handleClearPrompts = () => {
  253. setFormData({ ...formData, prompt: "", negative_prompt: "" });
  254. };
  255. const handleResetToDefaults = () => {
  256. setFormData(defaultFormData);
  257. };
  258. const handleServerRestart = async () => {
  259. if (
  260. !confirm(
  261. "Are you sure you want to restart the server? This will cancel all running jobs.",
  262. )
  263. ) {
  264. return;
  265. }
  266. try {
  267. setLoading(true);
  268. await apiClient.restartServer();
  269. setError("Server restart initiated. Please wait...");
  270. setTimeout(() => {
  271. window.location.reload();
  272. }, 3000);
  273. } catch (err) {
  274. setError(err instanceof Error ? err.message : "Failed to restart server");
  275. setLoading(false);
  276. }
  277. };
  278. return (
  279. <AppLayout>
  280. <Header
  281. title="Text to Image"
  282. description="Generate images from text prompts"
  283. />
  284. <div className="container mx-auto p-6">
  285. <div className="grid gap-6 lg:grid-cols-2">
  286. {/* Left Panel - Form */}
  287. <Card>
  288. <CardContent className="pt-6">
  289. <form onSubmit={handleGenerate} className="space-y-4">
  290. <div className="space-y-2">
  291. <Label htmlFor="prompt">Prompt *</Label>
  292. <PromptTextarea
  293. value={formData.prompt}
  294. onChange={(value) =>
  295. setFormData({ ...formData, prompt: value })
  296. }
  297. placeholder="a beautiful landscape with mountains and a lake, sunset, highly detailed..."
  298. rows={4}
  299. loras={loraModels}
  300. embeddings={embeddings}
  301. />
  302. <p className="text-xs text-muted-foreground">
  303. Tip: Use &lt;lora:name:weight&gt; for LoRAs (e.g.,
  304. &lt;lora:myLora:0.8&gt;) and embedding names directly
  305. </p>
  306. </div>
  307. <div className="space-y-2">
  308. <Label htmlFor="negative_prompt">Negative Prompt</Label>
  309. <PromptTextarea
  310. value={formData.negative_prompt || ""}
  311. onChange={(value) =>
  312. setFormData({ ...formData, negative_prompt: value })
  313. }
  314. placeholder="blurry, low quality, distorted..."
  315. rows={2}
  316. loras={loraModels}
  317. embeddings={embeddings}
  318. />
  319. </div>
  320. {/* Utility Buttons */}
  321. <div className="flex gap-2">
  322. <Button
  323. type="button"
  324. variant="outline"
  325. size="sm"
  326. onClick={handleClearPrompts}
  327. disabled={loading}
  328. title="Clear both prompts"
  329. >
  330. <Trash2 className="h-4 w-4 mr-1" />
  331. Clear Prompts
  332. </Button>
  333. <Button
  334. type="button"
  335. variant="outline"
  336. size="sm"
  337. onClick={handleResetToDefaults}
  338. disabled={loading}
  339. title="Reset all fields to defaults"
  340. >
  341. <RotateCcw className="h-4 w-4 mr-1" />
  342. Reset to Defaults
  343. </Button>
  344. <Button
  345. type="button"
  346. variant="outline"
  347. size="sm"
  348. onClick={handleServerRestart}
  349. disabled={loading}
  350. title="Restart the backend server"
  351. >
  352. <Power className="h-4 w-4 mr-1" />
  353. Restart Server
  354. </Button>
  355. </div>
  356. <div className="grid grid-cols-2 gap-4">
  357. <div className="space-y-2">
  358. <Label htmlFor="width">Width</Label>
  359. <Input
  360. id="width"
  361. name="width"
  362. type="number"
  363. value={formData.width}
  364. onChange={handleInputChange}
  365. step={64}
  366. min={256}
  367. max={2048}
  368. />
  369. </div>
  370. <div className="space-y-2">
  371. <Label htmlFor="height">Height</Label>
  372. <Input
  373. id="height"
  374. name="height"
  375. type="number"
  376. value={formData.height}
  377. onChange={handleInputChange}
  378. step={64}
  379. min={256}
  380. max={2048}
  381. />
  382. </div>
  383. </div>
  384. <div className="grid grid-cols-2 gap-4">
  385. <div className="space-y-2">
  386. <Label htmlFor="steps">Steps</Label>
  387. <Input
  388. id="steps"
  389. name="steps"
  390. type="number"
  391. value={formData.steps}
  392. onChange={handleInputChange}
  393. min={1}
  394. max={150}
  395. />
  396. </div>
  397. <div className="space-y-2">
  398. <Label htmlFor="cfg_scale">CFG Scale</Label>
  399. <Input
  400. id="cfg_scale"
  401. name="cfg_scale"
  402. type="number"
  403. value={formData.cfg_scale}
  404. onChange={handleInputChange}
  405. step={0.5}
  406. min={1}
  407. max={30}
  408. />
  409. </div>
  410. </div>
  411. <div className="space-y-2">
  412. <Label htmlFor="seed">Seed (optional)</Label>
  413. <Input
  414. id="seed"
  415. name="seed"
  416. value={formData.seed}
  417. onChange={handleInputChange}
  418. placeholder="Leave empty for random"
  419. />
  420. </div>
  421. <div className="space-y-2">
  422. <Label htmlFor="sampling_method">Sampling Method</Label>
  423. <select
  424. id="sampling_method"
  425. name="sampling_method"
  426. value={formData.sampling_method}
  427. onChange={handleInputChange}
  428. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  429. >
  430. {samplers.length > 0 ? (
  431. samplers.map((sampler) => (
  432. <option key={sampler.name} value={sampler.name}>
  433. {sampler.name.toUpperCase()} - {sampler.description}
  434. </option>
  435. ))
  436. ) : (
  437. <option value="euler_a">Loading...</option>
  438. )}
  439. </select>
  440. </div>
  441. <div className="space-y-2">
  442. <Label htmlFor="scheduler">Scheduler</Label>
  443. <select
  444. id="scheduler"
  445. name="scheduler"
  446. value={formData.scheduler}
  447. onChange={handleInputChange}
  448. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  449. >
  450. {schedulers.length > 0 ? (
  451. schedulers.map((scheduler) => (
  452. <option key={scheduler.name} value={scheduler.name}>
  453. {scheduler.name.toUpperCase()} -{" "}
  454. {scheduler.description}
  455. </option>
  456. ))
  457. ) : (
  458. <option value="default">Loading...</option>
  459. )}
  460. </select>
  461. </div>
  462. <div className="space-y-2">
  463. <Label>VAE Model (Optional)</Label>
  464. <Select
  465. value={selectedVae || "none"}
  466. onValueChange={(value) => setSelectedVae(value === "none" ? undefined : value)}
  467. >
  468. <SelectTrigger>
  469. <SelectValue placeholder="Select VAE model" />
  470. </SelectTrigger>
  471. <SelectContent>
  472. <SelectItem value="none">None</SelectItem>
  473. {vaeModels.map((model) => {
  474. const modelId = model.sha256_short || model.sha256 || model.id || model.name;
  475. const displayName = model.sha256_short
  476. ? `${model.name} (${model.sha256_short})`
  477. : model.name;
  478. return (
  479. <SelectItem key={modelId} value={modelId}>
  480. {displayName}
  481. </SelectItem>
  482. );
  483. })}
  484. </SelectContent>
  485. </Select>
  486. </div>
  487. <div className="space-y-2">
  488. <Label>TAESD Model (Optional)</Label>
  489. <Select
  490. value={selectedTaesd || "none"}
  491. onValueChange={(value) => setSelectedTaesd(value === "none" ? undefined : value)}
  492. >
  493. <SelectTrigger>
  494. <SelectValue placeholder="Select TAESD model" />
  495. </SelectTrigger>
  496. <SelectContent>
  497. <SelectItem value="none">None</SelectItem>
  498. {taesdModels.map((model) => {
  499. const modelId = model.sha256_short || model.sha256 || model.id || model.name;
  500. const displayName = model.sha256_short
  501. ? `${model.name} (${model.sha256_short})`
  502. : model.name;
  503. return (
  504. <SelectItem key={modelId} value={modelId}>
  505. {displayName}
  506. </SelectItem>
  507. );
  508. })}
  509. </SelectContent>
  510. </Select>
  511. </div>
  512. <div className="space-y-2">
  513. <Label htmlFor="batch_count">Batch Count</Label>
  514. <Input
  515. id="batch_count"
  516. name="batch_count"
  517. type="number"
  518. value={formData.batch_count}
  519. onChange={handleInputChange}
  520. min={1}
  521. max={4}
  522. />
  523. </div>
  524. <div className="flex gap-2">
  525. <Button type="submit" disabled={loading} className="flex-1">
  526. {loading ? (
  527. <>
  528. <Loader2 className="h-4 w-4 animate-spin" />
  529. Generating...
  530. </>
  531. ) : (
  532. "Generate"
  533. )}
  534. </Button>
  535. {loading && (
  536. <Button
  537. type="button"
  538. variant="destructive"
  539. onClick={handleCancel}
  540. >
  541. <X className="h-4 w-4" />
  542. Cancel
  543. </Button>
  544. )}
  545. </div>
  546. {error && (
  547. <div className="rounded-md bg-destructive/10 p-3 text-sm text-destructive">
  548. {error}
  549. </div>
  550. )}
  551. </form>
  552. </CardContent>
  553. </Card>
  554. {/* Right Panel - Generated Images */}
  555. <Card>
  556. <CardContent className="pt-6">
  557. <div className="space-y-4">
  558. <h3 className="text-lg font-semibold">Generated Images</h3>
  559. {/* Progress Display */}
  560. {loading && jobInfo && (
  561. <div className="space-y-2">
  562. <div className="flex justify-between text-sm">
  563. <span>Progress</span>
  564. <span>{Math.round(jobInfo.overall_progress || jobInfo.progress || 0)}%</span>
  565. </div>
  566. <Progress value={jobInfo.overall_progress || jobInfo.progress || 0} className="w-full" />
  567. {jobInfo.model_load_progress !== undefined && jobInfo.generation_progress !== undefined && (
  568. <div className="grid grid-cols-2 gap-4 text-xs text-muted-foreground">
  569. <div>Model Loading: {Math.round(jobInfo.model_load_progress)}%</div>
  570. <div>Generation: {Math.round(jobInfo.generation_progress)}%</div>
  571. </div>
  572. )}
  573. </div>
  574. )}
  575. {generatedImages.length === 0 ? (
  576. <div className="flex h-96 items-center justify-center rounded-lg border-2 border-dashed border-border">
  577. <p className="text-muted-foreground">
  578. {loading
  579. ? "Generating..."
  580. : "Generated images will appear here"}
  581. </p>
  582. </div>
  583. ) : (
  584. <div className="grid gap-4">
  585. {generatedImages.map((image, index) => (
  586. <div key={index} className="relative group">
  587. <img
  588. src={image}
  589. alt={`Generated ${index + 1}`}
  590. className="w-full rounded-lg border border-border"
  591. />
  592. <Button
  593. size="icon"
  594. variant="secondary"
  595. className="absolute top-2 right-2 opacity-0 group-hover:opacity-100 transition-opacity"
  596. onClick={() => {
  597. const authToken =
  598. localStorage.getItem("auth_token");
  599. const unixUser = localStorage.getItem("unix_user");
  600. downloadAuthenticatedImage(
  601. image,
  602. `generated-${Date.now()}-${index}.png`,
  603. authToken || undefined,
  604. unixUser || undefined,
  605. ).catch((err) => {
  606. console.error("Failed to download image:", err);
  607. // Fallback to regular download if authenticated download fails
  608. downloadImage(
  609. image,
  610. `generated-${Date.now()}-${index}.png`,
  611. );
  612. });
  613. }}
  614. >
  615. <Download className="h-4 w-4" />
  616. </Button>
  617. </div>
  618. ))}
  619. </div>
  620. )}
  621. </div>
  622. </CardContent>
  623. </Card>
  624. </div>
  625. </div>
  626. </AppLayout>
  627. );
  628. }
  629. export default function Text2ImgPage() {
  630. return <Text2ImgForm />;
  631. }