page.tsx 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639
  1. "use client";
  2. import { useState, useEffect } from "react";
  3. import { Header } from "@/components/layout";
  4. import { AppLayout } from "@/components/layout";
  5. import { Button } from "@/components/ui/button";
  6. import { Input } from "@/components/ui/input";
  7. import { PromptTextarea } from "@/components/forms";
  8. import { Label } from "@/components/ui/label";
  9. import {
  10. Card,
  11. CardContent,
  12. CardHeader,
  13. CardTitle,
  14. CardDescription,
  15. } from "@/components/ui/card";
  16. import { InpaintingCanvas } from "@/components/features/image-generation";
  17. import { apiClient, type JobInfo, type JobDetailsResponse } from "@/lib/api";
  18. import { Loader2, X, Download } from "lucide-react";
  19. import { downloadAuthenticatedImage } from "@/lib/utils";
  20. import { useLocalStorage, useGeneratedImages } from "@/lib/storage";
  21. import {
  22. useModelSelection,
  23. useCheckpointSelection,
  24. useModelTypeSelection,
  25. } from "@/contexts/model-selection-context";
  26. import {
  27. Select,
  28. SelectContent,
  29. SelectItem,
  30. SelectTrigger,
  31. SelectValue,
  32. } from "@/components/ui/select";
  33. // import { AutoSelectionStatus } from '@/components/features/models';
  34. type InpaintingFormData = {
  35. prompt: string;
  36. negative_prompt: string;
  37. steps: number;
  38. cfg_scale: number;
  39. seed: string;
  40. sampling_method: string;
  41. strength: number;
  42. width?: number;
  43. height?: number;
  44. };
  45. const defaultFormData: InpaintingFormData = {
  46. prompt: "",
  47. negative_prompt: "",
  48. steps: 20,
  49. cfg_scale: 7.5,
  50. seed: "",
  51. sampling_method: "euler_a",
  52. strength: 0.75,
  53. width: 512,
  54. height: 512,
  55. };
  56. function InpaintingForm() {
  57. const { actions } = useModelSelection();
  58. const {
  59. checkpointModels,
  60. selectedCheckpointModel,
  61. selectedCheckpoint,
  62. setSelectedCheckpoint,
  63. isAutoSelecting,
  64. } = useCheckpointSelection();
  65. const {
  66. availableModels: vaeModels,
  67. selectedModel: selectedVae,
  68. isAutoSelected: isVaeAutoSelected,
  69. setSelectedModel: setSelectedVae,
  70. setUserOverride: setVaeUserOverride,
  71. clearUserOverride: clearVaeUserOverride,
  72. } = useModelTypeSelection("vae");
  73. const [formData, setFormData] = useLocalStorage<InpaintingFormData>(
  74. "inpainting-form-data",
  75. defaultFormData,
  76. { excludeLargeData: true, maxSize: 512 * 1024 },
  77. );
  78. // Separate state for image data (not stored in localStorage)
  79. const [sourceImage, setSourceImage] = useState<string>("");
  80. const [maskImage, setMaskImage] = useState<string>("");
  81. const [loading, setLoading] = useState(false);
  82. const [jobInfo, setJobInfo] = useState<JobInfo | null>(null);
  83. const { images: storedImages, addImages, getLatestImages } = useGeneratedImages('inpainting');
  84. const [generatedImages, setGeneratedImages] = useState<string[]>(() => getLatestImages());
  85. const [loraModels, setLoraModels] = useState<string[]>([]);
  86. const [embeddings, setEmbeddings] = useState<string[]>([]);
  87. const [error, setError] = useState<string | null>(null);
  88. const [pollCleanup, setPollCleanup] = useState<(() => void) | null>(null);
  89. // Cleanup polling on unmount
  90. useEffect(() => {
  91. return () => {
  92. if (pollCleanup) {
  93. pollCleanup();
  94. }
  95. };
  96. }, [pollCleanup]);
  97. useEffect(() => {
  98. const loadModels = async () => {
  99. try {
  100. const [modelsData, loras, embeds] = await Promise.all([
  101. apiClient.getModels(), // Get all models with enhanced info
  102. apiClient.getModels("lora"),
  103. apiClient.getModels("embedding"),
  104. ]);
  105. actions.setModels(modelsData.models);
  106. setLoraModels(loras.models.map((m) => m.name));
  107. setEmbeddings(embeds.models.map((m) => m.name));
  108. } catch (err) {
  109. console.error("Failed to load models:", err);
  110. }
  111. };
  112. loadModels();
  113. }, [actions]);
  114. // Update form data when checkpoint changes
  115. useEffect(() => {
  116. if (selectedCheckpoint) {
  117. setFormData((prev) => ({
  118. ...prev,
  119. model: selectedCheckpoint,
  120. }));
  121. }
  122. }, [selectedCheckpoint, setFormData]);
  123. const handleInputChange = (
  124. e: React.ChangeEvent<
  125. HTMLInputElement | HTMLTextAreaElement | HTMLSelectElement
  126. >,
  127. ) => {
  128. const { name, value } = e.target;
  129. setFormData((prev) => ({
  130. ...prev,
  131. [name]:
  132. name === "prompt" ||
  133. name === "negative_prompt" ||
  134. name === "seed" ||
  135. name === "sampling_method"
  136. ? value
  137. : Number(value),
  138. }));
  139. };
  140. const handleSourceImageChange = (image: string) => {
  141. setSourceImage(image);
  142. setError(null);
  143. };
  144. const handleMaskImageChange = (image: string) => {
  145. setMaskImage(image);
  146. setError(null);
  147. };
  148. const pollJobStatus = async (jobId: string) => {
  149. const maxAttempts = 300; // 5 minutes with 2 second interval
  150. let attempts = 0;
  151. let isPolling = true;
  152. let timeoutId: NodeJS.Timeout | null = null;
  153. const poll = async () => {
  154. if (!isPolling) return;
  155. try {
  156. const status: JobDetailsResponse = await apiClient.getJobStatus(jobId);
  157. setJobInfo(status.job);
  158. if (status.job.status === "completed") {
  159. let imageUrls: string[] = [];
  160. // Handle both old format (result.images) and new format (outputs)
  161. if (status.job.outputs && status.job.outputs.length > 0) {
  162. // New format: convert output URLs to authenticated image URLs with cache-busting
  163. imageUrls = status.job.outputs.map((output: { filename: string }) => {
  164. const filename = output.filename;
  165. return apiClient.getImageUrl(jobId, filename);
  166. });
  167. } else if (
  168. status.job.result?.images &&
  169. status.job.result.images.length > 0
  170. ) {
  171. // Old format: convert image URLs to authenticated URLs
  172. imageUrls = status.job.result.images.map((imageUrl: string) => {
  173. // Extract filename from URL if it's already a full URL
  174. if (imageUrl.includes("/output/")) {
  175. const parts = imageUrl.split("/output/");
  176. if (parts.length === 2) {
  177. const filename = parts[1].split("?")[0]; // Remove query params
  178. return apiClient.getImageUrl(jobId, filename);
  179. }
  180. }
  181. return imageUrl; // Already a full URL
  182. });
  183. }
  184. // Create a new array to trigger React re-render
  185. setGeneratedImages([...imageUrls]);
  186. addImages(imageUrls, jobId);
  187. setLoading(false);
  188. isPolling = false;
  189. } else if (status.job.status === "failed") {
  190. setError(status.job.error || "Generation failed");
  191. setLoading(false);
  192. isPolling = false;
  193. } else if (status.job.status === "cancelled") {
  194. setError("Generation was cancelled");
  195. setLoading(false);
  196. isPolling = false;
  197. } else if (attempts < maxAttempts) {
  198. attempts++;
  199. timeoutId = setTimeout(poll, 2000);
  200. } else {
  201. setError("Job polling timeout");
  202. setLoading(false);
  203. isPolling = false;
  204. }
  205. } catch (err) {
  206. if (isPolling) {
  207. setError(
  208. err instanceof Error ? err.message : "Failed to check job status",
  209. );
  210. setLoading(false);
  211. isPolling = false;
  212. }
  213. }
  214. };
  215. poll();
  216. // Return cleanup function
  217. return () => {
  218. isPolling = false;
  219. if (timeoutId) {
  220. clearTimeout(timeoutId);
  221. }
  222. };
  223. };
  224. const handleGenerate = async (e: React.FormEvent) => {
  225. e.preventDefault();
  226. if (!sourceImage) {
  227. setError("Please upload a source image first");
  228. return;
  229. }
  230. if (!maskImage) {
  231. setError("Please create a mask first");
  232. return;
  233. }
  234. setLoading(true);
  235. setError(null);
  236. setGeneratedImages([]);
  237. setJobInfo(null);
  238. try {
  239. // Validate model selection
  240. if (selectedCheckpointModel) {
  241. const validation = actions.validateSelection(selectedCheckpointModel);
  242. if (!validation.isValid) {
  243. setError(
  244. `Missing required models: ${validation.missingRequired.join(", ")}`,
  245. );
  246. setLoading(false);
  247. return;
  248. }
  249. }
  250. const requestData = {
  251. ...formData,
  252. source_image: sourceImage,
  253. mask_image: maskImage,
  254. model: selectedCheckpoint || undefined,
  255. vae: selectedVae || undefined,
  256. };
  257. const job = await apiClient.inpainting(requestData);
  258. setJobInfo(job);
  259. const jobId = job.request_id || job.id;
  260. if (jobId) {
  261. const cleanup = pollJobStatus(jobId);
  262. setPollCleanup(() => cleanup);
  263. } else {
  264. setError("No job ID returned from server");
  265. setLoading(false);
  266. }
  267. } catch (err) {
  268. setError(err instanceof Error ? err.message : "Failed to generate image");
  269. setLoading(false);
  270. }
  271. };
  272. const handleCancel = async () => {
  273. const jobId = jobInfo?.request_id || jobInfo?.id;
  274. if (jobId) {
  275. try {
  276. await apiClient.cancelJob(jobId);
  277. setLoading(false);
  278. setError("Generation cancelled");
  279. // Cleanup polling
  280. if (pollCleanup) {
  281. pollCleanup();
  282. setPollCleanup(null);
  283. }
  284. } catch (err) {
  285. console.error("Failed to cancel job:", err);
  286. }
  287. }
  288. };
  289. return (
  290. <AppLayout>
  291. <Header
  292. title="Inpainting"
  293. description="Edit images by masking areas and regenerating with AI"
  294. />
  295. <div className="container mx-auto p-6">
  296. <div className="grid gap-6 lg:grid-cols-2">
  297. {/* Left Panel - Canvas and Form */}
  298. <div className="space-y-6">
  299. <InpaintingCanvas
  300. onSourceImageChange={handleSourceImageChange}
  301. onMaskImageChange={handleMaskImageChange}
  302. targetWidth={formData.width}
  303. targetHeight={formData.height}
  304. />
  305. <Card>
  306. <CardContent className="pt-6">
  307. <form onSubmit={handleGenerate} className="space-y-4">
  308. <div className="space-y-2">
  309. <Label htmlFor="prompt">Prompt *</Label>
  310. <PromptTextarea
  311. value={formData.prompt}
  312. onChange={(value) =>
  313. setFormData({ ...formData, prompt: value })
  314. }
  315. placeholder="Describe what to generate in the masked areas..."
  316. rows={3}
  317. loras={loraModels}
  318. embeddings={embeddings}
  319. />
  320. <p className="text-xs text-muted-foreground">
  321. Tip: Use {"<lora:name:weight>"} for LoRAs and embedding
  322. names directly
  323. </p>
  324. </div>
  325. <div className="space-y-2">
  326. <Label htmlFor="negative_prompt">Negative Prompt</Label>
  327. <PromptTextarea
  328. value={formData.negative_prompt || ""}
  329. onChange={(value) =>
  330. setFormData({ ...formData, negative_prompt: value })
  331. }
  332. placeholder="What to avoid in the generated areas..."
  333. rows={2}
  334. loras={loraModels}
  335. embeddings={embeddings}
  336. />
  337. </div>
  338. <div className="space-y-2">
  339. <Label htmlFor="strength">
  340. Strength: {formData.strength.toFixed(2)}
  341. </Label>
  342. <Input
  343. id="strength"
  344. name="strength"
  345. type="range"
  346. value={formData.strength}
  347. onChange={handleInputChange}
  348. min={0}
  349. max={1}
  350. step={0.05}
  351. />
  352. <p className="text-xs text-muted-foreground">
  353. Lower values preserve more of the original image
  354. </p>
  355. </div>
  356. <div className="grid grid-cols-2 gap-4">
  357. <div className="space-y-2">
  358. <Label htmlFor="width">Width</Label>
  359. <Input
  360. id="width"
  361. name="width"
  362. type="number"
  363. value={formData.width}
  364. onChange={handleInputChange}
  365. step={64}
  366. min={256}
  367. max={2048}
  368. />
  369. </div>
  370. <div className="space-y-2">
  371. <Label htmlFor="height">Height</Label>
  372. <Input
  373. id="height"
  374. name="height"
  375. type="number"
  376. value={formData.height}
  377. onChange={handleInputChange}
  378. step={64}
  379. min={256}
  380. max={2048}
  381. />
  382. </div>
  383. </div>
  384. <div className="grid grid-cols-2 gap-4">
  385. <div className="space-y-2">
  386. <Label htmlFor="steps">Steps</Label>
  387. <Input
  388. id="steps"
  389. name="steps"
  390. type="number"
  391. value={formData.steps}
  392. onChange={handleInputChange}
  393. min={1}
  394. max={150}
  395. />
  396. </div>
  397. <div className="space-y-2">
  398. <Label htmlFor="cfg_scale">CFG Scale</Label>
  399. <Input
  400. id="cfg_scale"
  401. name="cfg_scale"
  402. type="number"
  403. value={formData.cfg_scale}
  404. onChange={handleInputChange}
  405. step={0.5}
  406. min={1}
  407. max={30}
  408. />
  409. </div>
  410. </div>
  411. <div className="space-y-2">
  412. <Label htmlFor="seed">Seed (optional)</Label>
  413. <Input
  414. id="seed"
  415. name="seed"
  416. value={formData.seed}
  417. onChange={handleInputChange}
  418. placeholder="Leave empty for random"
  419. />
  420. </div>
  421. <div className="space-y-2">
  422. <Label htmlFor="sampling_method">Sampling Method</Label>
  423. <select
  424. id="sampling_method"
  425. name="sampling_method"
  426. value={formData.sampling_method}
  427. onChange={handleInputChange}
  428. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  429. >
  430. <option value="euler">Euler</option>
  431. <option value="euler_a">Euler A</option>
  432. <option value="heun">Heun</option>
  433. <option value="dpm2">DPM2</option>
  434. <option value="dpm++2s_a">DPM++ 2S A</option>
  435. <option value="dpm++2m">DPM++ 2M</option>
  436. <option value="dpm++2mv2">DPM++ 2M V2</option>
  437. <option value="lcm">LCM</option>
  438. </select>
  439. </div>
  440. {/* Model Selection Section */}
  441. <Card>
  442. <CardHeader>
  443. <CardTitle>Model Selection</CardTitle>
  444. <CardDescription>
  445. Select checkpoint and additional models for generation
  446. </CardDescription>
  447. </CardHeader>
  448. {/* Checkpoint Selection */}
  449. <div className="space-y-2">
  450. <Label htmlFor="checkpoint">Checkpoint Model *</Label>
  451. <select
  452. id="checkpoint"
  453. value={selectedCheckpoint || ""}
  454. onChange={(e) =>
  455. setSelectedCheckpoint(e.target.value || null)
  456. }
  457. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  458. disabled={isAutoSelecting}
  459. >
  460. <option value="">Select a checkpoint model...</option>
  461. {checkpointModels.map((model) => (
  462. <option key={model.id} value={model.name}>
  463. {model.name} {model.loaded ? "(Loaded)" : ""}
  464. </option>
  465. ))}
  466. </select>
  467. </div>
  468. {/* VAE Selection */}
  469. <div className="space-y-2">
  470. <Label>VAE Model (Optional)</Label>
  471. <Select
  472. value={selectedVae || "default"}
  473. onValueChange={(value) => {
  474. if (value && value !== "default") {
  475. setSelectedVae(value);
  476. setVaeUserOverride(value);
  477. } else {
  478. clearVaeUserOverride();
  479. }
  480. }}
  481. disabled={isAutoSelecting}
  482. >
  483. <SelectTrigger>
  484. <SelectValue placeholder="Use default VAE" />
  485. </SelectTrigger>
  486. <SelectContent>
  487. <SelectItem value="default">Use default VAE</SelectItem>
  488. {vaeModels.map((model) => (
  489. <SelectItem key={model.name} value={model.name}>
  490. {model.name}
  491. </SelectItem>
  492. ))}
  493. </SelectContent>
  494. </Select>
  495. {isVaeAutoSelected && (
  496. <p className="text-xs text-muted-foreground">
  497. Auto-selected VAE model
  498. </p>
  499. )}
  500. </div>
  501. </Card>
  502. <div className="flex gap-2">
  503. <Button
  504. type="submit"
  505. disabled={loading || !sourceImage || !maskImage}
  506. className="flex-1"
  507. >
  508. {loading ? (
  509. <>
  510. <Loader2 className="h-4 w-4 animate-spin" />
  511. Generating...
  512. </>
  513. ) : (
  514. "Generate"
  515. )}
  516. </Button>
  517. {loading && (
  518. <Button
  519. type="button"
  520. variant="destructive"
  521. onClick={handleCancel}
  522. >
  523. <X className="h-4 w-4" />
  524. Cancel
  525. </Button>
  526. )}
  527. </div>
  528. {error && (
  529. <div className="rounded-md bg-destructive/10 p-3 text-sm text-destructive">
  530. {error}
  531. </div>
  532. )}
  533. </form>
  534. </CardContent>
  535. </Card>
  536. </div>
  537. {/* Right Panel - Generated Images */}
  538. <Card>
  539. <CardContent className="pt-6">
  540. <div className="space-y-4">
  541. <h3 className="text-lg font-semibold">Generated Images</h3>
  542. {generatedImages.length === 0 ? (
  543. <div className="flex h-96 items-center justify-center rounded-lg border-2 border-dashed border-border">
  544. <p className="text-muted-foreground">
  545. {loading
  546. ? "Generating..."
  547. : "Generated images will appear here"}
  548. </p>
  549. </div>
  550. ) : (
  551. <div className="grid gap-4">
  552. {generatedImages.map((image, index) => (
  553. <div key={index} className="relative group">
  554. <img
  555. src={image}
  556. alt={`Generated ${index + 1}`}
  557. className="w-full rounded-lg border border-border"
  558. />
  559. <Button
  560. size="icon"
  561. variant="secondary"
  562. className="absolute top-2 right-2 opacity-0 group-hover:opacity-100 transition-opacity"
  563. onClick={() => {
  564. const authToken =
  565. localStorage.getItem("auth_token");
  566. const unixUser = localStorage.getItem("unix_user");
  567. downloadAuthenticatedImage(
  568. image,
  569. `inpainting-${Date.now()}-${index}.png`,
  570. authToken || undefined,
  571. unixUser || undefined,
  572. ).catch((err) => {
  573. console.error("Failed to download image:", err);
  574. });
  575. }}
  576. >
  577. <Download className="h-4 w-4" />
  578. </Button>
  579. </div>
  580. ))}
  581. </div>
  582. )}
  583. </div>
  584. </CardContent>
  585. </Card>
  586. </div>
  587. </div>
  588. </AppLayout>
  589. );
  590. }
  591. export default function InpaintingPage() {
  592. return <InpaintingForm />;
  593. }