page.tsx 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567
  1. 'use client';
  2. import { useState, useEffect } from 'react';
  3. import { Header } from '@/components/layout';
  4. import { AppLayout } from '@/components/layout';
  5. import { Button } from '@/components/ui/button';
  6. import { Input } from '@/components/ui/input';
  7. import { Textarea } from '@/components/ui/textarea';
  8. import { PromptTextarea } from '@/components/forms';
  9. import { Label } from '@/components/ui/label';
  10. import { Card, CardContent, CardHeader, CardTitle, CardDescription } from '@/components/ui/card';
  11. import { InpaintingCanvas } from '@/components/features/image-generation';
  12. import { apiClient, type JobInfo } from '@/lib/api';
  13. import { Loader2, X, Download } from 'lucide-react';
  14. import { downloadAuthenticatedImage } from '@/lib/utils';
  15. import { useLocalStorage } from '@/lib/storage';
  16. import { ModelSelectionProvider, useModelSelection, useCheckpointSelection, useModelTypeSelection } from '@/contexts/model-selection-context';
  17. import { EnhancedModelSelect } from '@/components/features/models';
  18. // import { AutoSelectionStatus } from '@/components/features/models';
  19. type InpaintingFormData = {
  20. prompt: string;
  21. negative_prompt: string;
  22. steps: number;
  23. cfg_scale: number;
  24. seed: string;
  25. sampling_method: string;
  26. strength: number;
  27. width?: number;
  28. height?: number;
  29. };
  30. const defaultFormData: InpaintingFormData = {
  31. prompt: '',
  32. negative_prompt: '',
  33. steps: 20,
  34. cfg_scale: 7.5,
  35. seed: '',
  36. sampling_method: 'euler_a',
  37. strength: 0.75,
  38. width: 512,
  39. height: 512,
  40. };
  41. function InpaintingForm() {
  42. const { state, actions } = useModelSelection();
  43. const {
  44. checkpointModels,
  45. selectedCheckpointModel,
  46. selectedCheckpoint,
  47. setSelectedCheckpoint,
  48. isAutoSelecting,
  49. warnings,
  50. error: checkpointError
  51. } = useCheckpointSelection();
  52. const {
  53. availableModels: vaeModels,
  54. selectedModel: selectedVae,
  55. isUserOverride: isVaeUserOverride,
  56. isAutoSelected: isVaeAutoSelected,
  57. setSelectedModel: setSelectedVae,
  58. setUserOverride: setVaeUserOverride,
  59. clearUserOverride: clearVaeUserOverride,
  60. } = useModelTypeSelection('vae');
  61. const [formData, setFormData] = useLocalStorage<InpaintingFormData>(
  62. 'inpainting-form-data',
  63. defaultFormData,
  64. { excludeLargeData: true, maxSize: 512 * 1024 }
  65. );
  66. // Separate state for image data (not stored in localStorage)
  67. const [sourceImage, setSourceImage] = useState<string>('');
  68. const [maskImage, setMaskImage] = useState<string>('');
  69. const [loading, setLoading] = useState(false);
  70. const [jobInfo, setJobInfo] = useState<JobInfo | null>(null);
  71. const [generatedImages, setGeneratedImages] = useState<string[]>([]);
  72. const [loraModels, setLoraModels] = useState<string[]>([]);
  73. const [embeddings, setEmbeddings] = useState<string[]>([]);
  74. const [error, setError] = useState<string | null>(null);
  75. const [pollCleanup, setPollCleanup] = useState<(() => void) | null>(null);
  76. // Cleanup polling on unmount
  77. useEffect(() => {
  78. return () => {
  79. if (pollCleanup) {
  80. pollCleanup();
  81. }
  82. };
  83. }, [pollCleanup]);
  84. useEffect(() => {
  85. const loadModels = async () => {
  86. try {
  87. const [modelsData, loras, embeds] = await Promise.all([
  88. apiClient.getModels(), // Get all models with enhanced info
  89. apiClient.getModels('lora'),
  90. apiClient.getModels('embedding'),
  91. ]);
  92. actions.setModels(modelsData.models);
  93. setLoraModels(loras.models.map(m => m.name));
  94. setEmbeddings(embeds.models.map(m => m.name));
  95. } catch (err) {
  96. console.error('Failed to load models:', err);
  97. }
  98. };
  99. loadModels();
  100. }, [actions]);
  101. // Update form data when checkpoint changes
  102. useEffect(() => {
  103. if (selectedCheckpoint) {
  104. setFormData(prev => ({
  105. ...prev,
  106. model: selectedCheckpoint,
  107. }));
  108. }
  109. }, [selectedCheckpoint, setFormData]);
  110. const handleInputChange = (
  111. e: React.ChangeEvent<HTMLInputElement | HTMLTextAreaElement | HTMLSelectElement>
  112. ) => {
  113. const { name, value } = e.target;
  114. setFormData((prev) => ({
  115. ...prev,
  116. [name]: name === 'prompt' || name === 'negative_prompt' || name === 'seed' || name === 'sampling_method'
  117. ? value
  118. : Number(value),
  119. }));
  120. };
  121. const handleSourceImageChange = (image: string) => {
  122. setSourceImage(image);
  123. setError(null);
  124. };
  125. const handleMaskImageChange = (image: string) => {
  126. setMaskImage(image);
  127. setError(null);
  128. };
  129. const pollJobStatus = async (jobId: string) => {
  130. const maxAttempts = 300; // 5 minutes with 2 second interval
  131. let attempts = 0;
  132. let isPolling = true;
  133. const poll = async () => {
  134. if (!isPolling) return;
  135. try {
  136. const status = await apiClient.getJobStatus(jobId);
  137. setJobInfo(status);
  138. if (status.status === 'completed') {
  139. let imageUrls: string[] = [];
  140. // Handle both old format (result.images) and new format (outputs)
  141. if (status.outputs && status.outputs.length > 0) {
  142. // New format: convert output URLs to authenticated image URLs with cache-busting
  143. imageUrls = status.outputs.map((output: any) => {
  144. const filename = output.filename;
  145. return apiClient.getImageUrl(jobId, filename);
  146. });
  147. } else if (status.result?.images && status.result.images.length > 0) {
  148. // Old format: convert image URLs to authenticated URLs
  149. imageUrls = status.result.images.map((imageUrl: string) => {
  150. // Extract filename from URL if it's already a full URL
  151. if (imageUrl.includes('/output/')) {
  152. const parts = imageUrl.split('/output/');
  153. if (parts.length === 2) {
  154. const filename = parts[1].split('?')[0]; // Remove query params
  155. return apiClient.getImageUrl(jobId, filename);
  156. }
  157. }
  158. // If it's just a filename, convert it directly
  159. return apiClient.getImageUrl(jobId, imageUrl);
  160. });
  161. }
  162. // Create a new array to trigger React re-render
  163. setGeneratedImages([...imageUrls]);
  164. setLoading(false);
  165. isPolling = false;
  166. } else if (status.status === 'failed') {
  167. setError(status.error || 'Generation failed');
  168. setLoading(false);
  169. isPolling = false;
  170. } else if (status.status === 'cancelled') {
  171. setError('Generation was cancelled');
  172. setLoading(false);
  173. isPolling = false;
  174. } else if (attempts < maxAttempts) {
  175. attempts++;
  176. setTimeout(poll, 2000);
  177. } else {
  178. setError('Job polling timeout');
  179. setLoading(false);
  180. isPolling = false;
  181. }
  182. } catch (err) {
  183. if (isPolling) {
  184. setError(err instanceof Error ? err.message : 'Failed to check job status');
  185. setLoading(false);
  186. isPolling = false;
  187. }
  188. }
  189. };
  190. poll();
  191. // Return cleanup function
  192. return () => {
  193. isPolling = false;
  194. };
  195. };
  196. const handleGenerate = async (e: React.FormEvent) => {
  197. e.preventDefault();
  198. if (!sourceImage) {
  199. setError('Please upload a source image first');
  200. return;
  201. }
  202. if (!maskImage) {
  203. setError('Please create a mask first');
  204. return;
  205. }
  206. setLoading(true);
  207. setError(null);
  208. setGeneratedImages([]);
  209. setJobInfo(null);
  210. try {
  211. // Validate model selection
  212. if (selectedCheckpointModel) {
  213. const validation = actions.validateSelection(selectedCheckpointModel);
  214. if (!validation.isValid) {
  215. setError(`Missing required models: ${validation.missingRequired.join(', ')}`);
  216. setLoading(false);
  217. return;
  218. }
  219. }
  220. const requestData = {
  221. ...formData,
  222. source_image: sourceImage,
  223. mask_image: maskImage,
  224. model: selectedCheckpoint || undefined,
  225. vae: selectedVae || undefined,
  226. };
  227. const job = await apiClient.inpainting(requestData);
  228. setJobInfo(job);
  229. const jobId = job.request_id || job.id;
  230. if (jobId) {
  231. const cleanup = pollJobStatus(jobId);
  232. setPollCleanup(() => cleanup);
  233. } else {
  234. setError('No job ID returned from server');
  235. setLoading(false);
  236. }
  237. } catch (err) {
  238. setError(err instanceof Error ? err.message : 'Failed to generate image');
  239. setLoading(false);
  240. }
  241. };
  242. const handleCancel = async () => {
  243. const jobId = jobInfo?.request_id || jobInfo?.id;
  244. if (jobId) {
  245. try {
  246. await apiClient.cancelJob(jobId);
  247. setLoading(false);
  248. setError('Generation cancelled');
  249. // Cleanup polling
  250. if (pollCleanup) {
  251. pollCleanup();
  252. setPollCleanup(null);
  253. }
  254. } catch (err) {
  255. console.error('Failed to cancel job:', err);
  256. }
  257. }
  258. };
  259. return (
  260. <AppLayout>
  261. <Header title="Inpainting" description="Edit images by masking areas and regenerating with AI" />
  262. <div className="container mx-auto p-6">
  263. <div className="grid gap-6 lg:grid-cols-2">
  264. {/* Left Panel - Canvas and Form */}
  265. <div className="space-y-6">
  266. <InpaintingCanvas
  267. onSourceImageChange={handleSourceImageChange}
  268. onMaskImageChange={handleMaskImageChange}
  269. targetWidth={formData.width}
  270. targetHeight={formData.height}
  271. />
  272. <Card>
  273. <CardContent className="pt-6">
  274. <form onSubmit={handleGenerate} className="space-y-4">
  275. <div className="space-y-2">
  276. <Label htmlFor="prompt">Prompt *</Label>
  277. <PromptTextarea
  278. value={formData.prompt}
  279. onChange={(value) => setFormData({ ...formData, prompt: value })}
  280. placeholder="Describe what to generate in the masked areas..."
  281. rows={3}
  282. loras={loraModels}
  283. embeddings={embeddings}
  284. />
  285. <p className="text-xs text-muted-foreground">
  286. Tip: Use {'<lora:name:weight>'} for LoRAs and embedding names directly
  287. </p>
  288. </div>
  289. <div className="space-y-2">
  290. <Label htmlFor="negative_prompt">Negative Prompt</Label>
  291. <PromptTextarea
  292. value={formData.negative_prompt || ''}
  293. onChange={(value) => setFormData({ ...formData, negative_prompt: value })}
  294. placeholder="What to avoid in the generated areas..."
  295. rows={2}
  296. loras={loraModels}
  297. embeddings={embeddings}
  298. />
  299. </div>
  300. <div className="space-y-2">
  301. <Label htmlFor="strength">
  302. Strength: {formData.strength.toFixed(2)}
  303. </Label>
  304. <Input
  305. id="strength"
  306. name="strength"
  307. type="range"
  308. value={formData.strength}
  309. onChange={handleInputChange}
  310. min={0}
  311. max={1}
  312. step={0.05}
  313. />
  314. <p className="text-xs text-muted-foreground">
  315. Lower values preserve more of the original image
  316. </p>
  317. </div>
  318. <div className="grid grid-cols-2 gap-4">
  319. <div className="space-y-2">
  320. <Label htmlFor="width">Width</Label>
  321. <Input
  322. id="width"
  323. name="width"
  324. type="number"
  325. value={formData.width}
  326. onChange={handleInputChange}
  327. step={64}
  328. min={256}
  329. max={2048}
  330. />
  331. </div>
  332. <div className="space-y-2">
  333. <Label htmlFor="height">Height</Label>
  334. <Input
  335. id="height"
  336. name="height"
  337. type="number"
  338. value={formData.height}
  339. onChange={handleInputChange}
  340. step={64}
  341. min={256}
  342. max={2048}
  343. />
  344. </div>
  345. </div>
  346. <div className="grid grid-cols-2 gap-4">
  347. <div className="space-y-2">
  348. <Label htmlFor="steps">Steps</Label>
  349. <Input
  350. id="steps"
  351. name="steps"
  352. type="number"
  353. value={formData.steps}
  354. onChange={handleInputChange}
  355. min={1}
  356. max={150}
  357. />
  358. </div>
  359. <div className="space-y-2">
  360. <Label htmlFor="cfg_scale">CFG Scale</Label>
  361. <Input
  362. id="cfg_scale"
  363. name="cfg_scale"
  364. type="number"
  365. value={formData.cfg_scale}
  366. onChange={handleInputChange}
  367. step={0.5}
  368. min={1}
  369. max={30}
  370. />
  371. </div>
  372. </div>
  373. <div className="space-y-2">
  374. <Label htmlFor="seed">Seed (optional)</Label>
  375. <Input
  376. id="seed"
  377. name="seed"
  378. value={formData.seed}
  379. onChange={handleInputChange}
  380. placeholder="Leave empty for random"
  381. />
  382. </div>
  383. <div className="space-y-2">
  384. <Label htmlFor="sampling_method">Sampling Method</Label>
  385. <select
  386. id="sampling_method"
  387. name="sampling_method"
  388. value={formData.sampling_method}
  389. onChange={handleInputChange}
  390. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  391. >
  392. <option value="euler">Euler</option>
  393. <option value="euler_a">Euler A</option>
  394. <option value="heun">Heun</option>
  395. <option value="dpm2">DPM2</option>
  396. <option value="dpm++2s_a">DPM++ 2S A</option>
  397. <option value="dpm++2m">DPM++ 2M</option>
  398. <option value="dpm++2mv2">DPM++ 2M V2</option>
  399. <option value="lcm">LCM</option>
  400. </select>
  401. </div>
  402. {/* Model Selection Section */}
  403. <Card>
  404. <CardHeader>
  405. <CardTitle>Model Selection</CardTitle>
  406. <CardDescription>Select checkpoint and additional models for generation</CardDescription>
  407. </CardHeader>
  408. {/* Checkpoint Selection */}
  409. <div className="space-y-2">
  410. <Label htmlFor="checkpoint">Checkpoint Model *</Label>
  411. <select
  412. id="checkpoint"
  413. value={selectedCheckpoint || ''}
  414. onChange={(e) => setSelectedCheckpoint(e.target.value || null)}
  415. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  416. disabled={isAutoSelecting}
  417. >
  418. <option value="">Select a checkpoint model...</option>
  419. {checkpointModels.map((model) => (
  420. <option key={model.id} value={model.name}>
  421. {model.name} {model.loaded ? '(Loaded)' : ''}
  422. </option>
  423. ))}
  424. </select>
  425. </div>
  426. {/* VAE Selection */}
  427. <EnhancedModelSelect
  428. modelType="vae"
  429. label="VAE Model"
  430. description="Optional VAE model for improved image quality"
  431. value={selectedVae}
  432. availableModels={vaeModels}
  433. isAutoSelected={isVaeAutoSelected}
  434. isUserOverride={isVaeUserOverride}
  435. isLoading={isAutoSelecting}
  436. onValueChange={setSelectedVae}
  437. onSetUserOverride={setVaeUserOverride}
  438. onClearOverride={clearVaeUserOverride}
  439. placeholder="Use default VAE"
  440. />
  441. </Card>
  442. <div className="flex gap-2">
  443. <Button
  444. type="submit"
  445. disabled={loading || !sourceImage || !maskImage}
  446. className="flex-1"
  447. >
  448. {loading ? (
  449. <>
  450. <Loader2 className="h-4 w-4 animate-spin" />
  451. Generating...
  452. </>
  453. ) : (
  454. 'Generate'
  455. )}
  456. </Button>
  457. {loading && (
  458. <Button type="button" variant="destructive" onClick={handleCancel}>
  459. <X className="h-4 w-4" />
  460. Cancel
  461. </Button>
  462. )}
  463. </div>
  464. {error && (
  465. <div className="rounded-md bg-destructive/10 p-3 text-sm text-destructive">
  466. {error}
  467. </div>
  468. )}
  469. </form>
  470. </CardContent>
  471. </Card>
  472. </div>
  473. {/* Right Panel - Generated Images */}
  474. <Card>
  475. <CardContent className="pt-6">
  476. <div className="space-y-4">
  477. <h3 className="text-lg font-semibold">Generated Images</h3>
  478. {generatedImages.length === 0 ? (
  479. <div className="flex h-96 items-center justify-center rounded-lg border-2 border-dashed border-border">
  480. <p className="text-muted-foreground">
  481. {loading ? 'Generating...' : 'Generated images will appear here'}
  482. </p>
  483. </div>
  484. ) : (
  485. <div className="grid gap-4">
  486. {generatedImages.map((image, index) => (
  487. <div key={index} className="relative group">
  488. <img
  489. src={image}
  490. alt={`Generated ${index + 1}`}
  491. className="w-full rounded-lg border border-border"
  492. />
  493. <Button
  494. size="icon"
  495. variant="secondary"
  496. className="absolute top-2 right-2 opacity-0 group-hover:opacity-100 transition-opacity"
  497. onClick={() => {
  498. const authToken = localStorage.getItem('auth_token');
  499. const unixUser = localStorage.getItem('unix_user');
  500. downloadAuthenticatedImage(image, `inpainting-${Date.now()}-${index}.png`, authToken || undefined, unixUser || undefined)
  501. .catch(err => {
  502. console.error('Failed to download image:', err);
  503. });
  504. }}
  505. >
  506. <Download className="h-4 w-4" />
  507. </Button>
  508. </div>
  509. ))}
  510. </div>
  511. )}
  512. </div>
  513. </CardContent>
  514. </Card>
  515. </div>
  516. </div>
  517. </AppLayout>
  518. );
  519. }
  520. export default function InpaintingPage() {
  521. return <InpaintingForm />;
  522. }