page.tsx 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596
  1. 'use client';
  2. import { useState, useEffect } from 'react';
  3. import { Header } from '@/components/header';
  4. import { AppLayout } from '@/components/layout';
  5. import { Button } from '@/components/ui/button';
  6. import { Input } from '@/components/ui/input';
  7. import { Textarea } from '@/components/ui/textarea';
  8. import { PromptTextarea } from '@/components/prompt-textarea';
  9. import { Label } from '@/components/ui/label';
  10. import { Card, CardContent } from '@/components/ui/card';
  11. import { apiClient, type GenerationRequest, type JobInfo, type ModelInfo } from '@/lib/api';
  12. import { Loader2, Download, X, Trash2, RotateCcw, Power } from 'lucide-react';
  13. import { downloadImage, downloadAuthenticatedImage } from '@/lib/utils';
  14. import { useLocalStorage } from '@/lib/hooks';
  15. import { ModelSelectionProvider, useModelSelection, useCheckpointSelection, useModelTypeSelection } from '@/contexts/model-selection-context';
  16. import { EnhancedModelSelect, EnhancedModelSelectGroup } from '@/components/enhanced-model-select';
  17. import { ModelSelectionWarning, AutoSelectionStatus } from '@/components/model-selection-indicator';
  18. const defaultFormData: GenerationRequest = {
  19. prompt: '',
  20. negative_prompt: '',
  21. width: 512,
  22. height: 512,
  23. steps: 20,
  24. cfg_scale: 7.5,
  25. seed: '',
  26. sampling_method: 'euler_a',
  27. scheduler: 'default',
  28. batch_count: 1,
  29. };
  30. function Text2ImgForm() {
  31. const { state, actions } = useModelSelection();
  32. const {
  33. checkpointModels,
  34. selectedCheckpointModel,
  35. selectedCheckpoint,
  36. setSelectedCheckpoint,
  37. isAutoSelecting,
  38. warnings,
  39. error: checkpointError
  40. } = useCheckpointSelection();
  41. const {
  42. availableModels: vaeModels,
  43. selectedModel: selectedVae,
  44. isUserOverride: isVaeUserOverride,
  45. isAutoSelected: isVaeAutoSelected,
  46. setSelectedModel: setSelectedVae,
  47. setUserOverride: setVaeUserOverride,
  48. clearUserOverride: clearVaeUserOverride,
  49. } = useModelTypeSelection('vae');
  50. const [formData, setFormData] = useLocalStorage<GenerationRequest>(
  51. 'text2img-form-data',
  52. defaultFormData
  53. );
  54. const [loading, setLoading] = useState(false);
  55. const [jobInfo, setJobInfo] = useState<JobInfo | null>(null);
  56. const [generatedImages, setGeneratedImages] = useState<string[]>([]);
  57. const [samplers, setSamplers] = useState<Array<{ name: string; description: string }>>([]);
  58. const [schedulers, setSchedulers] = useState<Array<{ name: string; description: string }>>([]);
  59. const [loraModels, setLoraModels] = useState<string[]>([]);
  60. const [embeddings, setEmbeddings] = useState<string[]>([]);
  61. const [error, setError] = useState<string | null>(null);
  62. useEffect(() => {
  63. const loadOptions = async () => {
  64. try {
  65. const [samplersData, schedulersData, modelsData, loras, embeds] = await Promise.all([
  66. apiClient.getSamplers(),
  67. apiClient.getSchedulers(),
  68. apiClient.getModels(), // Get all models with enhanced info
  69. apiClient.getModels('lora'),
  70. apiClient.getModels('embedding'),
  71. ]);
  72. setSamplers(samplersData);
  73. setSchedulers(schedulersData);
  74. actions.setModels(modelsData.models);
  75. setLoraModels(loras.models.map(m => m.name));
  76. setEmbeddings(embeds.models.map(m => m.name));
  77. } catch (err) {
  78. console.error('Failed to load options:', err);
  79. }
  80. };
  81. loadOptions();
  82. }, [actions]);
  83. // Update form data when checkpoint changes
  84. useEffect(() => {
  85. if (selectedCheckpoint) {
  86. setFormData(prev => ({
  87. ...prev,
  88. model: selectedCheckpoint,
  89. }));
  90. }
  91. }, [selectedCheckpoint, setFormData]);
  92. const handleInputChange = (
  93. e: React.ChangeEvent<HTMLInputElement | HTMLTextAreaElement | HTMLSelectElement>
  94. ) => {
  95. const { name, value } = e.target;
  96. setFormData((prev) => ({
  97. ...prev,
  98. [name]: name === 'prompt' || name === 'negative_prompt' || name === 'seed' || name === 'sampling_method' || name === 'scheduler'
  99. ? value
  100. : Number(value),
  101. }));
  102. };
  103. const pollJobStatus = async (jobId: string) => {
  104. const maxAttempts = 300; // 5 minutes with 1 second interval
  105. let attempts = 0;
  106. const poll = async () => {
  107. try {
  108. const status = await apiClient.getJobStatus(jobId);
  109. setJobInfo(status);
  110. if (status.status === 'completed') {
  111. let imageUrls: string[] = [];
  112. // Handle both old format (result.images) and new format (outputs)
  113. if (status.outputs && status.outputs.length > 0) {
  114. // New format: convert output URLs to authenticated image URLs with cache-busting
  115. imageUrls = status.outputs.map((output: any) => {
  116. const filename = output.filename;
  117. return apiClient.getImageUrl(jobId, filename);
  118. });
  119. } else if (status.result?.images && status.result.images.length > 0) {
  120. // Old format: convert image URLs to authenticated URLs
  121. imageUrls = status.result.images.map((imageUrl: string) => {
  122. // Extract filename from URL if it's already a full URL
  123. if (imageUrl.includes('/output/')) {
  124. const parts = imageUrl.split('/output/');
  125. if (parts.length === 2) {
  126. const filename = parts[1].split('?')[0]; // Remove query params
  127. return apiClient.getImageUrl(jobId, filename);
  128. }
  129. }
  130. // If it's just a filename, convert it directly
  131. return apiClient.getImageUrl(jobId, imageUrl);
  132. });
  133. }
  134. // Create a new array to trigger React re-render
  135. setGeneratedImages([...imageUrls]);
  136. setLoading(false);
  137. } else if (status.status === 'failed') {
  138. setError(status.error || 'Generation failed');
  139. setLoading(false);
  140. } else if (status.status === 'cancelled') {
  141. setError('Generation was cancelled');
  142. setLoading(false);
  143. } else if (attempts < maxAttempts) {
  144. attempts++;
  145. setTimeout(poll, 2000);
  146. } else {
  147. setError('Job polling timeout');
  148. setLoading(false);
  149. }
  150. } catch (err) {
  151. setError(err instanceof Error ? err.message : 'Failed to check job status');
  152. setLoading(false);
  153. }
  154. };
  155. poll();
  156. };
  157. const handleGenerate = async (e: React.FormEvent) => {
  158. e.preventDefault();
  159. setLoading(true);
  160. setError(null);
  161. setGeneratedImages([]);
  162. setJobInfo(null);
  163. try {
  164. // Validate model selection
  165. if (selectedCheckpointModel) {
  166. const validation = actions.validateSelection(selectedCheckpointModel);
  167. if (!validation.isValid) {
  168. setError(`Missing required models: ${validation.missingRequired.join(', ')}`);
  169. setLoading(false);
  170. return;
  171. }
  172. }
  173. const requestData = {
  174. ...formData,
  175. model: selectedCheckpoint || undefined,
  176. vae: selectedVae || undefined,
  177. };
  178. const job = await apiClient.text2img(requestData);
  179. setJobInfo(job);
  180. const jobId = job.request_id || job.id;
  181. if (jobId) {
  182. await pollJobStatus(jobId);
  183. } else {
  184. setError('No job ID returned from server');
  185. setLoading(false);
  186. }
  187. } catch (err) {
  188. setError(err instanceof Error ? err.message : 'Failed to generate image');
  189. setLoading(false);
  190. }
  191. };
  192. const handleCancel = async () => {
  193. const jobId = jobInfo?.request_id || jobInfo?.id;
  194. if (jobId) {
  195. try {
  196. await apiClient.cancelJob(jobId);
  197. setLoading(false);
  198. setError('Generation cancelled');
  199. } catch (err) {
  200. console.error('Failed to cancel job:', err);
  201. }
  202. }
  203. };
  204. const handleClearPrompts = () => {
  205. setFormData({ ...formData, prompt: '', negative_prompt: '' });
  206. };
  207. const handleResetToDefaults = () => {
  208. setFormData(defaultFormData);
  209. setSelectedCheckpoint(null);
  210. setSelectedVae('');
  211. actions.resetSelection();
  212. };
  213. const handleServerRestart = async () => {
  214. if (!confirm('Are you sure you want to restart the server? This will cancel all running jobs.')) {
  215. return;
  216. }
  217. try {
  218. setLoading(true);
  219. await apiClient.restartServer();
  220. setError('Server restart initiated. Please wait...');
  221. setTimeout(() => {
  222. window.location.reload();
  223. }, 3000);
  224. } catch (err) {
  225. setError(err instanceof Error ? err.message : 'Failed to restart server');
  226. setLoading(false);
  227. }
  228. };
  229. return (
  230. <AppLayout>
  231. <Header title="Text to Image" description="Generate images from text prompts" />
  232. <div className="container mx-auto p-6">
  233. <div className="grid gap-6 lg:grid-cols-2">
  234. {/* Left Panel - Form */}
  235. <Card>
  236. <CardContent className="pt-6">
  237. <form onSubmit={handleGenerate} className="space-y-4">
  238. <div className="space-y-2">
  239. <Label htmlFor="prompt">Prompt *</Label>
  240. <PromptTextarea
  241. value={formData.prompt}
  242. onChange={(value) => setFormData({ ...formData, prompt: value })}
  243. placeholder="a beautiful landscape with mountains and a lake, sunset, highly detailed..."
  244. rows={4}
  245. loras={loraModels}
  246. embeddings={embeddings}
  247. />
  248. <p className="text-xs text-muted-foreground">
  249. Tip: Use &lt;lora:name:weight&gt; for LoRAs (e.g., &lt;lora:myLora:0.8&gt;) and embedding names directly
  250. </p>
  251. </div>
  252. <div className="space-y-2">
  253. <Label htmlFor="negative_prompt">Negative Prompt</Label>
  254. <PromptTextarea
  255. value={formData.negative_prompt || ''}
  256. onChange={(value) => setFormData({ ...formData, negative_prompt: value })}
  257. placeholder="blurry, low quality, distorted..."
  258. rows={2}
  259. loras={loraModels}
  260. embeddings={embeddings}
  261. />
  262. </div>
  263. {/* Utility Buttons */}
  264. <div className="flex gap-2">
  265. <Button
  266. type="button"
  267. variant="outline"
  268. size="sm"
  269. onClick={handleClearPrompts}
  270. disabled={loading}
  271. title="Clear both prompts"
  272. >
  273. <Trash2 className="h-4 w-4 mr-1" />
  274. Clear Prompts
  275. </Button>
  276. <Button
  277. type="button"
  278. variant="outline"
  279. size="sm"
  280. onClick={handleResetToDefaults}
  281. disabled={loading}
  282. title="Reset all fields to defaults"
  283. >
  284. <RotateCcw className="h-4 w-4 mr-1" />
  285. Reset to Defaults
  286. </Button>
  287. <Button
  288. type="button"
  289. variant="outline"
  290. size="sm"
  291. onClick={handleServerRestart}
  292. disabled={loading}
  293. title="Restart the backend server"
  294. >
  295. <Power className="h-4 w-4 mr-1" />
  296. Restart Server
  297. </Button>
  298. </div>
  299. <div className="grid grid-cols-2 gap-4">
  300. <div className="space-y-2">
  301. <Label htmlFor="width">Width</Label>
  302. <Input
  303. id="width"
  304. name="width"
  305. type="number"
  306. value={formData.width}
  307. onChange={handleInputChange}
  308. step={64}
  309. min={256}
  310. max={2048}
  311. />
  312. </div>
  313. <div className="space-y-2">
  314. <Label htmlFor="height">Height</Label>
  315. <Input
  316. id="height"
  317. name="height"
  318. type="number"
  319. value={formData.height}
  320. onChange={handleInputChange}
  321. step={64}
  322. min={256}
  323. max={2048}
  324. />
  325. </div>
  326. </div>
  327. <div className="grid grid-cols-2 gap-4">
  328. <div className="space-y-2">
  329. <Label htmlFor="steps">Steps</Label>
  330. <Input
  331. id="steps"
  332. name="steps"
  333. type="number"
  334. value={formData.steps}
  335. onChange={handleInputChange}
  336. min={1}
  337. max={150}
  338. />
  339. </div>
  340. <div className="space-y-2">
  341. <Label htmlFor="cfg_scale">CFG Scale</Label>
  342. <Input
  343. id="cfg_scale"
  344. name="cfg_scale"
  345. type="number"
  346. value={formData.cfg_scale}
  347. onChange={handleInputChange}
  348. step={0.5}
  349. min={1}
  350. max={30}
  351. />
  352. </div>
  353. </div>
  354. <div className="space-y-2">
  355. <Label htmlFor="seed">Seed (optional)</Label>
  356. <Input
  357. id="seed"
  358. name="seed"
  359. value={formData.seed}
  360. onChange={handleInputChange}
  361. placeholder="Leave empty for random"
  362. />
  363. </div>
  364. <div className="space-y-2">
  365. <Label htmlFor="sampling_method">Sampling Method</Label>
  366. <select
  367. id="sampling_method"
  368. name="sampling_method"
  369. value={formData.sampling_method}
  370. onChange={handleInputChange}
  371. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  372. >
  373. {samplers.length > 0 ? (
  374. samplers.map((sampler) => (
  375. <option key={sampler.name} value={sampler.name}>
  376. {sampler.name.toUpperCase()} - {sampler.description}
  377. </option>
  378. ))
  379. ) : (
  380. <option value="euler_a">Loading...</option>
  381. )}
  382. </select>
  383. </div>
  384. <div className="space-y-2">
  385. <Label htmlFor="scheduler">Scheduler</Label>
  386. <select
  387. id="scheduler"
  388. name="scheduler"
  389. value={formData.scheduler}
  390. onChange={handleInputChange}
  391. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  392. >
  393. {schedulers.length > 0 ? (
  394. schedulers.map((scheduler) => (
  395. <option key={scheduler.name} value={scheduler.name}>
  396. {scheduler.name.toUpperCase()} - {scheduler.description}
  397. </option>
  398. ))
  399. ) : (
  400. <option value="default">Loading...</option>
  401. )}
  402. </select>
  403. </div>
  404. {/* Model Selection Section */}
  405. <EnhancedModelSelectGroup
  406. title="Model Selection"
  407. description="Select the checkpoint and additional models for generation"
  408. >
  409. {/* Checkpoint Selection */}
  410. <div className="space-y-2">
  411. <Label htmlFor="checkpoint">Checkpoint Model *</Label>
  412. <select
  413. id="checkpoint"
  414. value={selectedCheckpoint || ''}
  415. onChange={(e) => setSelectedCheckpoint(e.target.value || null)}
  416. className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
  417. disabled={isAutoSelecting}
  418. >
  419. <option value="">Select a checkpoint model...</option>
  420. {checkpointModels.map((model) => (
  421. <option key={model.id} value={model.name}>
  422. {model.name} {model.loaded ? '(Loaded)' : ''}
  423. </option>
  424. ))}
  425. </select>
  426. </div>
  427. {/* VAE Selection */}
  428. <EnhancedModelSelect
  429. modelType="vae"
  430. label="VAE Model"
  431. description="Optional VAE model for improved image quality"
  432. value={selectedVae}
  433. availableModels={vaeModels}
  434. isAutoSelected={isVaeAutoSelected}
  435. isUserOverride={isVaeUserOverride}
  436. isLoading={isAutoSelecting}
  437. onValueChange={setSelectedVae}
  438. onSetUserOverride={setVaeUserOverride}
  439. onClearOverride={clearVaeUserOverride}
  440. placeholder="Use default VAE"
  441. />
  442. {/* Auto-selection Status */}
  443. <div className="pt-2">
  444. <AutoSelectionStatus
  445. isAutoSelecting={isAutoSelecting}
  446. hasAutoSelection={Object.keys(state.autoSelectedModels).length > 0}
  447. />
  448. </div>
  449. {/* Warnings and Errors */}
  450. <ModelSelectionWarning
  451. warnings={warnings}
  452. errors={error ? [error] : []}
  453. onClearWarnings={actions.clearWarnings}
  454. />
  455. </EnhancedModelSelectGroup>
  456. <div className="space-y-2">
  457. <Label htmlFor="batch_count">Batch Count</Label>
  458. <Input
  459. id="batch_count"
  460. name="batch_count"
  461. type="number"
  462. value={formData.batch_count}
  463. onChange={handleInputChange}
  464. min={1}
  465. max={4}
  466. />
  467. </div>
  468. <div className="flex gap-2">
  469. <Button type="submit" disabled={loading} className="flex-1">
  470. {loading ? (
  471. <>
  472. <Loader2 className="h-4 w-4 animate-spin" />
  473. Generating...
  474. </>
  475. ) : (
  476. 'Generate'
  477. )}
  478. </Button>
  479. {loading && (
  480. <Button type="button" variant="destructive" onClick={handleCancel}>
  481. <X className="h-4 w-4" />
  482. Cancel
  483. </Button>
  484. )}
  485. </div>
  486. {error && (
  487. <div className="rounded-md bg-destructive/10 p-3 text-sm text-destructive">
  488. {error}
  489. </div>
  490. )}
  491. {loading && jobInfo && (
  492. <div className="rounded-md bg-muted p-3 text-sm">
  493. <p>Job ID: {jobInfo.id || jobInfo.request_id || 'N/A'}</p>
  494. <p>Status: {jobInfo.status}</p>
  495. {jobInfo.progress !== undefined && (
  496. <p>Progress: {Math.round(jobInfo.progress * 100)}%</p>
  497. )}
  498. </div>
  499. )}
  500. </form>
  501. </CardContent>
  502. </Card>
  503. {/* Right Panel - Generated Images */}
  504. <Card>
  505. <CardContent className="pt-6">
  506. <div className="space-y-4">
  507. <h3 className="text-lg font-semibold">Generated Images</h3>
  508. {generatedImages.length === 0 ? (
  509. <div className="flex h-96 items-center justify-center rounded-lg border-2 border-dashed border-border">
  510. <p className="text-muted-foreground">
  511. {loading ? 'Generating...' : 'Generated images will appear here'}
  512. </p>
  513. </div>
  514. ) : (
  515. <div className="grid gap-4">
  516. {generatedImages.map((image, index) => (
  517. <div key={index} className="relative group">
  518. <img
  519. src={image}
  520. alt={`Generated ${index + 1}`}
  521. className="w-full rounded-lg border border-border"
  522. />
  523. <Button
  524. size="icon"
  525. variant="secondary"
  526. className="absolute top-2 right-2 opacity-0 group-hover:opacity-100 transition-opacity"
  527. onClick={() => {
  528. const authToken = localStorage.getItem('auth_token');
  529. const unixUser = localStorage.getItem('unix_user');
  530. downloadAuthenticatedImage(image, `generated-${Date.now()}-${index}.png`, authToken || undefined, unixUser || undefined)
  531. .catch(err => {
  532. console.error('Failed to download image:', err);
  533. // Fallback to regular download if authenticated download fails
  534. downloadImage(image, `generated-${Date.now()}-${index}.png`);
  535. });
  536. }}
  537. >
  538. <Download className="h-4 w-4" />
  539. </Button>
  540. </div>
  541. ))}
  542. </div>
  543. )}
  544. </div>
  545. </CardContent>
  546. </Card>
  547. </div>
  548. </div>
  549. </AppLayout>
  550. );
  551. }
  552. export default function Text2ImgPage() {
  553. return <Text2ImgForm />;
  554. }