page.tsx 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. "use client";
  2. import { useState, useRef, useEffect } from "react";
  3. import { Header } from "@/components/layout";
  4. import { AppLayout } from "@/components/layout";
  5. import { Button } from "@/components/ui/button";
  6. import { Input } from "@/components/ui/input";
  7. import { Label } from "@/components/ui/label";
  8. import {
  9. Card,
  10. CardContent,
  11. CardHeader,
  12. CardTitle,
  13. CardDescription,
  14. } from "@/components/ui/card";
  15. import {
  16. apiClient,
  17. type JobInfo,
  18. type ModelInfo,
  19. type JobDetailsResponse,
  20. } from "@/lib/api";
  21. import { Loader2, Download, X, Upload } from "lucide-react";
  22. import {
  23. downloadImage,
  24. downloadAuthenticatedImage,
  25. fileToBase64,
  26. } from "@/lib/utils";
  27. import { useLocalStorage } from "@/lib/storage";
  28. import {
  29. ModelSelectionProvider,
  30. useModelSelection,
  31. useModelTypeSelection,
  32. } from "@/contexts/model-selection-context";
  33. import {
  34. Select,
  35. SelectContent,
  36. SelectItem,
  37. SelectTrigger,
  38. SelectValue,
  39. } from "@/components/ui/select";
  40. // import { AutoSelectionStatus } from '@/components/features/models';
  41. type UpscalerFormData = {
  42. upscale_factor: number;
  43. model: string;
  44. };
  45. const defaultFormData: UpscalerFormData = {
  46. upscale_factor: 2,
  47. model: "",
  48. };
  49. function UpscalerForm() {
  50. const { state, actions } = useModelSelection();
  51. const {
  52. availableModels: upscalerModels,
  53. selectedModel: selectedUpscalerModel,
  54. isUserOverride: isUpscalerUserOverride,
  55. isAutoSelected: isUpscalerAutoSelected,
  56. setSelectedModel: setSelectedUpscalerModel,
  57. setUserOverride: setUpscalerUserOverride,
  58. clearUserOverride: clearUpscalerUserOverride,
  59. } = useModelTypeSelection("upscaler");
  60. const [formData, setFormData] = useLocalStorage<UpscalerFormData>(
  61. "upscaler-form-data",
  62. defaultFormData,
  63. { excludeLargeData: true, maxSize: 512 * 1024 },
  64. );
  65. // Separate state for image data (not stored in localStorage)
  66. const [uploadedImage, setUploadedImage] = useState<string>("");
  67. const [previewImage, setPreviewImage] = useState<string | null>(null);
  68. const [loading, setLoading] = useState(false);
  69. const [error, setError] = useState<string | null>(null);
  70. const [jobInfo, setJobInfo] = useState<JobInfo | null>(null);
  71. const [generatedImages, setGeneratedImages] = useState<string[]>([]);
  72. const [pollCleanup, setPollCleanup] = useState<(() => void) | null>(null);
  73. const fileInputRef = useRef<HTMLInputElement>(null);
  74. // Cleanup polling on unmount
  75. useEffect(() => {
  76. return () => {
  77. if (pollCleanup) {
  78. pollCleanup();
  79. }
  80. };
  81. }, [pollCleanup]);
  82. useEffect(() => {
  83. const loadModels = async () => {
  84. try {
  85. // Fetch all models with enhanced info
  86. const modelsData = await apiClient.getModels();
  87. // Filter for upscaler models (ESRGAN and upscaler types)
  88. const allUpscalerModels = [
  89. ...modelsData.models.filter((m) => m.type.toLowerCase() === "esrgan"),
  90. ...modelsData.models.filter(
  91. (m) => m.type.toLowerCase() === "upscaler",
  92. ),
  93. ];
  94. actions.setModels(modelsData.models);
  95. // Set first model as default if none selected
  96. if (allUpscalerModels.length > 0 && !formData.model) {
  97. setFormData((prev) => ({
  98. ...prev,
  99. model: allUpscalerModels[0].name,
  100. }));
  101. }
  102. } catch (err) {
  103. console.error("Failed to load upscaler models:", err);
  104. }
  105. };
  106. loadModels();
  107. }, [actions, formData.model, setFormData]);
  108. // Update form data when upscaler model changes
  109. useEffect(() => {
  110. if (selectedUpscalerModel) {
  111. setFormData((prev) => ({
  112. ...prev,
  113. model: selectedUpscalerModel,
  114. }));
  115. }
  116. }, [selectedUpscalerModel, setFormData]);
  117. const handleInputChange = (
  118. e: React.ChangeEvent<HTMLInputElement | HTMLSelectElement>,
  119. ) => {
  120. const { name, value } = e.target;
  121. setFormData((prev) => ({
  122. ...prev,
  123. [name]: name === "upscale_factor" ? Number(value) : value,
  124. }));
  125. };
  126. const handleImageUpload = async (e: React.ChangeEvent<HTMLInputElement>) => {
  127. const file = e.target.files?.[0];
  128. if (!file) return;
  129. try {
  130. const base64 = await fileToBase64(file);
  131. setUploadedImage(base64);
  132. setPreviewImage(base64);
  133. setError(null);
  134. } catch (err) {
  135. setError("Failed to load image");
  136. }
  137. };
  138. const pollJobStatus = async (jobId: string) => {
  139. const maxAttempts = 300;
  140. let attempts = 0;
  141. let isPolling = true;
  142. let timeoutId: NodeJS.Timeout | null = null;
  143. const poll = async () => {
  144. if (!isPolling) return;
  145. try {
  146. const status: JobDetailsResponse = await apiClient.getJobStatus(jobId);
  147. setJobInfo(status.job);
  148. if (status.job.status === "completed") {
  149. let imageUrls: string[] = [];
  150. // Handle both old format (result.images) and new format (outputs)
  151. if (status.job.outputs && status.job.outputs.length > 0) {
  152. // New format: convert output URLs to authenticated image URLs with cache-busting
  153. imageUrls = status.job.outputs.map((output: any) => {
  154. const filename = output.filename;
  155. return apiClient.getImageUrl(jobId, filename);
  156. });
  157. } else if (
  158. status.job.result?.images &&
  159. status.job.result.images.length > 0
  160. ) {
  161. // Old format: convert image URLs to authenticated URLs
  162. imageUrls = status.job.result.images.map((imageUrl: string) => {
  163. // Extract filename from URL if it's already a full URL
  164. if (imageUrl.includes("/output/")) {
  165. const parts = imageUrl.split("/output/");
  166. if (parts.length === 2) {
  167. const filename = parts[1].split("?")[0]; // Remove query params
  168. return apiClient.getImageUrl(jobId, filename);
  169. }
  170. }
  171. // If it's just a filename, convert it directly
  172. return apiClient.getImageUrl(jobId, imageUrl);
  173. });
  174. }
  175. // Create a new array to trigger React re-render
  176. setGeneratedImages([...imageUrls]);
  177. setLoading(false);
  178. isPolling = false;
  179. } else if (status.job.status === "failed") {
  180. setError(status.job.error || "Upscaling failed");
  181. setLoading(false);
  182. isPolling = false;
  183. } else if (status.job.status === "cancelled") {
  184. setError("Upscaling was cancelled");
  185. setLoading(false);
  186. isPolling = false;
  187. } else if (attempts < maxAttempts) {
  188. attempts++;
  189. timeoutId = setTimeout(poll, 2000);
  190. } else {
  191. setError("Job polling timeout");
  192. setLoading(false);
  193. isPolling = false;
  194. }
  195. } catch (err) {
  196. if (isPolling) {
  197. setError(
  198. err instanceof Error ? err.message : "Failed to check job status",
  199. );
  200. setLoading(false);
  201. isPolling = false;
  202. }
  203. }
  204. };
  205. poll();
  206. // Return cleanup function
  207. return () => {
  208. isPolling = false;
  209. if (timeoutId) {
  210. clearTimeout(timeoutId);
  211. }
  212. };
  213. };
  214. const handleUpscale = async (e: React.FormEvent) => {
  215. e.preventDefault();
  216. if (!uploadedImage) {
  217. setError("Please upload an image first");
  218. return;
  219. }
  220. setLoading(true);
  221. setError(null);
  222. setGeneratedImages([]);
  223. setJobInfo(null);
  224. try {
  225. // Validate model selection
  226. if (!selectedUpscalerModel) {
  227. setError("Please select an upscaler model");
  228. setLoading(false);
  229. return;
  230. }
  231. // Note: You may need to adjust the API endpoint based on your backend implementation
  232. const job = await apiClient.generateImage({
  233. prompt: `upscale ${formData.upscale_factor}x`,
  234. model: selectedUpscalerModel,
  235. image: uploadedImage,
  236. // Add upscale-specific parameters here based on your API
  237. } as any);
  238. setJobInfo(job);
  239. const jobId = job.request_id || job.id;
  240. if (jobId) {
  241. const cleanup = pollJobStatus(jobId);
  242. setPollCleanup(() => cleanup);
  243. } else {
  244. setError("No job ID returned from server");
  245. setLoading(false);
  246. }
  247. } catch (err) {
  248. setError(err instanceof Error ? err.message : "Failed to upscale image");
  249. setLoading(false);
  250. }
  251. };
  252. const handleCancel = async () => {
  253. const jobId = jobInfo?.request_id || jobInfo?.id;
  254. if (jobId) {
  255. try {
  256. await apiClient.cancelJob(jobId);
  257. setLoading(false);
  258. setError("Upscaling cancelled");
  259. // Cleanup polling
  260. if (pollCleanup) {
  261. pollCleanup();
  262. setPollCleanup(null);
  263. }
  264. } catch (err) {
  265. console.error("Failed to cancel job:", err);
  266. }
  267. }
  268. };
  269. return (
  270. <AppLayout>
  271. <Header
  272. title="Upscaler"
  273. description="Enhance and upscale your images with AI"
  274. />
  275. <div className="container mx-auto p-6">
  276. <div className="grid gap-6 lg:grid-cols-2">
  277. {/* Left Panel - Form */}
  278. <Card>
  279. <CardContent className="pt-6">
  280. <form onSubmit={handleUpscale} className="space-y-4">
  281. <div className="space-y-2">
  282. <Label>Source Image *</Label>
  283. <div className="space-y-4">
  284. {previewImage && (
  285. <div className="relative">
  286. <img
  287. src={previewImage}
  288. alt="Preview"
  289. className="h-64 w-full rounded-lg object-cover"
  290. />
  291. <Button
  292. type="button"
  293. variant="destructive"
  294. size="icon"
  295. className="absolute top-2 right-2 h-8 w-8"
  296. onClick={() => {
  297. setPreviewImage(null);
  298. setUploadedImage("");
  299. }}
  300. >
  301. <X className="h-4 w-4" />
  302. </Button>
  303. </div>
  304. )}
  305. <div className="space-y-2">
  306. <div className="flex items-center justify-center">
  307. <input
  308. ref={fileInputRef}
  309. type="file"
  310. accept="image/*"
  311. onChange={handleImageUpload}
  312. className="hidden"
  313. />
  314. <Button
  315. type="button"
  316. variant="outline"
  317. onClick={() => fileInputRef.current?.click()}
  318. >
  319. <Upload className="mr-2 h-4 w-4" />
  320. Choose Image
  321. </Button>
  322. </div>
  323. </div>
  324. </div>
  325. </div>
  326. <div className="space-y-2">
  327. <Label>Upscaling Factor</Label>
  328. <select
  329. value={formData.upscale_factor}
  330. onChange={(e) =>
  331. setFormData((prev) => ({
  332. ...prev,
  333. upscale_factor: Number(e.target.value),
  334. }))
  335. }
  336. className="w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  337. >
  338. <option value={2}>2x (Double)</option>
  339. <option value={3}>3x (Triple)</option>
  340. <option value={4}>4x (Quadruple)</option>
  341. </select>
  342. </div>
  343. <div className="space-y-2">
  344. <Label>Upscaler Model</Label>
  345. <Select
  346. value={formData.model}
  347. onValueChange={(value) => {
  348. setFormData((prev) => ({ ...prev, model: value }));
  349. setSelectedUpscalerModel(value);
  350. setUpscalerUserOverride(value);
  351. }}
  352. >
  353. <SelectTrigger>
  354. <SelectValue placeholder="Select an upscaler model" />
  355. </SelectTrigger>
  356. <SelectContent>
  357. {upscalerModels.map((model) => (
  358. <SelectItem key={model.name} value={model.name}>
  359. {model.name}
  360. </SelectItem>
  361. ))}
  362. </SelectContent>
  363. </Select>
  364. </div>
  365. <div className="flex gap-2">
  366. <Button
  367. type="submit"
  368. disabled={loading || !uploadedImage}
  369. className="flex-1"
  370. >
  371. {loading ? (
  372. <>
  373. <Loader2 className="h-4 w-4 animate-spin" />
  374. Upscaling...
  375. </>
  376. ) : (
  377. "Upscale"
  378. )}
  379. </Button>
  380. {loading && (
  381. <Button
  382. type="button"
  383. variant="destructive"
  384. onClick={handleCancel}
  385. >
  386. <X className="h-4 w-4" />
  387. Cancel
  388. </Button>
  389. )}
  390. </div>
  391. {error && (
  392. <div className="rounded-md bg-destructive/10 p-3 text-sm text-destructive">
  393. {error}
  394. </div>
  395. )}
  396. </form>
  397. </CardContent>
  398. </Card>
  399. <Card>
  400. <CardContent className="pt-6">
  401. <div className="space-y-4">
  402. <h3 className="text-lg font-semibold">Upscaled Image</h3>
  403. {generatedImages.length === 0 ? (
  404. <div className="flex h-96 items-center justify-center rounded-lg border-2 border-dashed border-border">
  405. <p className="text-muted-foreground">
  406. {loading
  407. ? "Upscaling..."
  408. : "Upscaled image will appear here"}
  409. </p>
  410. </div>
  411. ) : (
  412. <div className="grid gap-4">
  413. {generatedImages.map((image, index) => (
  414. <div key={index} className="relative group">
  415. <img
  416. src={image}
  417. alt={`Upscaled ${index + 1}`}
  418. className="w-full rounded-lg border border-border"
  419. />
  420. <Button
  421. size="icon"
  422. variant="secondary"
  423. className="absolute top-2 right-2 opacity-0 group-hover:opacity-100 transition-opacity"
  424. onClick={() => {
  425. const authToken =
  426. localStorage.getItem("auth_token");
  427. const unixUser = localStorage.getItem("unix_user");
  428. downloadAuthenticatedImage(
  429. image,
  430. `upscaled-${Date.now()}-${formData.upscale_factor}x.png`,
  431. authToken || undefined,
  432. unixUser || undefined,
  433. ).catch((err) => {
  434. console.error("Failed to download image:", err);
  435. // Fallback to regular download if authenticated download fails
  436. downloadImage(
  437. image,
  438. `upscaled-${Date.now()}-${formData.upscale_factor}x.png`,
  439. );
  440. });
  441. }}
  442. >
  443. <Download className="h-4 w-4" />
  444. </Button>
  445. </div>
  446. ))}
  447. </div>
  448. )}
  449. </div>
  450. </CardContent>
  451. </Card>
  452. </div>
  453. </div>
  454. </AppLayout>
  455. );
  456. }
  457. export default function UpscalerPage() {
  458. return <UpscalerForm />;
  459. }