page.tsx 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567
  1. "use client";
  2. import { useState, useEffect, useRef } from "react";
  3. import { Header } from "@/components/layout";
  4. import { AppLayout } from "@/components/layout";
  5. import { Button } from "@/components/ui/button";
  6. import { Input } from "@/components/ui/input";
  7. import { PromptTextarea } from "@/components/forms";
  8. import { Label } from "@/components/ui/label";
  9. import { Card, CardContent } from "@/components/ui/card";
  10. import {
  11. apiClient,
  12. type GenerationRequest,
  13. type JobInfo,
  14. type JobDetailsResponse,
  15. } from "@/lib/api";
  16. import { Loader2, Download, X, Trash2, RotateCcw, Power } from "lucide-react";
  17. import { downloadImage, downloadAuthenticatedImage } from "@/lib/utils";
  18. import { useLocalStorage } from "@/lib/storage";
  19. const defaultFormData: GenerationRequest = {
  20. prompt: "",
  21. negative_prompt: "",
  22. width: 512,
  23. height: 512,
  24. steps: 20,
  25. cfg_scale: 7.5,
  26. seed: "",
  27. sampling_method: "euler_a",
  28. scheduler: "default",
  29. batch_count: 1,
  30. };
  31. function Text2ImgForm() {
  32. const [formData, setFormData] = useLocalStorage<GenerationRequest>(
  33. "text2img-form-data",
  34. defaultFormData,
  35. { excludeLargeData: true, maxSize: 512 * 1024 }, // 512KB limit
  36. );
  37. const [loading, setLoading] = useState(false);
  38. const [jobInfo, setJobInfo] = useState<JobInfo | null>(null);
  39. const [generatedImages, setGeneratedImages] = useState<string[]>([]);
  40. const [samplers, setSamplers] = useState<
  41. Array<{ name: string; description: string }>
  42. >([]);
  43. const [schedulers, setSchedulers] = useState<
  44. Array<{ name: string; description: string }>
  45. >([]);
  46. const [loraModels, setLoraModels] = useState<string[]>([]);
  47. const [embeddings, setEmbeddings] = useState<string[]>([]);
  48. const [error, setError] = useState<string | null>(null);
  49. const pollCleanupRef = useRef<(() => void) | null>(null);
  50. // Cleanup polling on unmount
  51. useEffect(() => {
  52. return () => {
  53. if (pollCleanupRef.current) {
  54. pollCleanupRef.current();
  55. pollCleanupRef.current = null;
  56. }
  57. };
  58. }, []);
  59. useEffect(() => {
  60. const loadOptions = async () => {
  61. try {
  62. const [samplersData, schedulersData, loras, embeds] = await Promise.all(
  63. [
  64. apiClient.getSamplers(),
  65. apiClient.getSchedulers(),
  66. apiClient.getModels("lora"),
  67. apiClient.getModels("embedding"),
  68. ],
  69. );
  70. setSamplers(samplersData);
  71. setSchedulers(schedulersData);
  72. setLoraModels(loras.models.map((m) => m.name));
  73. setEmbeddings(embeds.models.map((m) => m.name));
  74. } catch (err) {
  75. console.error("Failed to load options:", err);
  76. }
  77. };
  78. loadOptions();
  79. }, []);
  80. const handleInputChange = (
  81. e: React.ChangeEvent<
  82. HTMLInputElement | HTMLTextAreaElement | HTMLSelectElement
  83. >,
  84. ) => {
  85. const { name, value } = e.target;
  86. setFormData((prev) => ({
  87. ...prev,
  88. [name]:
  89. name === "prompt" ||
  90. name === "negative_prompt" ||
  91. name === "seed" ||
  92. name === "sampling_method" ||
  93. name === "scheduler"
  94. ? value
  95. : Number(value),
  96. }));
  97. };
  98. const pollJobStatus = async (jobId: string) => {
  99. const maxAttempts = 300; // 5 minutes with 2 second interval
  100. let attempts = 0;
  101. let isPolling = true;
  102. let timeoutId: NodeJS.Timeout | null = null;
  103. const poll = async () => {
  104. if (!isPolling) return;
  105. try {
  106. const status: JobDetailsResponse = await apiClient.getJobStatus(jobId);
  107. setJobInfo(status.job);
  108. if (status.job.status === "completed") {
  109. let imageUrls: string[] = [];
  110. // Handle both old format (result.images) and new format (outputs)
  111. if (status.job.outputs && status.job.outputs.length > 0) {
  112. // New format: convert output URLs to authenticated image URLs with cache-busting
  113. imageUrls = status.job.outputs.map((output: { filename: string }) => {
  114. const filename = output.filename;
  115. return apiClient.getImageUrl(jobId, filename);
  116. });
  117. } else if (
  118. status.job.result?.images &&
  119. status.job.result.images.length > 0
  120. ) {
  121. // Old format: convert image URLs to authenticated URLs
  122. imageUrls = status.job.result.images.map((imageUrl: string) => {
  123. // Extract filename from URL if it's a full URL
  124. if (imageUrl.includes("/output/")) {
  125. const parts = imageUrl.split("/output/");
  126. if (parts.length === 2) {
  127. const filename = parts[1].split("?")[0]; // Remove query params
  128. return apiClient.getImageUrl(jobId, filename);
  129. }
  130. }
  131. // If it's just a filename, convert it directly
  132. return apiClient.getImageUrl(jobId, imageUrl);
  133. });
  134. }
  135. // Create a new array to trigger React re-render
  136. setGeneratedImages([...imageUrls]);
  137. setLoading(false);
  138. isPolling = false;
  139. } else if (status.job.status === "failed") {
  140. setError(status.job.error || "Generation failed");
  141. setLoading(false);
  142. isPolling = false;
  143. } else if (status.job.status === "cancelled") {
  144. setError("Generation was cancelled");
  145. setLoading(false);
  146. isPolling = false;
  147. } else if (attempts < maxAttempts) {
  148. attempts++;
  149. timeoutId = setTimeout(poll, 2000);
  150. } else {
  151. setError("Job polling timeout");
  152. setLoading(false);
  153. isPolling = false;
  154. }
  155. } catch (err) {
  156. if (isPolling) {
  157. setError(
  158. err instanceof Error ? err.message : "Failed to check job status",
  159. );
  160. setLoading(false);
  161. isPolling = false;
  162. }
  163. }
  164. };
  165. poll();
  166. // Return cleanup function
  167. return () => {
  168. isPolling = false;
  169. if (timeoutId) {
  170. clearTimeout(timeoutId);
  171. }
  172. };
  173. };
  174. const handleGenerate = async (e: React.FormEvent) => {
  175. e.preventDefault();
  176. setLoading(true);
  177. setError(null);
  178. setGeneratedImages([]);
  179. setJobInfo(null);
  180. try {
  181. const requestData = {
  182. ...formData,
  183. };
  184. const job = await apiClient.text2img(requestData);
  185. setJobInfo(job);
  186. const jobId = job.request_id || job.id;
  187. if (jobId) {
  188. pollJobStatus(jobId).then((cleanup) => {
  189. pollCleanupRef.current = cleanup;
  190. });
  191. } else {
  192. setError("No job ID returned from server");
  193. setLoading(false);
  194. }
  195. } catch (err) {
  196. setError(err instanceof Error ? err.message : "Failed to generate image");
  197. setLoading(false);
  198. }
  199. };
  200. const handleCancel = async () => {
  201. const jobId = jobInfo?.request_id || jobInfo?.id;
  202. if (jobId) {
  203. try {
  204. await apiClient.cancelJob(jobId);
  205. setLoading(false);
  206. setError("Generation cancelled");
  207. // Cleanup polling
  208. if (pollCleanupRef.current) {
  209. pollCleanupRef.current();
  210. pollCleanupRef.current = null;
  211. }
  212. } catch (err) {
  213. console.error("Failed to cancel job:", err);
  214. }
  215. }
  216. };
  217. const handleClearPrompts = () => {
  218. setFormData({ ...formData, prompt: "", negative_prompt: "" });
  219. };
  220. const handleResetToDefaults = () => {
  221. setFormData(defaultFormData);
  222. };
  223. const handleServerRestart = async () => {
  224. if (
  225. !confirm(
  226. "Are you sure you want to restart the server? This will cancel all running jobs.",
  227. )
  228. ) {
  229. return;
  230. }
  231. try {
  232. setLoading(true);
  233. await apiClient.restartServer();
  234. setError("Server restart initiated. Please wait...");
  235. setTimeout(() => {
  236. window.location.reload();
  237. }, 3000);
  238. } catch (err) {
  239. setError(err instanceof Error ? err.message : "Failed to restart server");
  240. setLoading(false);
  241. }
  242. };
  243. return (
  244. <AppLayout>
  245. <Header
  246. title="Text to Image"
  247. description="Generate images from text prompts"
  248. />
  249. <div className="container mx-auto p-6">
  250. <div className="grid gap-6 lg:grid-cols-2">
  251. {/* Left Panel - Form */}
  252. <Card>
  253. <CardContent className="pt-6">
  254. <form onSubmit={handleGenerate} className="space-y-4">
  255. <div className="space-y-2">
  256. <Label htmlFor="prompt">Prompt *</Label>
  257. <PromptTextarea
  258. value={formData.prompt}
  259. onChange={(value) =>
  260. setFormData({ ...formData, prompt: value })
  261. }
  262. placeholder="a beautiful landscape with mountains and a lake, sunset, highly detailed..."
  263. rows={4}
  264. loras={loraModels}
  265. embeddings={embeddings}
  266. />
  267. <p className="text-xs text-muted-foreground">
  268. Tip: Use &lt;lora:name:weight&gt; for LoRAs (e.g.,
  269. &lt;lora:myLora:0.8&gt;) and embedding names directly
  270. </p>
  271. </div>
  272. <div className="space-y-2">
  273. <Label htmlFor="negative_prompt">Negative Prompt</Label>
  274. <PromptTextarea
  275. value={formData.negative_prompt || ""}
  276. onChange={(value) =>
  277. setFormData({ ...formData, negative_prompt: value })
  278. }
  279. placeholder="blurry, low quality, distorted..."
  280. rows={2}
  281. loras={loraModels}
  282. embeddings={embeddings}
  283. />
  284. </div>
  285. {/* Utility Buttons */}
  286. <div className="flex gap-2">
  287. <Button
  288. type="button"
  289. variant="outline"
  290. size="sm"
  291. onClick={handleClearPrompts}
  292. disabled={loading}
  293. title="Clear both prompts"
  294. >
  295. <Trash2 className="h-4 w-4 mr-1" />
  296. Clear Prompts
  297. </Button>
  298. <Button
  299. type="button"
  300. variant="outline"
  301. size="sm"
  302. onClick={handleResetToDefaults}
  303. disabled={loading}
  304. title="Reset all fields to defaults"
  305. >
  306. <RotateCcw className="h-4 w-4 mr-1" />
  307. Reset to Defaults
  308. </Button>
  309. <Button
  310. type="button"
  311. variant="outline"
  312. size="sm"
  313. onClick={handleServerRestart}
  314. disabled={loading}
  315. title="Restart the backend server"
  316. >
  317. <Power className="h-4 w-4 mr-1" />
  318. Restart Server
  319. </Button>
  320. </div>
  321. <div className="grid grid-cols-2 gap-4">
  322. <div className="space-y-2">
  323. <Label htmlFor="width">Width</Label>
  324. <Input
  325. id="width"
  326. name="width"
  327. type="number"
  328. value={formData.width}
  329. onChange={handleInputChange}
  330. step={64}
  331. min={256}
  332. max={2048}
  333. />
  334. </div>
  335. <div className="space-y-2">
  336. <Label htmlFor="height">Height</Label>
  337. <Input
  338. id="height"
  339. name="height"
  340. type="number"
  341. value={formData.height}
  342. onChange={handleInputChange}
  343. step={64}
  344. min={256}
  345. max={2048}
  346. />
  347. </div>
  348. </div>
  349. <div className="grid grid-cols-2 gap-4">
  350. <div className="space-y-2">
  351. <Label htmlFor="steps">Steps</Label>
  352. <Input
  353. id="steps"
  354. name="steps"
  355. type="number"
  356. value={formData.steps}
  357. onChange={handleInputChange}
  358. min={1}
  359. max={150}
  360. />
  361. </div>
  362. <div className="space-y-2">
  363. <Label htmlFor="cfg_scale">CFG Scale</Label>
  364. <Input
  365. id="cfg_scale"
  366. name="cfg_scale"
  367. type="number"
  368. value={formData.cfg_scale}
  369. onChange={handleInputChange}
  370. step={0.5}
  371. min={1}
  372. max={30}
  373. />
  374. </div>
  375. </div>
  376. <div className="space-y-2">
  377. <Label htmlFor="seed">Seed (optional)</Label>
  378. <Input
  379. id="seed"
  380. name="seed"
  381. value={formData.seed}
  382. onChange={handleInputChange}
  383. placeholder="Leave empty for random"
  384. />
  385. </div>
  386. <div className="space-y-2">
  387. <Label htmlFor="sampling_method">Sampling Method</Label>
  388. <select
  389. id="sampling_method"
  390. name="sampling_method"
  391. value={formData.sampling_method}
  392. onChange={handleInputChange}
  393. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  394. >
  395. {samplers.length > 0 ? (
  396. samplers.map((sampler) => (
  397. <option key={sampler.name} value={sampler.name}>
  398. {sampler.name.toUpperCase()} - {sampler.description}
  399. </option>
  400. ))
  401. ) : (
  402. <option value="euler_a">Loading...</option>
  403. )}
  404. </select>
  405. </div>
  406. <div className="space-y-2">
  407. <Label htmlFor="scheduler">Scheduler</Label>
  408. <select
  409. id="scheduler"
  410. name="scheduler"
  411. value={formData.scheduler}
  412. onChange={handleInputChange}
  413. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  414. >
  415. {schedulers.length > 0 ? (
  416. schedulers.map((scheduler) => (
  417. <option key={scheduler.name} value={scheduler.name}>
  418. {scheduler.name.toUpperCase()} -{" "}
  419. {scheduler.description}
  420. </option>
  421. ))
  422. ) : (
  423. <option value="default">Loading...</option>
  424. )}
  425. </select>
  426. </div>
  427. <div className="space-y-2">
  428. <Label htmlFor="batch_count">Batch Count</Label>
  429. <Input
  430. id="batch_count"
  431. name="batch_count"
  432. type="number"
  433. value={formData.batch_count}
  434. onChange={handleInputChange}
  435. min={1}
  436. max={4}
  437. />
  438. </div>
  439. <div className="flex gap-2">
  440. <Button type="submit" disabled={loading} className="flex-1">
  441. {loading ? (
  442. <>
  443. <Loader2 className="h-4 w-4 animate-spin" />
  444. Generating...
  445. </>
  446. ) : (
  447. "Generate"
  448. )}
  449. </Button>
  450. {loading && (
  451. <Button
  452. type="button"
  453. variant="destructive"
  454. onClick={handleCancel}
  455. >
  456. <X className="h-4 w-4" />
  457. Cancel
  458. </Button>
  459. )}
  460. </div>
  461. {error && (
  462. <div className="rounded-md bg-destructive/10 p-3 text-sm text-destructive">
  463. {error}
  464. </div>
  465. )}
  466. </form>
  467. </CardContent>
  468. </Card>
  469. {/* Right Panel - Generated Images */}
  470. <Card>
  471. <CardContent className="pt-6">
  472. <div className="space-y-4">
  473. <h3 className="text-lg font-semibold">Generated Images</h3>
  474. {generatedImages.length === 0 ? (
  475. <div className="flex h-96 items-center justify-center rounded-lg border-2 border-dashed border-border">
  476. <p className="text-muted-foreground">
  477. {loading
  478. ? "Generating..."
  479. : "Generated images will appear here"}
  480. </p>
  481. </div>
  482. ) : (
  483. <div className="grid gap-4">
  484. {generatedImages.map((image, index) => (
  485. <div key={index} className="relative group">
  486. <img
  487. src={image}
  488. alt={`Generated ${index + 1}`}
  489. className="w-full rounded-lg border border-border"
  490. />
  491. <Button
  492. size="icon"
  493. variant="secondary"
  494. className="absolute top-2 right-2 opacity-0 group-hover:opacity-100 transition-opacity"
  495. onClick={() => {
  496. const authToken =
  497. localStorage.getItem("auth_token");
  498. const unixUser = localStorage.getItem("unix_user");
  499. downloadAuthenticatedImage(
  500. image,
  501. `generated-${Date.now()}-${index}.png`,
  502. authToken || undefined,
  503. unixUser || undefined,
  504. ).catch((err) => {
  505. console.error("Failed to download image:", err);
  506. // Fallback to regular download if authenticated download fails
  507. downloadImage(
  508. image,
  509. `generated-${Date.now()}-${index}.png`,
  510. );
  511. });
  512. }}
  513. >
  514. <Download className="h-4 w-4" />
  515. </Button>
  516. </div>
  517. ))}
  518. </div>
  519. )}
  520. </div>
  521. </CardContent>
  522. </Card>
  523. </div>
  524. </div>
  525. </AppLayout>
  526. );
  527. }
  528. export default function Text2ImgPage() {
  529. return <Text2ImgForm />;
  530. }