page.tsx 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609
  1. "use client";
  2. import { useState, useEffect } from "react";
  3. import { Header } from "@/components/layout";
  4. import { AppLayout } from "@/components/layout";
  5. import { Button } from "@/components/ui/button";
  6. import { Input } from "@/components/ui/input";
  7. import { PromptTextarea } from "@/components/forms";
  8. import { Label } from "@/components/ui/label";
  9. import {
  10. Card,
  11. CardContent,
  12. CardHeader,
  13. CardTitle,
  14. CardDescription,
  15. } from "@/components/ui/card";
  16. import { InpaintingCanvas } from "@/components/features/image-generation";
  17. import { apiClient, type JobInfo, type JobDetailsResponse } from "@/lib/api";
  18. import { Loader2, X, Download } from "lucide-react";
  19. import { downloadAuthenticatedImage } from "@/lib/utils";
  20. import { useLocalStorage, useGeneratedImages } from "@/lib/storage";
  21. import {
  22. useModelTypeSelection,
  23. } from "@/contexts/model-selection-context";
  24. import {
  25. Select,
  26. SelectContent,
  27. SelectItem,
  28. SelectTrigger,
  29. SelectValue,
  30. } from "@/components/ui/select";
  31. // import { AutoSelectionStatus } from '@/components/features/models';
  32. type InpaintingFormData = {
  33. prompt: string;
  34. negative_prompt: string;
  35. steps: number;
  36. cfg_scale: number;
  37. seed: string;
  38. sampling_method: string;
  39. strength: number;
  40. width?: number;
  41. height?: number;
  42. };
  43. const defaultFormData: InpaintingFormData = {
  44. prompt: "",
  45. negative_prompt: "",
  46. steps: 20,
  47. cfg_scale: 7.5,
  48. seed: "",
  49. sampling_method: "euler_a",
  50. strength: 0.75,
  51. width: 512,
  52. height: 512,
  53. };
  54. function InpaintingForm() {
  55. const {
  56. availableModels: vaeModels,
  57. selectedModel: selectedVae,
  58. isAutoSelected: isVaeAutoSelected,
  59. setSelectedModel: setSelectedVae,
  60. setUserOverride: setVaeUserOverride,
  61. clearUserOverride: clearVaeUserOverride,
  62. } = useModelTypeSelection("vae");
  63. const {
  64. availableModels: taesdModels,
  65. selectedModel: selectedTaesd,
  66. setSelectedModel: setSelectedTaesd,
  67. } = useModelTypeSelection("taesd");
  68. const [formData, setFormData] = useLocalStorage<InpaintingFormData>(
  69. "inpainting-form-data",
  70. defaultFormData,
  71. { excludeLargeData: true, maxSize: 512 * 1024 },
  72. );
  73. // Separate state for image data (not stored in localStorage)
  74. const [sourceImage, setSourceImage] = useState<string>("");
  75. const [maskImage, setMaskImage] = useState<string>("");
  76. const [loading, setLoading] = useState(false);
  77. const [jobInfo, setJobInfo] = useState<JobInfo | null>(null);
  78. const { images: storedImages, addImages, getLatestImages } = useGeneratedImages('inpainting');
  79. const [generatedImages, setGeneratedImages] = useState<string[]>(() => getLatestImages());
  80. const [loraModels, setLoraModels] = useState<string[]>([]);
  81. const [embeddings, setEmbeddings] = useState<string[]>([]);
  82. const [error, setError] = useState<string | null>(null);
  83. const [pollCleanup, setPollCleanup] = useState<(() => void) | null>(null);
  84. // Cleanup polling on unmount
  85. useEffect(() => {
  86. return () => {
  87. if (pollCleanup) {
  88. pollCleanup();
  89. }
  90. };
  91. }, [pollCleanup]);
  92. useEffect(() => {
  93. const loadModels = async () => {
  94. try {
  95. const [loras, embeds] = await Promise.all([
  96. apiClient.getModels("lora"),
  97. apiClient.getModels("embedding"),
  98. ]);
  99. setLoraModels(loras.models.map((m) => m.name));
  100. setEmbeddings(embeds.models.map((m) => m.name));
  101. } catch (err) {
  102. console.error("Failed to load models:", err);
  103. }
  104. };
  105. loadModels();
  106. }, []);
  107. const handleInputChange = (
  108. e: React.ChangeEvent<
  109. HTMLInputElement | HTMLTextAreaElement | HTMLSelectElement
  110. >,
  111. ) => {
  112. const { name, value } = e.target;
  113. setFormData((prev) => ({
  114. ...prev,
  115. [name]:
  116. name === "prompt" ||
  117. name === "negative_prompt" ||
  118. name === "seed" ||
  119. name === "sampling_method"
  120. ? value
  121. : Number(value),
  122. }));
  123. };
  124. const handleSourceImageChange = (image: string) => {
  125. setSourceImage(image);
  126. setError(null);
  127. };
  128. const handleMaskImageChange = (image: string) => {
  129. setMaskImage(image);
  130. setError(null);
  131. };
  132. const pollJobStatus = async (jobId: string) => {
  133. const maxAttempts = 300; // 5 minutes with 2 second interval
  134. let attempts = 0;
  135. let isPolling = true;
  136. let timeoutId: NodeJS.Timeout | null = null;
  137. const poll = async () => {
  138. if (!isPolling) return;
  139. try {
  140. const status: JobDetailsResponse = await apiClient.getJobStatus(jobId);
  141. setJobInfo(status.job);
  142. if (status.job.status === "completed") {
  143. let imageUrls: string[] = [];
  144. // Handle both old format (result.images) and new format (outputs)
  145. if (status.job.outputs && status.job.outputs.length > 0) {
  146. // New format: convert output URLs to authenticated image URLs with cache-busting
  147. imageUrls = status.job.outputs.map((output: { filename: string }) => {
  148. const filename = output.filename;
  149. return apiClient.getImageUrl(jobId, filename);
  150. });
  151. } else if (
  152. status.job.result?.images &&
  153. status.job.result.images.length > 0
  154. ) {
  155. // Old format: convert image URLs to authenticated URLs
  156. imageUrls = status.job.result.images.map((imageUrl: string) => {
  157. // Extract filename from URL if it's already a full URL
  158. if (imageUrl.includes("/output/")) {
  159. const parts = imageUrl.split("/output/");
  160. if (parts.length === 2) {
  161. const filename = parts[1].split("?")[0]; // Remove query params
  162. return apiClient.getImageUrl(jobId, filename);
  163. }
  164. }
  165. return imageUrl; // Already a full URL
  166. });
  167. }
  168. // Create a new array to trigger React re-render
  169. setGeneratedImages([...imageUrls]);
  170. addImages(imageUrls, jobId);
  171. setLoading(false);
  172. isPolling = false;
  173. } else if (status.job.status === "failed") {
  174. setError(status.job.error || "Generation failed");
  175. setLoading(false);
  176. isPolling = false;
  177. } else if (status.job.status === "cancelled") {
  178. setError("Generation was cancelled");
  179. setLoading(false);
  180. isPolling = false;
  181. } else if (attempts < maxAttempts) {
  182. attempts++;
  183. timeoutId = setTimeout(poll, 2000);
  184. } else {
  185. setError("Job polling timeout");
  186. setLoading(false);
  187. isPolling = false;
  188. }
  189. } catch (err) {
  190. if (isPolling) {
  191. setError(
  192. err instanceof Error ? err.message : "Failed to check job status",
  193. );
  194. setLoading(false);
  195. isPolling = false;
  196. }
  197. }
  198. };
  199. poll();
  200. // Return cleanup function
  201. return () => {
  202. isPolling = false;
  203. if (timeoutId) {
  204. clearTimeout(timeoutId);
  205. }
  206. };
  207. };
  208. const handleGenerate = async (e: React.FormEvent) => {
  209. e.preventDefault();
  210. if (!sourceImage) {
  211. setError("Please upload a source image first");
  212. return;
  213. }
  214. if (!maskImage) {
  215. setError("Please create a mask first");
  216. return;
  217. }
  218. setLoading(true);
  219. setError(null);
  220. setGeneratedImages([]);
  221. setJobInfo(null);
  222. try {
  223. const requestData = {
  224. ...formData,
  225. source_image: sourceImage,
  226. mask_image: maskImage,
  227. vae: selectedVae || undefined,
  228. taesd: selectedTaesd || undefined,
  229. };
  230. const job = await apiClient.inpainting(requestData);
  231. setJobInfo(job);
  232. const jobId = job.request_id || job.id;
  233. if (jobId) {
  234. const cleanup = pollJobStatus(jobId);
  235. setPollCleanup(() => cleanup);
  236. } else {
  237. setError("No job ID returned from server");
  238. setLoading(false);
  239. }
  240. } catch (err) {
  241. setError(err instanceof Error ? err.message : "Failed to generate image");
  242. setLoading(false);
  243. }
  244. };
  245. const handleCancel = async () => {
  246. const jobId = jobInfo?.request_id || jobInfo?.id;
  247. if (jobId) {
  248. try {
  249. await apiClient.cancelJob(jobId);
  250. setLoading(false);
  251. setError("Generation cancelled");
  252. // Cleanup polling
  253. if (pollCleanup) {
  254. pollCleanup();
  255. setPollCleanup(null);
  256. }
  257. } catch (err) {
  258. console.error("Failed to cancel job:", err);
  259. }
  260. }
  261. };
  262. return (
  263. <AppLayout>
  264. <Header
  265. title="Inpainting"
  266. description="Edit images by masking areas and regenerating with AI"
  267. />
  268. <div className="container mx-auto p-6">
  269. <div className="grid gap-6 lg:grid-cols-2">
  270. {/* Left Panel - Form Parameters */}
  271. <div className="space-y-6">
  272. <Card>
  273. <CardContent className="pt-6">
  274. <form onSubmit={handleGenerate} className="space-y-4">
  275. <div className="space-y-2">
  276. <Label htmlFor="prompt">Prompt *</Label>
  277. <PromptTextarea
  278. value={formData.prompt}
  279. onChange={(value) =>
  280. setFormData({ ...formData, prompt: value })
  281. }
  282. placeholder="Describe what to generate in the masked areas..."
  283. rows={3}
  284. loras={loraModels}
  285. embeddings={embeddings}
  286. />
  287. <p className="text-xs text-muted-foreground">
  288. Tip: Use {"<lora:name:weight>"} for LoRAs and embedding
  289. names directly
  290. </p>
  291. </div>
  292. <div className="space-y-2">
  293. <Label htmlFor="negative_prompt">Negative Prompt</Label>
  294. <PromptTextarea
  295. value={formData.negative_prompt || ""}
  296. onChange={(value) =>
  297. setFormData({ ...formData, negative_prompt: value })
  298. }
  299. placeholder="What to avoid in the generated areas..."
  300. rows={2}
  301. loras={loraModels}
  302. embeddings={embeddings}
  303. />
  304. </div>
  305. <div className="space-y-2">
  306. <Label htmlFor="strength">
  307. Strength *
  308. <span className="text-xs text-muted-foreground ml-1">
  309. ({formData.strength})
  310. </span>
  311. </Label>
  312. <input
  313. id="strength"
  314. name="strength"
  315. type="range"
  316. min="0.1"
  317. max="1.0"
  318. step="0.1"
  319. value={formData.strength}
  320. onChange={handleInputChange}
  321. className="w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer dark:bg-gray-700"
  322. />
  323. </div>
  324. <div className="grid grid-cols-2 gap-4">
  325. <div className="space-y-2">
  326. <Label htmlFor="steps">Steps</Label>
  327. <input
  328. id="steps"
  329. name="steps"
  330. type="number"
  331. min="1"
  332. max="100"
  333. value={formData.steps}
  334. onChange={handleInputChange}
  335. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  336. />
  337. </div>
  338. <div className="space-y-2">
  339. <Label htmlFor="cfg_scale">CFG Scale</Label>
  340. <input
  341. id="cfg_scale"
  342. name="cfg_scale"
  343. type="number"
  344. min="1"
  345. max="30"
  346. step="0.5"
  347. value={formData.cfg_scale}
  348. onChange={handleInputChange}
  349. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  350. />
  351. </div>
  352. </div>
  353. <div className="grid grid-cols-2 gap-4">
  354. <div className="space-y-2">
  355. <Label htmlFor="width">Width</Label>
  356. <input
  357. id="width"
  358. name="width"
  359. type="number"
  360. min="64"
  361. max="2048"
  362. step="64"
  363. value={formData.width || 512}
  364. onChange={handleInputChange}
  365. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  366. />
  367. </div>
  368. <div className="space-y-2">
  369. <Label htmlFor="height">Height</Label>
  370. <input
  371. id="height"
  372. name="height"
  373. type="number"
  374. min="64"
  375. max="2048"
  376. step="64"
  377. value={formData.height || 512}
  378. onChange={handleInputChange}
  379. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  380. />
  381. </div>
  382. </div>
  383. <div className="space-y-2">
  384. <Label htmlFor="seed">Seed</Label>
  385. <input
  386. id="seed"
  387. name="seed"
  388. type="text"
  389. value={formData.seed}
  390. onChange={handleInputChange}
  391. placeholder="Random"
  392. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  393. />
  394. </div>
  395. <div className="space-y-2">
  396. <Label htmlFor="sampling_method">Sampling Method</Label>
  397. <select
  398. id="sampling_method"
  399. name="sampling_method"
  400. value={formData.sampling_method}
  401. onChange={handleInputChange}
  402. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  403. >
  404. <option value="euler">Euler</option>
  405. <option value="euler_a">Euler A</option>
  406. <option value="heun">Heun</option>
  407. <option value="dpm2">DPM2</option>
  408. <option value="dpm++2s_a">DPM++ 2S A</option>
  409. <option value="dpm++2m">DPM++ 2M</option>
  410. <option value="dpm++2mv2">DPM++ 2M V2</option>
  411. <option value="lcm">LCM</option>
  412. </select>
  413. </div>
  414. {/* Additional Models Section */}
  415. <Card>
  416. <CardHeader>
  417. <CardTitle>Additional Models</CardTitle>
  418. <CardDescription>
  419. Select additional models for generation
  420. </CardDescription>
  421. </CardHeader>
  422. <CardContent className="space-y-4">
  423. <div className="space-y-2">
  424. <Label>VAE Model (Optional)</Label>
  425. <Select
  426. value={selectedVae || "none"}
  427. onValueChange={(value) => setSelectedVae(value === "none" ? undefined : value)}
  428. >
  429. <SelectTrigger>
  430. <SelectValue placeholder="Select VAE model" />
  431. </SelectTrigger>
  432. <SelectContent>
  433. <SelectItem value="none">None</SelectItem>
  434. {vaeModels.map((model) => (
  435. <SelectItem key={model.name} value={model.name}>
  436. {model.name}
  437. </SelectItem>
  438. ))}
  439. </SelectContent>
  440. </Select>
  441. </div>
  442. <div className="space-y-2">
  443. <Label>TAESD Model (Optional)</Label>
  444. <Select
  445. value={selectedTaesd || "none"}
  446. onValueChange={(value) => setSelectedTaesd(value === "none" ? undefined : value)}
  447. >
  448. <SelectTrigger>
  449. <SelectValue placeholder="Select TAESD model" />
  450. </SelectTrigger>
  451. <SelectContent>
  452. <SelectItem value="none">None</SelectItem>
  453. {taesdModels.map((model) => (
  454. <SelectItem key={model.name} value={model.name}>
  455. {model.name}
  456. </SelectItem>
  457. ))}
  458. </SelectContent>
  459. </Select>
  460. </div>
  461. </CardContent>
  462. </Card>
  463. <div className="flex gap-2">
  464. <Button type="submit" disabled={loading} className="flex-1">
  465. {loading ? (
  466. <>
  467. <Loader2 className="mr-2 h-4 w-4 animate-spin" />
  468. Generating...
  469. </>
  470. ) : (
  471. "Generate"
  472. )}
  473. </Button>
  474. {jobInfo && (
  475. <Button
  476. type="button"
  477. variant="outline"
  478. onClick={handleCancel}
  479. disabled={!loading}
  480. >
  481. <X className="h-4 w-4" />
  482. </Button>
  483. )}
  484. </div>
  485. {error && (
  486. <div className="rounded-md bg-destructive/10 p-3 text-sm text-destructive">
  487. {error}
  488. </div>
  489. )}
  490. </form>
  491. </CardContent>
  492. </Card>
  493. </div>
  494. {/* Right Panel - Source Image and Results */}
  495. <div className="space-y-6">
  496. <InpaintingCanvas
  497. onSourceImageChange={handleSourceImageChange}
  498. onMaskImageChange={handleMaskImageChange}
  499. targetWidth={formData.width}
  500. targetHeight={formData.height}
  501. />
  502. <Card>
  503. <CardContent className="pt-6">
  504. <div className="space-y-4">
  505. <h3 className="text-lg font-semibold">Generated Images</h3>
  506. {generatedImages.length === 0 ? (
  507. <div className="flex h-96 items-center justify-center rounded-lg border-2 border-dashed border-border">
  508. <p className="text-muted-foreground">
  509. {loading
  510. ? "Generating..."
  511. : "Generated images will appear here"}
  512. </p>
  513. </div>
  514. ) : (
  515. <div className="grid gap-4">
  516. {generatedImages.map((image, index) => (
  517. <div key={index} className="relative group">
  518. <img
  519. src={image}
  520. alt={`Generated ${index + 1}`}
  521. className="w-full rounded-lg border border-border"
  522. />
  523. <Button
  524. size="icon"
  525. variant="secondary"
  526. className="absolute top-2 right-2 opacity-0 group-hover:opacity-100 transition-opacity"
  527. onClick={() => {
  528. const authToken =
  529. localStorage.getItem("auth_token");
  530. const unixUser = localStorage.getItem("unix_user");
  531. downloadAuthenticatedImage(
  532. image,
  533. `inpainting-${Date.now()}-${index}.png`,
  534. authToken || undefined,
  535. unixUser || undefined,
  536. ).catch((err) => {
  537. console.error("Failed to download image:", err);
  538. });
  539. }}
  540. >
  541. <Download className="h-4 w-4" />
  542. </Button>
  543. </div>
  544. ))}
  545. </div>
  546. )}
  547. </div>
  548. </CardContent>
  549. </Card>
  550. </div>
  551. </div>
  552. </div>
  553. </AppLayout>
  554. );
  555. }
  556. export default function InpaintingPage() {
  557. return <InpaintingForm />;
  558. }