page.tsx 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. 'use client';
  2. import { useState, useEffect } from 'react';
  3. import { Header } from '@/components/header';
  4. import { AppLayout } from '@/components/layout';
  5. import { Button } from '@/components/ui/button';
  6. import { Input } from '@/components/ui/input';
  7. import { Textarea } from '@/components/ui/textarea';
  8. import { PromptTextarea } from '@/components/prompt-textarea';
  9. import { Label } from '@/components/ui/label';
  10. import { Card, CardContent } from '@/components/ui/card';
  11. import { InpaintingCanvas } from '@/components/inpainting-canvas';
  12. import { apiClient, type JobInfo } from '@/lib/api';
  13. import { Loader2, X, Download } from 'lucide-react';
  14. import { downloadAuthenticatedImage } from '@/lib/utils';
  15. import { useLocalStorage } from '@/lib/hooks';
  16. import { ModelSelectionProvider, useModelSelection, useCheckpointSelection, useModelTypeSelection } from '@/contexts/model-selection-context';
  17. import { EnhancedModelSelect, EnhancedModelSelectGroup } from '@/components/enhanced-model-select';
  18. import { ModelSelectionWarning, AutoSelectionStatus } from '@/components/model-selection-indicator';
  19. type InpaintingFormData = {
  20. prompt: string;
  21. negative_prompt: string;
  22. source_image: string;
  23. mask_image: string;
  24. steps: number;
  25. cfg_scale: number;
  26. seed: string;
  27. sampling_method: string;
  28. strength: number;
  29. width?: number;
  30. height?: number;
  31. };
  32. const defaultFormData: InpaintingFormData = {
  33. prompt: '',
  34. negative_prompt: '',
  35. source_image: '',
  36. mask_image: '',
  37. steps: 20,
  38. cfg_scale: 7.5,
  39. seed: '',
  40. sampling_method: 'euler_a',
  41. strength: 0.75,
  42. width: 512,
  43. height: 512,
  44. };
  45. function InpaintingForm() {
  46. const { state, actions } = useModelSelection();
  47. const {
  48. checkpointModels,
  49. selectedCheckpointModel,
  50. selectedCheckpoint,
  51. setSelectedCheckpoint,
  52. isAutoSelecting,
  53. warnings,
  54. error: checkpointError
  55. } = useCheckpointSelection();
  56. const {
  57. availableModels: vaeModels,
  58. selectedModel: selectedVae,
  59. isUserOverride: isVaeUserOverride,
  60. isAutoSelected: isVaeAutoSelected,
  61. setSelectedModel: setSelectedVae,
  62. setUserOverride: setVaeUserOverride,
  63. clearUserOverride: clearVaeUserOverride,
  64. } = useModelTypeSelection('vae');
  65. const [formData, setFormData] = useLocalStorage<InpaintingFormData>(
  66. 'inpainting-form-data',
  67. defaultFormData
  68. );
  69. const [loading, setLoading] = useState(false);
  70. const [jobInfo, setJobInfo] = useState<JobInfo | null>(null);
  71. const [generatedImages, setGeneratedImages] = useState<string[]>([]);
  72. const [loraModels, setLoraModels] = useState<string[]>([]);
  73. const [embeddings, setEmbeddings] = useState<string[]>([]);
  74. const [error, setError] = useState<string | null>(null);
  75. useEffect(() => {
  76. const loadModels = async () => {
  77. try {
  78. const [modelsData, loras, embeds] = await Promise.all([
  79. apiClient.getModels(), // Get all models with enhanced info
  80. apiClient.getModels('lora'),
  81. apiClient.getModels('embedding'),
  82. ]);
  83. actions.setModels(modelsData.models);
  84. setLoraModels(loras.models.map(m => m.name));
  85. setEmbeddings(embeds.models.map(m => m.name));
  86. } catch (err) {
  87. console.error('Failed to load models:', err);
  88. }
  89. };
  90. loadModels();
  91. }, [actions]);
  92. // Update form data when checkpoint changes
  93. useEffect(() => {
  94. if (selectedCheckpoint) {
  95. setFormData(prev => ({
  96. ...prev,
  97. model: selectedCheckpoint,
  98. }));
  99. }
  100. }, [selectedCheckpoint, setFormData]);
  101. const handleInputChange = (
  102. e: React.ChangeEvent<HTMLInputElement | HTMLTextAreaElement | HTMLSelectElement>
  103. ) => {
  104. const { name, value } = e.target;
  105. setFormData((prev) => ({
  106. ...prev,
  107. [name]: name === 'prompt' || name === 'negative_prompt' || name === 'seed' || name === 'sampling_method'
  108. ? value
  109. : Number(value),
  110. }));
  111. };
  112. const handleSourceImageChange = (image: string) => {
  113. setFormData((prev) => ({ ...prev, source_image: image }));
  114. setError(null);
  115. };
  116. const handleMaskImageChange = (image: string) => {
  117. setFormData((prev) => ({ ...prev, mask_image: image }));
  118. setError(null);
  119. };
  120. const pollJobStatus = async (jobId: string) => {
  121. const maxAttempts = 300;
  122. let attempts = 0;
  123. const poll = async () => {
  124. try {
  125. const status = await apiClient.getJobStatus(jobId);
  126. setJobInfo(status);
  127. if (status.status === 'completed') {
  128. let imageUrls: string[] = [];
  129. // Handle both old format (result.images) and new format (outputs)
  130. if (status.outputs && status.outputs.length > 0) {
  131. // New format: convert output URLs to authenticated image URLs with cache-busting
  132. imageUrls = status.outputs.map((output: any) => {
  133. const filename = output.filename;
  134. return apiClient.getImageUrl(jobId, filename);
  135. });
  136. } else if (status.result?.images && status.result.images.length > 0) {
  137. // Old format: convert image URLs to authenticated URLs
  138. imageUrls = status.result.images.map((imageUrl: string) => {
  139. // Extract filename from URL if it's already a full URL
  140. if (imageUrl.includes('/output/')) {
  141. const parts = imageUrl.split('/output/');
  142. if (parts.length === 2) {
  143. const filename = parts[1].split('?')[0]; // Remove query params
  144. return apiClient.getImageUrl(jobId, filename);
  145. }
  146. }
  147. // If it's just a filename, convert it directly
  148. return apiClient.getImageUrl(jobId, imageUrl);
  149. });
  150. }
  151. // Create a new array to trigger React re-render
  152. setGeneratedImages([...imageUrls]);
  153. setLoading(false);
  154. } else if (status.status === 'failed') {
  155. setError(status.error || 'Generation failed');
  156. setLoading(false);
  157. } else if (status.status === 'cancelled') {
  158. setError('Generation was cancelled');
  159. setLoading(false);
  160. } else if (attempts < maxAttempts) {
  161. attempts++;
  162. setTimeout(poll, 2000);
  163. } else {
  164. setError('Job polling timeout');
  165. setLoading(false);
  166. }
  167. } catch (err) {
  168. setError(err instanceof Error ? err.message : 'Failed to check job status');
  169. setLoading(false);
  170. }
  171. };
  172. poll();
  173. };
  174. const handleGenerate = async (e: React.FormEvent) => {
  175. e.preventDefault();
  176. if (!formData.source_image) {
  177. setError('Please upload a source image first');
  178. return;
  179. }
  180. if (!formData.mask_image) {
  181. setError('Please create a mask first');
  182. return;
  183. }
  184. setLoading(true);
  185. setError(null);
  186. setGeneratedImages([]);
  187. setJobInfo(null);
  188. try {
  189. // Validate model selection
  190. if (selectedCheckpointModel) {
  191. const validation = actions.validateSelection(selectedCheckpointModel);
  192. if (!validation.isValid) {
  193. setError(`Missing required models: ${validation.missingRequired.join(', ')}`);
  194. setLoading(false);
  195. return;
  196. }
  197. }
  198. const requestData = {
  199. ...formData,
  200. model: selectedCheckpoint || undefined,
  201. vae: selectedVae || undefined,
  202. };
  203. const job = await apiClient.inpainting(requestData);
  204. setJobInfo(job);
  205. const jobId = job.request_id || job.id;
  206. if (jobId) {
  207. await pollJobStatus(jobId);
  208. } else {
  209. setError('No job ID returned from server');
  210. setLoading(false);
  211. }
  212. } catch (err) {
  213. setError(err instanceof Error ? err.message : 'Failed to generate image');
  214. setLoading(false);
  215. }
  216. };
  217. const handleCancel = async () => {
  218. const jobId = jobInfo?.request_id || jobInfo?.id;
  219. if (jobId) {
  220. try {
  221. await apiClient.cancelJob(jobId);
  222. setLoading(false);
  223. setError('Generation cancelled');
  224. } catch (err) {
  225. console.error('Failed to cancel job:', err);
  226. }
  227. }
  228. };
  229. return (
  230. <AppLayout>
  231. <Header title="Inpainting" description="Edit images by masking areas and regenerating with AI" />
  232. <div className="container mx-auto p-6">
  233. <div className="grid gap-6 lg:grid-cols-2">
  234. {/* Left Panel - Canvas and Form */}
  235. <div className="space-y-6">
  236. <InpaintingCanvas
  237. onSourceImageChange={handleSourceImageChange}
  238. onMaskImageChange={handleMaskImageChange}
  239. targetWidth={formData.width}
  240. targetHeight={formData.height}
  241. />
  242. <Card>
  243. <CardContent className="pt-6">
  244. <form onSubmit={handleGenerate} className="space-y-4">
  245. <div className="space-y-2">
  246. <Label htmlFor="prompt">Prompt *</Label>
  247. <PromptTextarea
  248. value={formData.prompt}
  249. onChange={(value) => setFormData({ ...formData, prompt: value })}
  250. placeholder="Describe what to generate in the masked areas..."
  251. rows={3}
  252. loras={loraModels}
  253. embeddings={embeddings}
  254. />
  255. <p className="text-xs text-muted-foreground">
  256. Tip: Use {'<lora:name:weight>'} for LoRAs and embedding names directly
  257. </p>
  258. </div>
  259. <div className="space-y-2">
  260. <Label htmlFor="negative_prompt">Negative Prompt</Label>
  261. <PromptTextarea
  262. value={formData.negative_prompt || ''}
  263. onChange={(value) => setFormData({ ...formData, negative_prompt: value })}
  264. placeholder="What to avoid in the generated areas..."
  265. rows={2}
  266. loras={loraModels}
  267. embeddings={embeddings}
  268. />
  269. </div>
  270. <div className="space-y-2">
  271. <Label htmlFor="strength">
  272. Strength: {formData.strength.toFixed(2)}
  273. </Label>
  274. <Input
  275. id="strength"
  276. name="strength"
  277. type="range"
  278. value={formData.strength}
  279. onChange={handleInputChange}
  280. min={0}
  281. max={1}
  282. step={0.05}
  283. />
  284. <p className="text-xs text-muted-foreground">
  285. Lower values preserve more of the original image
  286. </p>
  287. </div>
  288. <div className="grid grid-cols-2 gap-4">
  289. <div className="space-y-2">
  290. <Label htmlFor="width">Width</Label>
  291. <Input
  292. id="width"
  293. name="width"
  294. type="number"
  295. value={formData.width}
  296. onChange={handleInputChange}
  297. step={64}
  298. min={256}
  299. max={2048}
  300. />
  301. </div>
  302. <div className="space-y-2">
  303. <Label htmlFor="height">Height</Label>
  304. <Input
  305. id="height"
  306. name="height"
  307. type="number"
  308. value={formData.height}
  309. onChange={handleInputChange}
  310. step={64}
  311. min={256}
  312. max={2048}
  313. />
  314. </div>
  315. </div>
  316. <div className="grid grid-cols-2 gap-4">
  317. <div className="space-y-2">
  318. <Label htmlFor="steps">Steps</Label>
  319. <Input
  320. id="steps"
  321. name="steps"
  322. type="number"
  323. value={formData.steps}
  324. onChange={handleInputChange}
  325. min={1}
  326. max={150}
  327. />
  328. </div>
  329. <div className="space-y-2">
  330. <Label htmlFor="cfg_scale">CFG Scale</Label>
  331. <Input
  332. id="cfg_scale"
  333. name="cfg_scale"
  334. type="number"
  335. value={formData.cfg_scale}
  336. onChange={handleInputChange}
  337. step={0.5}
  338. min={1}
  339. max={30}
  340. />
  341. </div>
  342. </div>
  343. <div className="space-y-2">
  344. <Label htmlFor="seed">Seed (optional)</Label>
  345. <Input
  346. id="seed"
  347. name="seed"
  348. value={formData.seed}
  349. onChange={handleInputChange}
  350. placeholder="Leave empty for random"
  351. />
  352. </div>
  353. <div className="space-y-2">
  354. <Label htmlFor="sampling_method">Sampling Method</Label>
  355. <select
  356. id="sampling_method"
  357. name="sampling_method"
  358. value={formData.sampling_method}
  359. onChange={handleInputChange}
  360. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  361. >
  362. <option value="euler">Euler</option>
  363. <option value="euler_a">Euler A</option>
  364. <option value="heun">Heun</option>
  365. <option value="dpm2">DPM2</option>
  366. <option value="dpm++2s_a">DPM++ 2S A</option>
  367. <option value="dpm++2m">DPM++ 2M</option>
  368. <option value="dpm++2mv2">DPM++ 2M V2</option>
  369. <option value="lcm">LCM</option>
  370. </select>
  371. </div>
  372. {/* Model Selection Section */}
  373. <EnhancedModelSelectGroup
  374. title="Model Selection"
  375. description="Select the checkpoint and additional models for generation"
  376. >
  377. {/* Checkpoint Selection */}
  378. <div className="space-y-2">
  379. <Label htmlFor="checkpoint">Checkpoint Model *</Label>
  380. <select
  381. id="checkpoint"
  382. value={selectedCheckpoint || ''}
  383. onChange={(e) => setSelectedCheckpoint(e.target.value || null)}
  384. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  385. disabled={isAutoSelecting}
  386. >
  387. <option value="">Select a checkpoint model...</option>
  388. {checkpointModels.map((model) => (
  389. <option key={model.id} value={model.name}>
  390. {model.name} {model.loaded ? '(Loaded)' : ''}
  391. </option>
  392. ))}
  393. </select>
  394. </div>
  395. {/* VAE Selection */}
  396. <EnhancedModelSelect
  397. modelType="vae"
  398. label="VAE Model"
  399. description="Optional VAE model for improved image quality"
  400. value={selectedVae}
  401. availableModels={vaeModels}
  402. isAutoSelected={isVaeAutoSelected}
  403. isUserOverride={isVaeUserOverride}
  404. isLoading={isAutoSelecting}
  405. onValueChange={setSelectedVae}
  406. onSetUserOverride={setVaeUserOverride}
  407. onClearOverride={clearVaeUserOverride}
  408. placeholder="Use default VAE"
  409. />
  410. {/* Auto-selection Status */}
  411. <div className="pt-2">
  412. <AutoSelectionStatus
  413. isAutoSelecting={isAutoSelecting}
  414. hasAutoSelection={Object.keys(state.autoSelectedModels).length > 0}
  415. />
  416. </div>
  417. {/* Warnings and Errors */}
  418. <ModelSelectionWarning
  419. warnings={warnings}
  420. errors={error ? [error] : []}
  421. onClearWarnings={actions.clearWarnings}
  422. />
  423. </EnhancedModelSelectGroup>
  424. <div className="flex gap-2">
  425. <Button
  426. type="submit"
  427. disabled={loading || !formData.source_image || !formData.mask_image}
  428. className="flex-1"
  429. >
  430. {loading ? (
  431. <>
  432. <Loader2 className="h-4 w-4 animate-spin" />
  433. Generating...
  434. </>
  435. ) : (
  436. 'Generate'
  437. )}
  438. </Button>
  439. {loading && (
  440. <Button type="button" variant="destructive" onClick={handleCancel}>
  441. <X className="h-4 w-4" />
  442. Cancel
  443. </Button>
  444. )}
  445. </div>
  446. {error && (
  447. <div className="rounded-md bg-destructive/10 p-3 text-sm text-destructive">
  448. {error}
  449. </div>
  450. )}
  451. {loading && jobInfo && (
  452. <div className="rounded-md bg-muted p-3 text-sm">
  453. <p>Job ID: {jobInfo.id || jobInfo.request_id || 'N/A'}</p>
  454. <p>Status: {jobInfo.status}</p>
  455. {jobInfo.progress !== undefined && (
  456. <p>Progress: {Math.round(jobInfo.progress * 100)}%</p>
  457. )}
  458. </div>
  459. )}
  460. </form>
  461. </CardContent>
  462. </Card>
  463. </div>
  464. {/* Right Panel - Generated Images */}
  465. <Card>
  466. <CardContent className="pt-6">
  467. <div className="space-y-4">
  468. <h3 className="text-lg font-semibold">Generated Images</h3>
  469. {generatedImages.length === 0 ? (
  470. <div className="flex h-96 items-center justify-center rounded-lg border-2 border-dashed border-border">
  471. <p className="text-muted-foreground">
  472. {loading ? 'Generating...' : 'Generated images will appear here'}
  473. </p>
  474. </div>
  475. ) : (
  476. <div className="grid gap-4">
  477. {generatedImages.map((image, index) => (
  478. <div key={index} className="relative group">
  479. <img
  480. src={image}
  481. alt={`Generated ${index + 1}`}
  482. className="w-full rounded-lg border border-border"
  483. />
  484. <Button
  485. size="icon"
  486. variant="secondary"
  487. className="absolute top-2 right-2 opacity-0 group-hover:opacity-100 transition-opacity"
  488. onClick={() => {
  489. const authToken = localStorage.getItem('auth_token');
  490. const unixUser = localStorage.getItem('unix_user');
  491. downloadAuthenticatedImage(image, `inpainting-${Date.now()}-${index}.png`, authToken || undefined, unixUser || undefined)
  492. .catch(err => {
  493. console.error('Failed to download image:', err);
  494. });
  495. }}
  496. >
  497. <Download className="h-4 w-4" />
  498. </Button>
  499. </div>
  500. ))}
  501. </div>
  502. )}
  503. </div>
  504. </CardContent>
  505. </Card>
  506. </div>
  507. </div>
  508. </AppLayout>
  509. );
  510. }
  511. export default function InpaintingPage() {
  512. return <InpaintingForm />;
  513. }