page.tsx 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637
  1. "use client";
  2. import { useState, useEffect } from "react";
  3. import { Header } from "@/components/layout";
  4. import { AppLayout } from "@/components/layout";
  5. import { Button } from "@/components/ui/button";
  6. import { Input } from "@/components/ui/input";
  7. import { PromptTextarea } from "@/components/forms";
  8. import { Label } from "@/components/ui/label";
  9. import {
  10. Card,
  11. CardContent,
  12. CardHeader,
  13. CardTitle,
  14. CardDescription,
  15. } from "@/components/ui/card";
  16. import { InpaintingCanvas } from "@/components/features/image-generation";
  17. import { apiClient, type JobInfo, type JobDetailsResponse } from "@/lib/api";
  18. import { Loader2, X, Download } from "lucide-react";
  19. import { downloadAuthenticatedImage } from "@/lib/utils";
  20. import { useLocalStorage } from "@/lib/storage";
  21. import {
  22. useModelSelection,
  23. useCheckpointSelection,
  24. useModelTypeSelection,
  25. } from "@/contexts/model-selection-context";
  26. import {
  27. Select,
  28. SelectContent,
  29. SelectItem,
  30. SelectTrigger,
  31. SelectValue,
  32. } from "@/components/ui/select";
  33. // import { AutoSelectionStatus } from '@/components/features/models';
  34. type InpaintingFormData = {
  35. prompt: string;
  36. negative_prompt: string;
  37. steps: number;
  38. cfg_scale: number;
  39. seed: string;
  40. sampling_method: string;
  41. strength: number;
  42. width?: number;
  43. height?: number;
  44. };
  45. const defaultFormData: InpaintingFormData = {
  46. prompt: "",
  47. negative_prompt: "",
  48. steps: 20,
  49. cfg_scale: 7.5,
  50. seed: "",
  51. sampling_method: "euler_a",
  52. strength: 0.75,
  53. width: 512,
  54. height: 512,
  55. };
  56. function InpaintingForm() {
  57. const { actions } = useModelSelection();
  58. const {
  59. checkpointModels,
  60. selectedCheckpointModel,
  61. selectedCheckpoint,
  62. setSelectedCheckpoint,
  63. isAutoSelecting,
  64. } = useCheckpointSelection();
  65. const {
  66. availableModels: vaeModels,
  67. selectedModel: selectedVae,
  68. isAutoSelected: isVaeAutoSelected,
  69. setSelectedModel: setSelectedVae,
  70. setUserOverride: setVaeUserOverride,
  71. clearUserOverride: clearVaeUserOverride,
  72. } = useModelTypeSelection("vae");
  73. const [formData, setFormData] = useLocalStorage<InpaintingFormData>(
  74. "inpainting-form-data",
  75. defaultFormData,
  76. { excludeLargeData: true, maxSize: 512 * 1024 },
  77. );
  78. // Separate state for image data (not stored in localStorage)
  79. const [sourceImage, setSourceImage] = useState<string>("");
  80. const [maskImage, setMaskImage] = useState<string>("");
  81. const [loading, setLoading] = useState(false);
  82. const [jobInfo, setJobInfo] = useState<JobInfo | null>(null);
  83. const [generatedImages, setGeneratedImages] = useState<string[]>([]);
  84. const [loraModels, setLoraModels] = useState<string[]>([]);
  85. const [embeddings, setEmbeddings] = useState<string[]>([]);
  86. const [error, setError] = useState<string | null>(null);
  87. const [pollCleanup, setPollCleanup] = useState<(() => void) | null>(null);
  88. // Cleanup polling on unmount
  89. useEffect(() => {
  90. return () => {
  91. if (pollCleanup) {
  92. pollCleanup();
  93. }
  94. };
  95. }, [pollCleanup]);
  96. useEffect(() => {
  97. const loadModels = async () => {
  98. try {
  99. const [modelsData, loras, embeds] = await Promise.all([
  100. apiClient.getModels(), // Get all models with enhanced info
  101. apiClient.getModels("lora"),
  102. apiClient.getModels("embedding"),
  103. ]);
  104. actions.setModels(modelsData.models);
  105. setLoraModels(loras.models.map((m) => m.name));
  106. setEmbeddings(embeds.models.map((m) => m.name));
  107. } catch (err) {
  108. console.error("Failed to load models:", err);
  109. }
  110. };
  111. loadModels();
  112. }, [actions]);
  113. // Update form data when checkpoint changes
  114. useEffect(() => {
  115. if (selectedCheckpoint) {
  116. setFormData((prev) => ({
  117. ...prev,
  118. model: selectedCheckpoint,
  119. }));
  120. }
  121. }, [selectedCheckpoint, setFormData]);
  122. const handleInputChange = (
  123. e: React.ChangeEvent<
  124. HTMLInputElement | HTMLTextAreaElement | HTMLSelectElement
  125. >,
  126. ) => {
  127. const { name, value } = e.target;
  128. setFormData((prev) => ({
  129. ...prev,
  130. [name]:
  131. name === "prompt" ||
  132. name === "negative_prompt" ||
  133. name === "seed" ||
  134. name === "sampling_method"
  135. ? value
  136. : Number(value),
  137. }));
  138. };
  139. const handleSourceImageChange = (image: string) => {
  140. setSourceImage(image);
  141. setError(null);
  142. };
  143. const handleMaskImageChange = (image: string) => {
  144. setMaskImage(image);
  145. setError(null);
  146. };
  147. const pollJobStatus = async (jobId: string) => {
  148. const maxAttempts = 300; // 5 minutes with 2 second interval
  149. let attempts = 0;
  150. let isPolling = true;
  151. let timeoutId: NodeJS.Timeout | null = null;
  152. const poll = async () => {
  153. if (!isPolling) return;
  154. try {
  155. const status: JobDetailsResponse = await apiClient.getJobStatus(jobId);
  156. setJobInfo(status.job);
  157. if (status.job.status === "completed") {
  158. let imageUrls: string[] = [];
  159. // Handle both old format (result.images) and new format (outputs)
  160. if (status.job.outputs && status.job.outputs.length > 0) {
  161. // New format: convert output URLs to authenticated image URLs with cache-busting
  162. imageUrls = status.job.outputs.map((output: { filename: string }) => {
  163. const filename = output.filename;
  164. return apiClient.getImageUrl(jobId, filename);
  165. });
  166. } else if (
  167. status.job.result?.images &&
  168. status.job.result.images.length > 0
  169. ) {
  170. // Old format: convert image URLs to authenticated URLs
  171. imageUrls = status.job.result.images.map((imageUrl: string) => {
  172. // Extract filename from URL if it's already a full URL
  173. if (imageUrl.includes("/output/")) {
  174. const parts = imageUrl.split("/output/");
  175. if (parts.length === 2) {
  176. const filename = parts[1].split("?")[0]; // Remove query params
  177. return apiClient.getImageUrl(jobId, filename);
  178. }
  179. }
  180. return imageUrl; // Already a full URL
  181. });
  182. }
  183. // Create a new array to trigger React re-render
  184. setGeneratedImages([...imageUrls]);
  185. setLoading(false);
  186. isPolling = false;
  187. } else if (status.job.status === "failed") {
  188. setError(status.job.error || "Generation failed");
  189. setLoading(false);
  190. isPolling = false;
  191. } else if (status.job.status === "cancelled") {
  192. setError("Generation was cancelled");
  193. setLoading(false);
  194. isPolling = false;
  195. } else if (attempts < maxAttempts) {
  196. attempts++;
  197. timeoutId = setTimeout(poll, 2000);
  198. } else {
  199. setError("Job polling timeout");
  200. setLoading(false);
  201. isPolling = false;
  202. }
  203. } catch (err) {
  204. if (isPolling) {
  205. setError(
  206. err instanceof Error ? err.message : "Failed to check job status",
  207. );
  208. setLoading(false);
  209. isPolling = false;
  210. }
  211. }
  212. };
  213. poll();
  214. // Return cleanup function
  215. return () => {
  216. isPolling = false;
  217. if (timeoutId) {
  218. clearTimeout(timeoutId);
  219. }
  220. };
  221. };
  222. const handleGenerate = async (e: React.FormEvent) => {
  223. e.preventDefault();
  224. if (!sourceImage) {
  225. setError("Please upload a source image first");
  226. return;
  227. }
  228. if (!maskImage) {
  229. setError("Please create a mask first");
  230. return;
  231. }
  232. setLoading(true);
  233. setError(null);
  234. setGeneratedImages([]);
  235. setJobInfo(null);
  236. try {
  237. // Validate model selection
  238. if (selectedCheckpointModel) {
  239. const validation = actions.validateSelection(selectedCheckpointModel);
  240. if (!validation.isValid) {
  241. setError(
  242. `Missing required models: ${validation.missingRequired.join(", ")}`,
  243. );
  244. setLoading(false);
  245. return;
  246. }
  247. }
  248. const requestData = {
  249. ...formData,
  250. source_image: sourceImage,
  251. mask_image: maskImage,
  252. model: selectedCheckpoint || undefined,
  253. vae: selectedVae || undefined,
  254. };
  255. const job = await apiClient.inpainting(requestData);
  256. setJobInfo(job);
  257. const jobId = job.request_id || job.id;
  258. if (jobId) {
  259. const cleanup = pollJobStatus(jobId);
  260. setPollCleanup(() => cleanup);
  261. } else {
  262. setError("No job ID returned from server");
  263. setLoading(false);
  264. }
  265. } catch (err) {
  266. setError(err instanceof Error ? err.message : "Failed to generate image");
  267. setLoading(false);
  268. }
  269. };
  270. const handleCancel = async () => {
  271. const jobId = jobInfo?.request_id || jobInfo?.id;
  272. if (jobId) {
  273. try {
  274. await apiClient.cancelJob(jobId);
  275. setLoading(false);
  276. setError("Generation cancelled");
  277. // Cleanup polling
  278. if (pollCleanup) {
  279. pollCleanup();
  280. setPollCleanup(null);
  281. }
  282. } catch (err) {
  283. console.error("Failed to cancel job:", err);
  284. }
  285. }
  286. };
  287. return (
  288. <AppLayout>
  289. <Header
  290. title="Inpainting"
  291. description="Edit images by masking areas and regenerating with AI"
  292. />
  293. <div className="container mx-auto p-6">
  294. <div className="grid gap-6 lg:grid-cols-2">
  295. {/* Left Panel - Canvas and Form */}
  296. <div className="space-y-6">
  297. <InpaintingCanvas
  298. onSourceImageChange={handleSourceImageChange}
  299. onMaskImageChange={handleMaskImageChange}
  300. targetWidth={formData.width}
  301. targetHeight={formData.height}
  302. />
  303. <Card>
  304. <CardContent className="pt-6">
  305. <form onSubmit={handleGenerate} className="space-y-4">
  306. <div className="space-y-2">
  307. <Label htmlFor="prompt">Prompt *</Label>
  308. <PromptTextarea
  309. value={formData.prompt}
  310. onChange={(value) =>
  311. setFormData({ ...formData, prompt: value })
  312. }
  313. placeholder="Describe what to generate in the masked areas..."
  314. rows={3}
  315. loras={loraModels}
  316. embeddings={embeddings}
  317. />
  318. <p className="text-xs text-muted-foreground">
  319. Tip: Use {"<lora:name:weight>"} for LoRAs and embedding
  320. names directly
  321. </p>
  322. </div>
  323. <div className="space-y-2">
  324. <Label htmlFor="negative_prompt">Negative Prompt</Label>
  325. <PromptTextarea
  326. value={formData.negative_prompt || ""}
  327. onChange={(value) =>
  328. setFormData({ ...formData, negative_prompt: value })
  329. }
  330. placeholder="What to avoid in the generated areas..."
  331. rows={2}
  332. loras={loraModels}
  333. embeddings={embeddings}
  334. />
  335. </div>
  336. <div className="space-y-2">
  337. <Label htmlFor="strength">
  338. Strength: {formData.strength.toFixed(2)}
  339. </Label>
  340. <Input
  341. id="strength"
  342. name="strength"
  343. type="range"
  344. value={formData.strength}
  345. onChange={handleInputChange}
  346. min={0}
  347. max={1}
  348. step={0.05}
  349. />
  350. <p className="text-xs text-muted-foreground">
  351. Lower values preserve more of the original image
  352. </p>
  353. </div>
  354. <div className="grid grid-cols-2 gap-4">
  355. <div className="space-y-2">
  356. <Label htmlFor="width">Width</Label>
  357. <Input
  358. id="width"
  359. name="width"
  360. type="number"
  361. value={formData.width}
  362. onChange={handleInputChange}
  363. step={64}
  364. min={256}
  365. max={2048}
  366. />
  367. </div>
  368. <div className="space-y-2">
  369. <Label htmlFor="height">Height</Label>
  370. <Input
  371. id="height"
  372. name="height"
  373. type="number"
  374. value={formData.height}
  375. onChange={handleInputChange}
  376. step={64}
  377. min={256}
  378. max={2048}
  379. />
  380. </div>
  381. </div>
  382. <div className="grid grid-cols-2 gap-4">
  383. <div className="space-y-2">
  384. <Label htmlFor="steps">Steps</Label>
  385. <Input
  386. id="steps"
  387. name="steps"
  388. type="number"
  389. value={formData.steps}
  390. onChange={handleInputChange}
  391. min={1}
  392. max={150}
  393. />
  394. </div>
  395. <div className="space-y-2">
  396. <Label htmlFor="cfg_scale">CFG Scale</Label>
  397. <Input
  398. id="cfg_scale"
  399. name="cfg_scale"
  400. type="number"
  401. value={formData.cfg_scale}
  402. onChange={handleInputChange}
  403. step={0.5}
  404. min={1}
  405. max={30}
  406. />
  407. </div>
  408. </div>
  409. <div className="space-y-2">
  410. <Label htmlFor="seed">Seed (optional)</Label>
  411. <Input
  412. id="seed"
  413. name="seed"
  414. value={formData.seed}
  415. onChange={handleInputChange}
  416. placeholder="Leave empty for random"
  417. />
  418. </div>
  419. <div className="space-y-2">
  420. <Label htmlFor="sampling_method">Sampling Method</Label>
  421. <select
  422. id="sampling_method"
  423. name="sampling_method"
  424. value={formData.sampling_method}
  425. onChange={handleInputChange}
  426. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  427. >
  428. <option value="euler">Euler</option>
  429. <option value="euler_a">Euler A</option>
  430. <option value="heun">Heun</option>
  431. <option value="dpm2">DPM2</option>
  432. <option value="dpm++2s_a">DPM++ 2S A</option>
  433. <option value="dpm++2m">DPM++ 2M</option>
  434. <option value="dpm++2mv2">DPM++ 2M V2</option>
  435. <option value="lcm">LCM</option>
  436. </select>
  437. </div>
  438. {/* Model Selection Section */}
  439. <Card>
  440. <CardHeader>
  441. <CardTitle>Model Selection</CardTitle>
  442. <CardDescription>
  443. Select checkpoint and additional models for generation
  444. </CardDescription>
  445. </CardHeader>
  446. {/* Checkpoint Selection */}
  447. <div className="space-y-2">
  448. <Label htmlFor="checkpoint">Checkpoint Model *</Label>
  449. <select
  450. id="checkpoint"
  451. value={selectedCheckpoint || ""}
  452. onChange={(e) =>
  453. setSelectedCheckpoint(e.target.value || null)
  454. }
  455. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  456. disabled={isAutoSelecting}
  457. >
  458. <option value="">Select a checkpoint model...</option>
  459. {checkpointModels.map((model) => (
  460. <option key={model.id} value={model.name}>
  461. {model.name} {model.loaded ? "(Loaded)" : ""}
  462. </option>
  463. ))}
  464. </select>
  465. </div>
  466. {/* VAE Selection */}
  467. <div className="space-y-2">
  468. <Label>VAE Model (Optional)</Label>
  469. <Select
  470. value={selectedVae || ""}
  471. onValueChange={(value) => {
  472. if (value) {
  473. setSelectedVae(value);
  474. setVaeUserOverride(value);
  475. } else {
  476. clearVaeUserOverride();
  477. }
  478. }}
  479. disabled={isAutoSelecting}
  480. >
  481. <SelectTrigger>
  482. <SelectValue placeholder="Use default VAE" />
  483. </SelectTrigger>
  484. <SelectContent>
  485. <SelectItem value="">Use default VAE</SelectItem>
  486. {vaeModels.map((model) => (
  487. <SelectItem key={model.name} value={model.name}>
  488. {model.name}
  489. </SelectItem>
  490. ))}
  491. </SelectContent>
  492. </Select>
  493. {isVaeAutoSelected && (
  494. <p className="text-xs text-muted-foreground">
  495. Auto-selected VAE model
  496. </p>
  497. )}
  498. </div>
  499. </Card>
  500. <div className="flex gap-2">
  501. <Button
  502. type="submit"
  503. disabled={loading || !sourceImage || !maskImage}
  504. className="flex-1"
  505. >
  506. {loading ? (
  507. <>
  508. <Loader2 className="h-4 w-4 animate-spin" />
  509. Generating...
  510. </>
  511. ) : (
  512. "Generate"
  513. )}
  514. </Button>
  515. {loading && (
  516. <Button
  517. type="button"
  518. variant="destructive"
  519. onClick={handleCancel}
  520. >
  521. <X className="h-4 w-4" />
  522. Cancel
  523. </Button>
  524. )}
  525. </div>
  526. {error && (
  527. <div className="rounded-md bg-destructive/10 p-3 text-sm text-destructive">
  528. {error}
  529. </div>
  530. )}
  531. </form>
  532. </CardContent>
  533. </Card>
  534. </div>
  535. {/* Right Panel - Generated Images */}
  536. <Card>
  537. <CardContent className="pt-6">
  538. <div className="space-y-4">
  539. <h3 className="text-lg font-semibold">Generated Images</h3>
  540. {generatedImages.length === 0 ? (
  541. <div className="flex h-96 items-center justify-center rounded-lg border-2 border-dashed border-border">
  542. <p className="text-muted-foreground">
  543. {loading
  544. ? "Generating..."
  545. : "Generated images will appear here"}
  546. </p>
  547. </div>
  548. ) : (
  549. <div className="grid gap-4">
  550. {generatedImages.map((image, index) => (
  551. <div key={index} className="relative group">
  552. <img
  553. src={image}
  554. alt={`Generated ${index + 1}`}
  555. className="w-full rounded-lg border border-border"
  556. />
  557. <Button
  558. size="icon"
  559. variant="secondary"
  560. className="absolute top-2 right-2 opacity-0 group-hover:opacity-100 transition-opacity"
  561. onClick={() => {
  562. const authToken =
  563. localStorage.getItem("auth_token");
  564. const unixUser = localStorage.getItem("unix_user");
  565. downloadAuthenticatedImage(
  566. image,
  567. `inpainting-${Date.now()}-${index}.png`,
  568. authToken || undefined,
  569. unixUser || undefined,
  570. ).catch((err) => {
  571. console.error("Failed to download image:", err);
  572. });
  573. }}
  574. >
  575. <Download className="h-4 w-4" />
  576. </Button>
  577. </div>
  578. ))}
  579. </div>
  580. )}
  581. </div>
  582. </CardContent>
  583. </Card>
  584. </div>
  585. </div>
  586. </AppLayout>
  587. );
  588. }
  589. export default function InpaintingPage() {
  590. return <InpaintingForm />;
  591. }