page.tsx 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. "use client";
  2. import { useState, useEffect } from "react";
  3. import { Header, AppLayout } from "@/components/layout";
  4. import { Button } from "@/components/ui/button";
  5. import { Input } from "@/components/ui/input";
  6. import { Label } from "@/components/ui/label";
  7. import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
  8. import { Badge } from "@/components/ui/badge";
  9. import { Alert, AlertDescription } from "@/components/ui/alert";
  10. import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
  11. import {
  12. Select,
  13. SelectContent,
  14. SelectItem,
  15. SelectTrigger,
  16. SelectValue,
  17. } from "@/components/ui/select";
  18. import {
  19. apiClient,
  20. type ModelInfo,
  21. type EnhancedModelsResponse,
  22. } from "@/lib/api";
  23. import {
  24. Loader2,
  25. Download,
  26. Upload,
  27. CheckCircle,
  28. XCircle,
  29. AlertCircle,
  30. Copy,
  31. RefreshCw,
  32. Package,
  33. Hash
  34. } from "lucide-react";
  35. export default function ModelsPage() {
  36. const [models, setModels] = useState<ModelInfo[]>([]);
  37. const [loading, setLoading] = useState(false);
  38. const [error, setError] = useState<string | null>(null);
  39. const [success, setSuccess] = useState<string | null>(null);
  40. const [activeTab, setActiveTab] = useState("all");
  41. const [modelTypeFilter, setModelTypeFilter] = useState("all");
  42. const [searchQuery, setSearchQuery] = useState("");
  43. const [loadingModels, setLoadingModels] = useState<Set<string>>(new Set());
  44. const [computingHashes, setComputingHashes] = useState<Set<string>>(new Set());
  45. // Load models on component mount
  46. useEffect(() => {
  47. loadModels();
  48. }, []);
  49. const loadModels = async () => {
  50. setLoading(true);
  51. setError(null);
  52. try {
  53. const response = await apiClient.getModels();
  54. setModels(response.models);
  55. } catch (err) {
  56. setError(err instanceof Error ? err.message : "Failed to load models");
  57. } finally {
  58. setLoading(false);
  59. }
  60. };
  61. const refreshModels = async () => {
  62. setLoading(true);
  63. setError(null);
  64. try {
  65. await apiClient.scanModels();
  66. // After scanning, reload the models list
  67. const response = await apiClient.getModels();
  68. setModels(response.models);
  69. setSuccess("Models refreshed successfully");
  70. } catch (err) {
  71. setError(err instanceof Error ? err.message : "Failed to refresh models");
  72. } finally {
  73. setLoading(false);
  74. }
  75. };
  76. const loadModel = async (modelId: string, modelName: string) => {
  77. setLoadingModels(prev => new Set(prev).add(modelId));
  78. setError(null);
  79. setSuccess(null);
  80. try {
  81. await apiClient.loadModel(modelId);
  82. setSuccess(`Model "${modelName}" loaded successfully`);
  83. // Update the model's loaded status
  84. setModels(prev => prev.map(model =>
  85. model.id === modelId || model.sha256 === modelId || model.sha256_short === modelId
  86. ? { ...model, loaded: true }
  87. : model
  88. ));
  89. } catch (err) {
  90. const errorMessage = err instanceof Error ? err.message : "Failed to load model";
  91. // Check if it's a hash validation error
  92. if (errorMessage.includes("INVALID_MODEL_IDENTIFIER") || errorMessage.includes("MODEL_NOT_FOUND")) {
  93. setError(`Failed to load model: ${errorMessage}. Please ensure you're using the model hash instead of the name.`);
  94. } else {
  95. setError(`Failed to load model "${modelName}": ${errorMessage}`);
  96. }
  97. } finally {
  98. setLoadingModels(prev => {
  99. const newSet = new Set(prev);
  100. newSet.delete(modelId);
  101. return newSet;
  102. });
  103. }
  104. };
  105. const unloadModel = async (modelId: string, modelName: string) => {
  106. setLoadingModels(prev => new Set(prev).add(modelId));
  107. setError(null);
  108. setSuccess(null);
  109. try {
  110. await apiClient.unloadModel(modelId);
  111. setSuccess(`Model "${modelName}" unloaded successfully`);
  112. // Update the model's loaded status
  113. setModels(prev => prev.map(model =>
  114. model.id === modelId || model.sha256 === modelId || model.sha256_short === modelId
  115. ? { ...model, loaded: false }
  116. : model
  117. ));
  118. } catch (err) {
  119. setError(err instanceof Error ? err.message : "Failed to unload model");
  120. } finally {
  121. setLoadingModels(prev => {
  122. const newSet = new Set(prev);
  123. newSet.delete(modelId);
  124. return newSet;
  125. });
  126. }
  127. };
  128. const computeModelHash = async (modelId: string, modelName: string) => {
  129. setComputingHashes(prev => new Set(prev).add(modelId));
  130. setError(null);
  131. setSuccess(null);
  132. try {
  133. const result = await apiClient.computeModelHash(modelId);
  134. setSuccess(`Hash computation started for "${modelName}". Request ID: ${result.request_id}`);
  135. // Refresh models after a delay to get updated hash information
  136. setTimeout(() => {
  137. loadModels();
  138. }, 2000);
  139. } catch (err) {
  140. setError(err instanceof Error ? err.message : "Failed to compute model hash");
  141. } finally {
  142. setComputingHashes(prev => {
  143. const newSet = new Set(prev);
  144. newSet.delete(modelId);
  145. return newSet;
  146. });
  147. }
  148. };
  149. const copyToClipboard = (text: string) => {
  150. navigator.clipboard.writeText(text).then(() => {
  151. setSuccess("Copied to clipboard");
  152. setTimeout(() => setSuccess(null), 2000);
  153. });
  154. };
  155. // Filter models based on active tab and search query
  156. const filteredModels = models.filter(model => {
  157. const matchesTab = activeTab === "all" ||
  158. (activeTab === "loaded" && model.loaded) ||
  159. (activeTab === "unloaded" && !model.loaded);
  160. const matchesType = modelTypeFilter === "all" || model.type === modelTypeFilter;
  161. const matchesSearch = searchQuery === "" ||
  162. model.name.toLowerCase().includes(searchQuery.toLowerCase()) ||
  163. (model.sha256_short && model.sha256_short.toLowerCase().includes(searchQuery.toLowerCase()));
  164. return matchesTab && matchesType && matchesSearch;
  165. });
  166. // Get unique model types for filter dropdown
  167. const modelTypes = Array.from(new Set(models.map(model => model.type))).sort();
  168. const getModelIdentifier = (model: ModelInfo) => {
  169. // Prefer hash over name for model identification
  170. return model.sha256_short || model.sha256 || model.id || model.name;
  171. };
  172. return (
  173. <AppLayout>
  174. <Header
  175. title="Model Management"
  176. description="Load, unload, and manage your AI models"
  177. />
  178. <div className="container mx-auto p-6 space-y-6">
  179. {error && (
  180. <Alert variant="destructive">
  181. <AlertCircle className="h-4 w-4" />
  182. <AlertDescription>{error}</AlertDescription>
  183. </Alert>
  184. )}
  185. {success && (
  186. <Alert>
  187. <CheckCircle className="h-4 w-4" />
  188. <AlertDescription>{success}</AlertDescription>
  189. </Alert>
  190. )}
  191. <Card>
  192. <CardHeader>
  193. <div className="flex items-center justify-between">
  194. <div>
  195. <CardTitle>Models</CardTitle>
  196. <CardDescription>
  197. Manage your AI models - use hashes for reliable identification
  198. </CardDescription>
  199. </div>
  200. <Button onClick={refreshModels} disabled={loading}>
  201. <RefreshCw className={`h-4 w-4 mr-2 ${loading ? 'animate-spin' : ''}`} />
  202. Refresh
  203. </Button>
  204. </div>
  205. </CardHeader>
  206. <CardContent>
  207. <div className="flex gap-4 mb-4">
  208. <div className="flex-1 max-w-md">
  209. <Label htmlFor="search" className="sr-only">Search models</Label>
  210. <Input
  211. id="search"
  212. placeholder="Search models by name or hash..."
  213. value={searchQuery}
  214. onChange={(e) => setSearchQuery(e.target.value)}
  215. />
  216. </div>
  217. <div className="w-64">
  218. <Label htmlFor="type-filter" className="sr-only">Filter by type</Label>
  219. <Select value={modelTypeFilter} onValueChange={setModelTypeFilter}>
  220. <SelectTrigger id="type-filter">
  221. <SelectValue placeholder="Filter by type" />
  222. </SelectTrigger>
  223. <SelectContent>
  224. <SelectItem value="all">All Types</SelectItem>
  225. {modelTypes.map(type => (
  226. <SelectItem key={type} value={type}>{type}</SelectItem>
  227. ))}
  228. </SelectContent>
  229. </Select>
  230. </div>
  231. </div>
  232. <Tabs value={activeTab} onValueChange={setActiveTab}>
  233. <TabsList>
  234. <TabsTrigger value="all">All Models ({models.length})</TabsTrigger>
  235. <TabsTrigger value="loaded">Loaded ({models.filter(m => m.loaded).length})</TabsTrigger>
  236. <TabsTrigger value="unloaded">Unloaded ({models.filter(m => !m.loaded).length})</TabsTrigger>
  237. </TabsList>
  238. <TabsContent value={activeTab} className="mt-4">
  239. {loading ? (
  240. <div className="flex justify-center py-8">
  241. <Loader2 className="h-8 w-8 animate-spin" />
  242. </div>
  243. ) : filteredModels.length === 0 ? (
  244. <div className="text-center py-8 text-muted-foreground">
  245. No models found matching your criteria
  246. </div>
  247. ) : (
  248. <div className="grid gap-4">
  249. {filteredModels.map((model) => (
  250. <Card key={getModelIdentifier(model)} className="p-4">
  251. <div className="flex items-start justify-between">
  252. <div className="flex-1 space-y-2">
  253. <div className="flex items-center gap-2">
  254. <h3 className="font-semibold">{model.name}</h3>
  255. <Badge variant={model.loaded ? "default" : "secondary"}>
  256. {model.loaded ? "Loaded" : "Unloaded"}
  257. </Badge>
  258. <Badge variant="outline">{model.type}</Badge>
  259. </div>
  260. <div className="space-y-1 text-sm text-muted-foreground">
  261. <div className="flex items-center gap-2">
  262. <Hash className="h-3 w-3" />
  263. <span className="font-mono">
  264. {model.sha256_short || "No hash"}
  265. </span>
  266. {model.sha256_short && (
  267. <Button
  268. variant="ghost"
  269. size="sm"
  270. onClick={() => copyToClipboard(model.sha256_short!)}
  271. >
  272. <Copy className="h-3 w-3" />
  273. </Button>
  274. )}
  275. </div>
  276. {model.file_size_mb && (
  277. <div>Size: {model.file_size_mb.toFixed(2)} MB</div>
  278. )}
  279. {model.architecture && (
  280. <div>Architecture: {model.architecture}</div>
  281. )}
  282. </div>
  283. {!model.sha256_short && (
  284. <Alert>
  285. <AlertCircle className="h-4 w-4" />
  286. <AlertDescription>
  287. This model doesn't have a computed hash. Compute the hash to use it with the hash-only loading system.
  288. </AlertDescription>
  289. </Alert>
  290. )}
  291. </div>
  292. <div className="flex gap-2 ml-4">
  293. {!model.sha256_short && (
  294. <Button
  295. variant="outline"
  296. size="sm"
  297. onClick={() => computeModelHash(getModelIdentifier(model), model.name)}
  298. disabled={computingHashes.has(getModelIdentifier(model))}
  299. >
  300. {computingHashes.has(getModelIdentifier(model)) ? (
  301. <Loader2 className="h-4 w-4 animate-spin" />
  302. ) : (
  303. <Package className="h-4 w-4 mr-1" />
  304. )}
  305. Compute Hash
  306. </Button>
  307. )}
  308. {model.loaded ? (
  309. <Button
  310. variant="destructive"
  311. size="sm"
  312. onClick={() => unloadModel(getModelIdentifier(model), model.name)}
  313. disabled={loadingModels.has(getModelIdentifier(model))}
  314. >
  315. {loadingModels.has(getModelIdentifier(model)) ? (
  316. <Loader2 className="h-4 w-4 animate-spin" />
  317. ) : (
  318. <XCircle className="h-4 w-4 mr-1" />
  319. )}
  320. Unload
  321. </Button>
  322. ) : (
  323. <Button
  324. variant="default"
  325. size="sm"
  326. onClick={() => loadModel(getModelIdentifier(model), model.name)}
  327. disabled={loadingModels.has(getModelIdentifier(model)) || !model.sha256_short}
  328. >
  329. {loadingModels.has(getModelIdentifier(model)) ? (
  330. <Loader2 className="h-4 w-4 animate-spin" />
  331. ) : (
  332. <Download className="h-4 w-4 mr-1" />
  333. )}
  334. Load
  335. </Button>
  336. )}
  337. </div>
  338. </div>
  339. </Card>
  340. ))}
  341. </div>
  342. )}
  343. </TabsContent>
  344. </Tabs>
  345. </CardContent>
  346. </Card>
  347. </div>
  348. </AppLayout>
  349. );
  350. }