page.tsx 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634
  1. 'use client';
  2. import { useState, useRef, useEffect } from 'react';
  3. import { Header } from '@/components/header';
  4. import { AppLayout } from '@/components/layout';
  5. import { Button } from '@/components/ui/button';
  6. import { Input } from '@/components/ui/input';
  7. import { Textarea } from '@/components/ui/textarea';
  8. import { PromptTextarea } from '@/components/prompt-textarea';
  9. import { Label } from '@/components/ui/label';
  10. import { Card, CardContent } from '@/components/ui/card';
  11. import { ImageInput } from '@/components/ui/image-input';
  12. import { apiClient, type JobInfo } from '@/lib/api';
  13. import { Loader2, Download, X } from 'lucide-react';
  14. import { downloadImage, downloadAuthenticatedImage, fileToBase64 } from '@/lib/utils';
  15. import { useLocalStorage } from '@/lib/hooks';
  16. import { ModelSelectionProvider, useModelSelection, useCheckpointSelection, useModelTypeSelection } from '@/contexts/model-selection-context';
  17. import { EnhancedModelSelect, EnhancedModelSelectGroup } from '@/components/enhanced-model-select';
  18. import { ModelSelectionWarning, AutoSelectionStatus } from '@/components/model-selection-indicator';
  19. type Img2ImgFormData = {
  20. prompt: string;
  21. negative_prompt: string;
  22. image: string;
  23. strength: number;
  24. steps: number;
  25. cfg_scale: number;
  26. seed: string;
  27. sampling_method: string;
  28. width?: number;
  29. height?: number;
  30. };
  31. const defaultFormData: Img2ImgFormData = {
  32. prompt: '',
  33. negative_prompt: '',
  34. image: '',
  35. strength: 0.75,
  36. steps: 20,
  37. cfg_scale: 7.5,
  38. seed: '',
  39. sampling_method: 'euler_a',
  40. width: 512,
  41. height: 512,
  42. };
  43. function Img2ImgForm() {
  44. const { state, actions } = useModelSelection();
  45. const {
  46. checkpointModels,
  47. selectedCheckpointModel,
  48. selectedCheckpoint,
  49. setSelectedCheckpoint,
  50. isAutoSelecting,
  51. warnings,
  52. error: checkpointError
  53. } = useCheckpointSelection();
  54. const {
  55. availableModels: vaeModels,
  56. selectedModel: selectedVae,
  57. isUserOverride: isVaeUserOverride,
  58. isAutoSelected: isVaeAutoSelected,
  59. setSelectedModel: setSelectedVae,
  60. setUserOverride: setVaeUserOverride,
  61. clearUserOverride: clearVaeUserOverride,
  62. } = useModelTypeSelection('vae');
  63. const [formData, setFormData] = useLocalStorage<Img2ImgFormData>(
  64. 'img2img-form-data',
  65. defaultFormData
  66. );
  67. const [loading, setLoading] = useState(false);
  68. const [jobInfo, setJobInfo] = useState<JobInfo | null>(null);
  69. const [generatedImages, setGeneratedImages] = useState<string[]>([]);
  70. const [previewImage, setPreviewImage] = useState<string | null>(null);
  71. const [loraModels, setLoraModels] = useState<string[]>([]);
  72. const [embeddings, setEmbeddings] = useState<string[]>([]);
  73. const [selectedImage, setSelectedImage] = useState<File | string | null>(null);
  74. const [imageValidation, setImageValidation] = useState<any>(null);
  75. const [originalImage, setOriginalImage] = useState<string | null>(null);
  76. const [isResizing, setIsResizing] = useState(false);
  77. const [error, setError] = useState<string | null>(null);
  78. useEffect(() => {
  79. const loadModels = async () => {
  80. try {
  81. const [modelsData, loras, embeds] = await Promise.all([
  82. apiClient.getModels(), // Get all models with enhanced info
  83. apiClient.getModels('lora'),
  84. apiClient.getModels('embedding'),
  85. ]);
  86. actions.setModels(modelsData.models);
  87. setLoraModels(loras.models.map(m => m.name));
  88. setEmbeddings(embeds.models.map(m => m.name));
  89. } catch (err) {
  90. console.error('Failed to load models:', err);
  91. }
  92. };
  93. loadModels();
  94. }, [actions]);
  95. // Update form data when checkpoint changes
  96. useEffect(() => {
  97. if (selectedCheckpoint) {
  98. setFormData(prev => ({
  99. ...prev,
  100. model: selectedCheckpoint,
  101. }));
  102. }
  103. }, [selectedCheckpoint, setFormData]);
  104. const handleInputChange = (
  105. e: React.ChangeEvent<HTMLInputElement | HTMLTextAreaElement | HTMLSelectElement>
  106. ) => {
  107. const { name, value } = e.target;
  108. setFormData((prev) => ({
  109. ...prev,
  110. [name]: name === 'prompt' || name === 'negative_prompt' || name === 'seed' || name === 'sampling_method'
  111. ? value
  112. : Number(value),
  113. }));
  114. };
  115. const handleImageChange = async (image: File | string | null) => {
  116. setSelectedImage(image);
  117. setError(null);
  118. if (!image) {
  119. setFormData(prev => ({ ...prev, image: '' }));
  120. setPreviewImage(null);
  121. setImageValidation(null);
  122. setOriginalImage(null);
  123. return;
  124. }
  125. try {
  126. let imageBase64: string;
  127. let previewUrl: string;
  128. if (image instanceof File) {
  129. // Convert File to base64
  130. imageBase64 = await fileToBase64(image);
  131. previewUrl = imageBase64;
  132. } else {
  133. // Use URL directly
  134. imageBase64 = image;
  135. previewUrl = image;
  136. }
  137. // Store original image for resizing
  138. setOriginalImage(imageBase64);
  139. setFormData(prev => ({ ...prev, image: imageBase64 }));
  140. setPreviewImage(previewUrl);
  141. } catch (err) {
  142. setError('Failed to process image');
  143. console.error('Image processing error:', err);
  144. }
  145. };
  146. // Auto-resize image when width or height changes
  147. useEffect(() => {
  148. const resizeImage = async () => {
  149. if (!originalImage || !formData.width || !formData.height) {
  150. return;
  151. }
  152. // Don't resize if we're already resizing
  153. if (isResizing) {
  154. return;
  155. }
  156. try {
  157. setIsResizing(true);
  158. const result = await apiClient.resizeImage(originalImage, formData.width, formData.height);
  159. setFormData(prev => ({ ...prev, image: result.image }));
  160. setPreviewImage(result.image);
  161. } catch (err) {
  162. console.error('Failed to resize image:', err);
  163. setError('Failed to resize image');
  164. } finally {
  165. setIsResizing(false);
  166. }
  167. };
  168. resizeImage();
  169. }, [formData.width, formData.height, originalImage]);
  170. const handleImageValidation = (result: any) => {
  171. setImageValidation(result);
  172. if (!result.isValid) {
  173. setError(result.error || 'Invalid image');
  174. } else {
  175. setError(null);
  176. }
  177. };
  178. const pollJobStatus = async (jobId: string) => {
  179. const maxAttempts = 300;
  180. let attempts = 0;
  181. const poll = async () => {
  182. try {
  183. const status = await apiClient.getJobStatus(jobId);
  184. setJobInfo(status);
  185. if (status.status === 'completed') {
  186. let imageUrls: string[] = [];
  187. // Handle both old format (result.images) and new format (outputs)
  188. if (status.outputs && status.outputs.length > 0) {
  189. // New format: convert output URLs to authenticated image URLs with cache-busting
  190. imageUrls = status.outputs.map((output: any) => {
  191. const filename = output.filename;
  192. return apiClient.getImageUrl(jobId, filename);
  193. });
  194. } else if (status.result?.images && status.result.images.length > 0) {
  195. // Old format: convert image URLs to authenticated URLs
  196. imageUrls = status.result.images.map((imageUrl: string) => {
  197. // Extract filename from URL if it's already a full URL
  198. if (imageUrl.includes('/output/')) {
  199. const parts = imageUrl.split('/output/');
  200. if (parts.length === 2) {
  201. const filename = parts[1].split('?')[0]; // Remove query params
  202. return apiClient.getImageUrl(jobId, filename);
  203. }
  204. }
  205. // If it's just a filename, convert it directly
  206. return apiClient.getImageUrl(jobId, imageUrl);
  207. });
  208. }
  209. // Create a new array to trigger React re-render
  210. setGeneratedImages([...imageUrls]);
  211. setLoading(false);
  212. } else if (status.status === 'failed') {
  213. setError(status.error || 'Generation failed');
  214. setLoading(false);
  215. } else if (status.status === 'cancelled') {
  216. setError('Generation was cancelled');
  217. setLoading(false);
  218. } else if (attempts < maxAttempts) {
  219. attempts++;
  220. setTimeout(poll, 2000);
  221. } else {
  222. setError('Job polling timeout');
  223. setLoading(false);
  224. }
  225. } catch (err) {
  226. setError(err instanceof Error ? err.message : 'Failed to check job status');
  227. setLoading(false);
  228. }
  229. };
  230. poll();
  231. };
  232. const handleGenerate = async (e: React.FormEvent) => {
  233. e.preventDefault();
  234. if (!formData.image) {
  235. setError('Please upload or select an image first');
  236. return;
  237. }
  238. // Check if image validation passed
  239. if (imageValidation && !imageValidation.isValid) {
  240. setError('Please fix the image validation errors before generating');
  241. return;
  242. }
  243. setLoading(true);
  244. setError(null);
  245. setGeneratedImages([]);
  246. setJobInfo(null);
  247. try {
  248. // Validate model selection
  249. if (selectedCheckpointModel) {
  250. const validation = actions.validateSelection(selectedCheckpointModel);
  251. if (!validation.isValid) {
  252. setError(`Missing required models: ${validation.missingRequired.join(', ')}`);
  253. setLoading(false);
  254. return;
  255. }
  256. }
  257. const requestData = {
  258. ...formData,
  259. model: selectedCheckpoint || undefined,
  260. vae: selectedVae || undefined,
  261. };
  262. const job = await apiClient.img2img(requestData);
  263. setJobInfo(job);
  264. const jobId = job.request_id || job.id;
  265. if (jobId) {
  266. await pollJobStatus(jobId);
  267. } else {
  268. setError('No job ID returned from server');
  269. setLoading(false);
  270. }
  271. } catch (err) {
  272. setError(err instanceof Error ? err.message : 'Failed to generate image');
  273. setLoading(false);
  274. }
  275. };
  276. const handleCancel = async () => {
  277. const jobId = jobInfo?.request_id || jobInfo?.id;
  278. if (jobId) {
  279. try {
  280. await apiClient.cancelJob(jobId);
  281. setLoading(false);
  282. setError('Generation cancelled');
  283. } catch (err) {
  284. console.error('Failed to cancel job:', err);
  285. }
  286. }
  287. };
  288. return (
  289. <AppLayout>
  290. <Header title="Image to Image" description="Transform images with AI using text prompts" />
  291. <div className="container mx-auto p-6">
  292. <div className="grid gap-6 lg:grid-cols-2">
  293. {/* Left Panel - Form */}
  294. <Card>
  295. <CardContent className="pt-6">
  296. <form onSubmit={handleGenerate} className="space-y-4">
  297. <div className="space-y-2">
  298. <Label>Source Image *</Label>
  299. <ImageInput
  300. value={selectedImage}
  301. onChange={handleImageChange}
  302. onValidation={handleImageValidation}
  303. disabled={loading}
  304. maxSize={10 * 1024 * 1024} // 10MB
  305. accept="image/*"
  306. placeholder="Enter image URL or select a file"
  307. showPreview={true}
  308. />
  309. </div>
  310. <div className="space-y-2">
  311. <Label htmlFor="prompt">Prompt *</Label>
  312. <PromptTextarea
  313. value={formData.prompt}
  314. onChange={(value) => setFormData({ ...formData, prompt: value })}
  315. placeholder="Describe the transformation you want..."
  316. rows={3}
  317. loras={loraModels}
  318. embeddings={embeddings}
  319. />
  320. <p className="text-xs text-muted-foreground">
  321. Tip: Use &lt;lora:name:weight&gt; for LoRAs and embedding names directly
  322. </p>
  323. </div>
  324. <div className="space-y-2">
  325. <Label htmlFor="negative_prompt">Negative Prompt</Label>
  326. <PromptTextarea
  327. value={formData.negative_prompt || ''}
  328. onChange={(value) => setFormData({ ...formData, negative_prompt: value })}
  329. placeholder="What to avoid..."
  330. rows={2}
  331. loras={loraModels}
  332. embeddings={embeddings}
  333. />
  334. </div>
  335. <div className="space-y-2">
  336. <Label htmlFor="strength">
  337. Strength: {formData.strength.toFixed(2)}
  338. </Label>
  339. <Input
  340. id="strength"
  341. name="strength"
  342. type="range"
  343. value={formData.strength}
  344. onChange={handleInputChange}
  345. min={0}
  346. max={1}
  347. step={0.05}
  348. />
  349. <p className="text-xs text-muted-foreground">
  350. Lower values preserve more of the original image
  351. </p>
  352. </div>
  353. <div className="grid grid-cols-2 gap-4">
  354. <div className="space-y-2">
  355. <Label htmlFor="width">Width</Label>
  356. <Input
  357. id="width"
  358. name="width"
  359. type="number"
  360. value={formData.width}
  361. onChange={handleInputChange}
  362. step={64}
  363. min={256}
  364. max={2048}
  365. disabled={isResizing}
  366. />
  367. </div>
  368. <div className="space-y-2">
  369. <Label htmlFor="height">Height</Label>
  370. <Input
  371. id="height"
  372. name="height"
  373. type="number"
  374. value={formData.height}
  375. onChange={handleInputChange}
  376. step={64}
  377. min={256}
  378. max={2048}
  379. disabled={isResizing}
  380. />
  381. </div>
  382. </div>
  383. {isResizing && (
  384. <div className="text-sm text-muted-foreground flex items-center gap-2">
  385. <Loader2 className="h-4 w-4 animate-spin" />
  386. Resizing image...
  387. </div>
  388. )}
  389. <div className="grid grid-cols-2 gap-4">
  390. <div className="space-y-2">
  391. <Label htmlFor="steps">Steps</Label>
  392. <Input
  393. id="steps"
  394. name="steps"
  395. type="number"
  396. value={formData.steps}
  397. onChange={handleInputChange}
  398. min={1}
  399. max={150}
  400. />
  401. </div>
  402. <div className="space-y-2">
  403. <Label htmlFor="cfg_scale">CFG Scale</Label>
  404. <Input
  405. id="cfg_scale"
  406. name="cfg_scale"
  407. type="number"
  408. value={formData.cfg_scale}
  409. onChange={handleInputChange}
  410. step={0.5}
  411. min={1}
  412. max={30}
  413. />
  414. </div>
  415. </div>
  416. <div className="space-y-2">
  417. <Label htmlFor="seed">Seed (optional)</Label>
  418. <Input
  419. id="seed"
  420. name="seed"
  421. value={formData.seed}
  422. onChange={handleInputChange}
  423. placeholder="Leave empty for random"
  424. />
  425. </div>
  426. <div className="space-y-2">
  427. <Label htmlFor="sampling_method">Sampling Method</Label>
  428. <select
  429. id="sampling_method"
  430. name="sampling_method"
  431. value={formData.sampling_method}
  432. onChange={handleInputChange}
  433. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  434. >
  435. <option value="euler">Euler</option>
  436. <option value="euler_a">Euler A</option>
  437. <option value="heun">Heun</option>
  438. <option value="dpm2">DPM2</option>
  439. <option value="dpm++2s_a">DPM++ 2S A</option>
  440. <option value="dpm++2m">DPM++ 2M</option>
  441. <option value="dpm++2mv2">DPM++ 2M V2</option>
  442. <option value="lcm">LCM</option>
  443. </select>
  444. </div>
  445. {/* Model Selection Section */}
  446. <EnhancedModelSelectGroup
  447. title="Model Selection"
  448. description="Select the checkpoint and additional models for generation"
  449. >
  450. {/* Checkpoint Selection */}
  451. <div className="space-y-2">
  452. <Label htmlFor="checkpoint">Checkpoint Model *</Label>
  453. <select
  454. id="checkpoint"
  455. value={selectedCheckpoint || ''}
  456. onChange={(e) => setSelectedCheckpoint(e.target.value || null)}
  457. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  458. disabled={isAutoSelecting}
  459. >
  460. <option value="">Select a checkpoint model...</option>
  461. {checkpointModels.map((model) => (
  462. <option key={model.id} value={model.name}>
  463. {model.name} {model.loaded ? '(Loaded)' : ''}
  464. </option>
  465. ))}
  466. </select>
  467. </div>
  468. {/* VAE Selection */}
  469. <EnhancedModelSelect
  470. modelType="vae"
  471. label="VAE Model"
  472. description="Optional VAE model for improved image quality"
  473. value={selectedVae}
  474. availableModels={vaeModels}
  475. isAutoSelected={isVaeAutoSelected}
  476. isUserOverride={isVaeUserOverride}
  477. isLoading={isAutoSelecting}
  478. onValueChange={setSelectedVae}
  479. onSetUserOverride={setVaeUserOverride}
  480. onClearOverride={clearVaeUserOverride}
  481. placeholder="Use default VAE"
  482. />
  483. {/* Auto-selection Status */}
  484. <div className="pt-2">
  485. <AutoSelectionStatus
  486. isAutoSelecting={isAutoSelecting}
  487. hasAutoSelection={Object.keys(state.autoSelectedModels).length > 0}
  488. />
  489. </div>
  490. {/* Warnings and Errors */}
  491. <ModelSelectionWarning
  492. warnings={warnings}
  493. errors={error ? [error] : []}
  494. onClearWarnings={actions.clearWarnings}
  495. />
  496. </EnhancedModelSelectGroup>
  497. <div className="flex gap-2">
  498. <Button type="submit" disabled={loading || !formData.image || (imageValidation && !imageValidation.isValid)} className="flex-1">
  499. {loading ? (
  500. <>
  501. <Loader2 className="h-4 w-4 animate-spin" />
  502. Generating...
  503. </>
  504. ) : (
  505. 'Generate'
  506. )}
  507. </Button>
  508. {loading && (
  509. <Button type="button" variant="destructive" onClick={handleCancel}>
  510. <X className="h-4 w-4" />
  511. Cancel
  512. </Button>
  513. )}
  514. </div>
  515. {error && (
  516. <div className="rounded-md bg-destructive/10 p-3 text-sm text-destructive">
  517. {error}
  518. </div>
  519. )}
  520. {loading && jobInfo && (
  521. <div className="rounded-md bg-muted p-3 text-sm">
  522. <p>Job ID: {jobInfo.id || jobInfo.request_id || 'N/A'}</p>
  523. <p>Status: {jobInfo.status}</p>
  524. {jobInfo.progress !== undefined && (
  525. <p>Progress: {Math.round(jobInfo.progress * 100)}%</p>
  526. )}
  527. </div>
  528. )}
  529. </form>
  530. </CardContent>
  531. </Card>
  532. {/* Right Panel - Generated Images */}
  533. <Card>
  534. <CardContent className="pt-6">
  535. <div className="space-y-4">
  536. <h3 className="text-lg font-semibold">Generated Images</h3>
  537. {generatedImages.length === 0 ? (
  538. <div className="flex h-96 items-center justify-center rounded-lg border-2 border-dashed border-border">
  539. <p className="text-muted-foreground">
  540. {loading ? 'Generating...' : 'Generated images will appear here'}
  541. </p>
  542. </div>
  543. ) : (
  544. <div className="grid gap-4">
  545. {generatedImages.map((image, index) => (
  546. <div key={index} className="relative group">
  547. <img
  548. src={image}
  549. alt={`Generated ${index + 1}`}
  550. className="w-full rounded-lg border border-border"
  551. />
  552. <Button
  553. size="icon"
  554. variant="secondary"
  555. className="absolute top-2 right-2 opacity-0 group-hover:opacity-100 transition-opacity"
  556. onClick={() => {
  557. const authToken = localStorage.getItem('auth_token');
  558. const unixUser = localStorage.getItem('unix_user');
  559. downloadAuthenticatedImage(image, `img2img-${Date.now()}-${index}.png`, authToken || undefined, unixUser || undefined)
  560. .catch(err => {
  561. console.error('Failed to download image:', err);
  562. // Fallback to regular download if authenticated download fails
  563. downloadImage(image, `img2img-${Date.now()}-${index}.png`);
  564. });
  565. }}
  566. >
  567. <Download className="h-4 w-4" />
  568. </Button>
  569. </div>
  570. ))}
  571. </div>
  572. )}
  573. </div>
  574. </CardContent>
  575. </Card>
  576. </div>
  577. </div>
  578. </AppLayout>
  579. );
  580. }
  581. export default function Img2ImgPage() {
  582. return <Img2ImgForm />;
  583. }