inpainting-canvas.tsx 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. 'use client';
  2. import { useRef, useEffect, useState, useCallback } from 'react';
  3. import { Button } from '@/components/ui/button';
  4. import { Card, CardContent } from '@/components/ui/card';
  5. import { Label } from '@/components/ui/label';
  6. import { Input } from '@/components/ui/input';
  7. import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs';
  8. import { Upload, Download, Eraser, Brush, RotateCcw, Link as LinkIcon, Loader2 } from 'lucide-react';
  9. import { fileToBase64 } from '@/lib/utils';
  10. import { validateImageUrlWithBase64 } from '@/lib/image-validation';
  11. import { apiClient } from '@/lib/api';
  12. interface InpaintingCanvasProps {
  13. onSourceImageChange: (image: string) => void;
  14. onMaskImageChange: (image: string) => void;
  15. className?: string;
  16. targetWidth?: number;
  17. targetHeight?: number;
  18. }
  19. export function InpaintingCanvas({
  20. onSourceImageChange,
  21. onMaskImageChange,
  22. className,
  23. targetWidth,
  24. targetHeight
  25. }: InpaintingCanvasProps) {
  26. const canvasRef = useRef<HTMLCanvasElement>(null);
  27. const maskCanvasRef = useRef<HTMLCanvasElement>(null); // Keep for mask generation
  28. const fileInputRef = useRef<HTMLInputElement>(null);
  29. const [sourceImage, setSourceImage] = useState<string | null>(null);
  30. const [originalSourceImage, setOriginalSourceImage] = useState<string | null>(null);
  31. const [isDrawing, setIsDrawing] = useState(false);
  32. const [brushSize, setBrushSize] = useState(20);
  33. const [isEraser, setIsEraser] = useState(false);
  34. const [canvasSize, setCanvasSize] = useState({ width: 512, height: 512 });
  35. const [inputMode, setInputMode] = useState<'file' | 'url'>('file');
  36. const [urlInput, setUrlInput] = useState('');
  37. const [isLoadingUrl, setIsLoadingUrl] = useState(false);
  38. const [urlError, setUrlError] = useState<string | null>(null);
  39. const [isResizing, setIsResizing] = useState(false);
  40. // Initialize canvases
  41. useEffect(() => {
  42. const canvas = canvasRef.current;
  43. const maskCanvas = maskCanvasRef.current;
  44. if (!canvas || !maskCanvas) return;
  45. const ctx = canvas.getContext('2d');
  46. const maskCtx = maskCanvas.getContext('2d');
  47. if (!ctx || !maskCtx) return;
  48. // Set canvas size
  49. canvas.width = canvasSize.width;
  50. canvas.height = canvasSize.height;
  51. maskCanvas.width = canvasSize.width;
  52. maskCanvas.height = canvasSize.height;
  53. // Initialize mask canvas with black (no inpainting)
  54. maskCtx.fillStyle = 'black';
  55. maskCtx.fillRect(0, 0, canvasSize.width, canvasSize.height);
  56. // Update mask image
  57. updateMaskImage();
  58. }, [canvasSize]);
  59. const updateMaskImage = useCallback(() => {
  60. const maskCanvas = maskCanvasRef.current;
  61. if (!maskCanvas) return;
  62. const maskDataUrl = maskCanvas.toDataURL();
  63. onMaskImageChange(maskDataUrl);
  64. }, [onMaskImageChange]);
  65. const loadImageToCanvas = useCallback((base64Image: string) => {
  66. // Store original image for resizing
  67. setOriginalSourceImage(base64Image);
  68. setSourceImage(base64Image);
  69. onSourceImageChange(base64Image);
  70. // Load image to get dimensions and update canvas size
  71. const img = new Image();
  72. img.onload = () => {
  73. // Use target dimensions if provided, otherwise fit within 512x512
  74. let width: number;
  75. let height: number;
  76. if (targetWidth && targetHeight) {
  77. width = targetWidth;
  78. height = targetHeight;
  79. } else {
  80. // Calculate scaled dimensions to fit within 512x512 while maintaining aspect ratio
  81. const maxSize = 512;
  82. width = img.width;
  83. height = img.height;
  84. if (width > maxSize || height > maxSize) {
  85. const aspectRatio = width / height;
  86. if (width > height) {
  87. width = maxSize;
  88. height = maxSize / aspectRatio;
  89. } else {
  90. height = maxSize;
  91. width = maxSize * aspectRatio;
  92. }
  93. }
  94. }
  95. const newCanvasSize = { width: Math.round(width), height: Math.round(height) };
  96. setCanvasSize(newCanvasSize);
  97. // Draw image on main canvas
  98. const canvas = canvasRef.current;
  99. if (!canvas) return;
  100. const ctx = canvas.getContext('2d');
  101. if (!ctx) return;
  102. canvas.width = width;
  103. canvas.height = height;
  104. ctx.drawImage(img, 0, 0, width, height);
  105. // Update mask canvas size
  106. const maskCanvas = maskCanvasRef.current;
  107. if (!maskCanvas) return;
  108. const maskCtx = maskCanvas.getContext('2d');
  109. if (!maskCtx) return;
  110. maskCanvas.width = width;
  111. maskCanvas.height = height;
  112. maskCtx.fillStyle = 'black';
  113. maskCtx.fillRect(0, 0, width, height);
  114. updateMaskImage();
  115. };
  116. img.src = base64Image;
  117. }, [onSourceImageChange, updateMaskImage, targetWidth, targetHeight]);
  118. // Auto-resize image when target dimensions change
  119. useEffect(() => {
  120. const resizeImage = async () => {
  121. if (!originalSourceImage || !targetWidth || !targetHeight) {
  122. return;
  123. }
  124. // Don't resize if we're already resizing
  125. if (isResizing) {
  126. return;
  127. }
  128. try {
  129. setIsResizing(true);
  130. const result = await apiClient.resizeImage(originalSourceImage, targetWidth, targetHeight);
  131. loadImageToCanvas(result.image);
  132. } catch (err) {
  133. console.error('Failed to resize image:', err);
  134. } finally {
  135. setIsResizing(false);
  136. }
  137. };
  138. resizeImage();
  139. }, [targetWidth, targetHeight, originalSourceImage]);
  140. const handleImageUpload = async (e: React.ChangeEvent<HTMLInputElement>) => {
  141. const file = e.target.files?.[0];
  142. if (!file) return;
  143. try {
  144. const base64 = await fileToBase64(file);
  145. loadImageToCanvas(base64);
  146. } catch (err) {
  147. console.error('Failed to load image:', err);
  148. }
  149. };
  150. const handleUrlLoad = async () => {
  151. if (!urlInput.trim()) {
  152. setUrlError('Please enter a URL');
  153. return;
  154. }
  155. setIsLoadingUrl(true);
  156. setUrlError(null);
  157. try {
  158. const result = await validateImageUrlWithBase64(urlInput);
  159. if (!result.isValid) {
  160. setUrlError(result.error || 'Failed to load image from URL');
  161. setIsLoadingUrl(false);
  162. return;
  163. }
  164. // Use base64 data if available, otherwise use the URL directly
  165. const imageData = result.base64Data || urlInput;
  166. loadImageToCanvas(imageData);
  167. setIsLoadingUrl(false);
  168. } catch (err) {
  169. setUrlError(err instanceof Error ? err.message : 'Failed to load image from URL');
  170. setIsLoadingUrl(false);
  171. }
  172. };
  173. const startDrawing = (e: React.MouseEvent<HTMLCanvasElement>) => {
  174. if (!sourceImage) return;
  175. setIsDrawing(true);
  176. draw(e);
  177. };
  178. const stopDrawing = () => {
  179. setIsDrawing(false);
  180. };
  181. const draw = (e: React.MouseEvent<HTMLCanvasElement>) => {
  182. if (!isDrawing || !sourceImage) return;
  183. const canvas = canvasRef.current;
  184. const maskCanvas = maskCanvasRef.current;
  185. if (!canvas || !maskCanvas) return;
  186. const ctx = canvas.getContext('2d');
  187. const maskCtx = maskCanvas.getContext('2d');
  188. if (!ctx || !maskCtx) return;
  189. const rect = canvas.getBoundingClientRect();
  190. const scaleX = canvas.width / rect.width;
  191. const scaleY = canvas.height / rect.height;
  192. const x = (e.clientX - rect.left) * scaleX;
  193. const y = (e.clientY - rect.top) * scaleY;
  194. // Draw on mask canvas (for API)
  195. maskCtx.globalCompositeOperation = 'source-over';
  196. maskCtx.fillStyle = isEraser ? 'black' : 'white';
  197. maskCtx.beginPath();
  198. maskCtx.arc(x, y, brushSize, 0, Math.PI * 2);
  199. maskCtx.fill();
  200. // Draw visual overlay directly on main canvas
  201. ctx.save();
  202. ctx.globalCompositeOperation = 'source-over';
  203. if (isEraser) {
  204. // For eraser, just redraw the image at that position
  205. const img = new Image();
  206. img.onload = () => {
  207. // Clear the area and redraw
  208. ctx.save();
  209. ctx.globalCompositeOperation = 'destination-out';
  210. ctx.beginPath();
  211. ctx.arc(x, y, brushSize, 0, Math.PI * 2);
  212. ctx.fill();
  213. ctx.restore();
  214. // Redraw the image in the cleared area
  215. ctx.save();
  216. ctx.globalCompositeOperation = 'destination-over';
  217. ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
  218. ctx.restore();
  219. updateMaskImage();
  220. };
  221. img.src = sourceImage;
  222. } else {
  223. // For brush, draw a visible overlay
  224. ctx.globalAlpha = 0.6;
  225. ctx.fillStyle = 'rgba(255, 105, 180, 0.8)'; // Bright pink for visibility
  226. ctx.beginPath();
  227. ctx.arc(x, y, brushSize, 0, Math.PI * 2);
  228. ctx.fill();
  229. // Also draw a border for better visibility
  230. ctx.globalAlpha = 1.0;
  231. ctx.strokeStyle = 'rgba(255, 0, 0, 0.9)'; // Red border
  232. ctx.lineWidth = 2;
  233. ctx.beginPath();
  234. ctx.arc(x, y, brushSize, 0, Math.PI * 2);
  235. ctx.stroke();
  236. ctx.restore();
  237. updateMaskImage();
  238. }
  239. };
  240. const clearMask = () => {
  241. const canvas = canvasRef.current;
  242. const maskCanvas = maskCanvasRef.current;
  243. if (!canvas || !maskCanvas) return;
  244. const ctx = canvas.getContext('2d');
  245. const maskCtx = maskCanvas.getContext('2d');
  246. if (!ctx || !maskCtx) return;
  247. // Clear mask canvas
  248. maskCtx.fillStyle = 'black';
  249. maskCtx.fillRect(0, 0, maskCanvas.width, maskCanvas.height);
  250. // Redraw source image on main canvas
  251. if (sourceImage) {
  252. const img = new Image();
  253. img.onload = () => {
  254. ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
  255. };
  256. img.src = sourceImage;
  257. }
  258. updateMaskImage();
  259. };
  260. const downloadMask = () => {
  261. const canvas = maskCanvasRef.current;
  262. if (!canvas) return;
  263. const link = document.createElement('a');
  264. link.download = 'inpainting-mask.png';
  265. link.href = canvas.toDataURL();
  266. link.click();
  267. };
  268. return (
  269. <div className={`space-y-4 ${className}`}>
  270. <Card>
  271. <CardContent className="pt-6">
  272. <div className="space-y-4">
  273. <div className="space-y-2">
  274. <Label>Source Image</Label>
  275. <Tabs value={inputMode} onValueChange={(value) => setInputMode(value as 'file' | 'url')}>
  276. <TabsList className="grid w-full grid-cols-2">
  277. <TabsTrigger value="file">
  278. <Upload className="w-4 h-4 mr-2" />
  279. Upload File
  280. </TabsTrigger>
  281. <TabsTrigger value="url">
  282. <LinkIcon className="w-4 h-4 mr-2" />
  283. From URL
  284. </TabsTrigger>
  285. </TabsList>
  286. <TabsContent value="file" className="space-y-4 mt-4">
  287. <Button
  288. type="button"
  289. variant="outline"
  290. onClick={() => fileInputRef.current?.click()}
  291. className="w-full"
  292. >
  293. <Upload className="h-4 w-4 mr-2" />
  294. {sourceImage ? 'Change Image' : 'Upload Image'}
  295. </Button>
  296. <input
  297. ref={fileInputRef}
  298. type="file"
  299. accept="image/*"
  300. onChange={handleImageUpload}
  301. className="hidden"
  302. />
  303. </TabsContent>
  304. <TabsContent value="url" className="space-y-4 mt-4">
  305. <div className="space-y-2">
  306. <Input
  307. type="url"
  308. value={urlInput}
  309. onChange={(e) => {
  310. setUrlInput(e.target.value);
  311. setUrlError(null);
  312. }}
  313. placeholder="https://example.com/image.png"
  314. disabled={isLoadingUrl}
  315. />
  316. <p className="text-xs text-muted-foreground">
  317. Enter a URL that ends with an image extension (.jpg, .png, .gif, etc.)
  318. </p>
  319. </div>
  320. <Button
  321. type="button"
  322. variant="outline"
  323. onClick={handleUrlLoad}
  324. disabled={isLoadingUrl || !urlInput.trim()}
  325. className="w-full"
  326. >
  327. {isLoadingUrl ? (
  328. <>
  329. <Download className="h-4 w-4 mr-2 animate-spin" />
  330. Loading...
  331. </>
  332. ) : (
  333. <>
  334. <Download className="h-4 w-4 mr-2" />
  335. Load from URL
  336. </>
  337. )}
  338. </Button>
  339. {urlError && (
  340. <p className="text-sm text-destructive">{urlError}</p>
  341. )}
  342. </TabsContent>
  343. </Tabs>
  344. </div>
  345. {isResizing && (
  346. <div className="text-sm text-muted-foreground flex items-center gap-2">
  347. <Loader2 className="h-4 w-4 animate-spin" />
  348. Resizing image...
  349. </div>
  350. )}
  351. {sourceImage && (
  352. <>
  353. <div className="space-y-2">
  354. <Label>Mask Editor</Label>
  355. <div className="relative">
  356. <canvas
  357. ref={canvasRef}
  358. className="rounded-lg border border-border cursor-crosshair"
  359. style={{ maxWidth: '512px', height: 'auto' }}
  360. onMouseDown={startDrawing}
  361. onMouseUp={stopDrawing}
  362. onMouseMove={draw}
  363. onMouseLeave={stopDrawing}
  364. />
  365. </div>
  366. <p className="text-xs text-muted-foreground">
  367. Draw on the image to mark areas for inpainting. White areas will be inpainted, black areas will be preserved.
  368. </p>
  369. </div>
  370. <div className="space-y-2">
  371. <Label htmlFor="brush_size">
  372. Brush Size: {brushSize}px
  373. </Label>
  374. <input
  375. id="brush_size"
  376. type="range"
  377. value={brushSize}
  378. onChange={(e) => setBrushSize(Number(e.target.value))}
  379. min={1}
  380. max={100}
  381. className="w-full"
  382. />
  383. </div>
  384. <div className="flex gap-2">
  385. <Button
  386. type="button"
  387. variant={isEraser ? "default" : "outline"}
  388. onClick={() => setIsEraser(true)}
  389. className="flex-1"
  390. >
  391. <Eraser className="h-4 w-4 mr-2" />
  392. Eraser
  393. </Button>
  394. <Button
  395. type="button"
  396. variant={!isEraser ? "default" : "outline"}
  397. onClick={() => setIsEraser(false)}
  398. className="flex-1"
  399. >
  400. <Brush className="h-4 w-4 mr-2" />
  401. Brush
  402. </Button>
  403. </div>
  404. <div className="flex gap-2">
  405. <Button
  406. type="button"
  407. variant="outline"
  408. onClick={clearMask}
  409. className="flex-1"
  410. >
  411. <RotateCcw className="h-4 w-4 mr-2" />
  412. Clear Mask
  413. </Button>
  414. <Button
  415. type="button"
  416. variant="outline"
  417. onClick={downloadMask}
  418. className="flex-1"
  419. >
  420. <Download className="h-4 w-4 mr-2" />
  421. Download Mask
  422. </Button>
  423. </div>
  424. </>
  425. )}
  426. </div>
  427. </CardContent>
  428. </Card>
  429. </div>
  430. );
  431. }