| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385 |
- "use client";
- import { useState, useEffect } from "react";
- import { Header, AppLayout } from "@/components/layout";
- import { Button } from "@/components/ui/button";
- import { Input } from "@/components/ui/input";
- import { Label } from "@/components/ui/label";
- import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
- import { Badge } from "@/components/ui/badge";
- import { Alert, AlertDescription } from "@/components/ui/alert";
- import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
- import {
- Select,
- SelectContent,
- SelectItem,
- SelectTrigger,
- SelectValue,
- } from "@/components/ui/select";
- import {
- apiClient,
- type ModelInfo,
- type EnhancedModelsResponse,
- } from "@/lib/api";
- import {
- Loader2,
- Download,
- Upload,
- CheckCircle,
- XCircle,
- AlertCircle,
- Copy,
- RefreshCw,
- Package,
- Hash
- } from "lucide-react";
- export default function ModelsPage() {
- const [models, setModels] = useState<ModelInfo[]>([]);
- const [loading, setLoading] = useState(false);
- const [error, setError] = useState<string | null>(null);
- const [success, setSuccess] = useState<string | null>(null);
- const [activeTab, setActiveTab] = useState("all");
- const [modelTypeFilter, setModelTypeFilter] = useState("all");
- const [searchQuery, setSearchQuery] = useState("");
- const [loadingModels, setLoadingModels] = useState<Set<string>>(new Set());
- const [computingHashes, setComputingHashes] = useState<Set<string>>(new Set());
- // Load models on component mount
- useEffect(() => {
- loadModels();
- }, []);
- const loadModels = async () => {
- setLoading(true);
- setError(null);
-
- try {
- const response = await apiClient.getModels();
- setModels(response.models);
- } catch (err) {
- setError(err instanceof Error ? err.message : "Failed to load models");
- } finally {
- setLoading(false);
- }
- };
- const refreshModels = async () => {
- setLoading(true);
- setError(null);
-
- try {
- await apiClient.scanModels();
- // After scanning, reload the models list
- const response = await apiClient.getModels();
- setModels(response.models);
- setSuccess("Models refreshed successfully");
- } catch (err) {
- setError(err instanceof Error ? err.message : "Failed to refresh models");
- } finally {
- setLoading(false);
- }
- };
- const loadModel = async (modelId: string, modelName: string) => {
- setLoadingModels(prev => new Set(prev).add(modelId));
- setError(null);
- setSuccess(null);
-
- try {
- await apiClient.loadModel(modelId);
- setSuccess(`Model "${modelName}" loaded successfully`);
-
- // Update the model's loaded status
- setModels(prev => prev.map(model =>
- model.id === modelId || model.sha256 === modelId || model.sha256_short === modelId
- ? { ...model, loaded: true }
- : model
- ));
- } catch (err) {
- const errorMessage = err instanceof Error ? err.message : "Failed to load model";
-
- // Check if it's a hash validation error
- if (errorMessage.includes("INVALID_MODEL_IDENTIFIER") || errorMessage.includes("MODEL_NOT_FOUND")) {
- setError(`Failed to load model: ${errorMessage}. Please ensure you're using the model hash instead of the name.`);
- } else {
- setError(`Failed to load model "${modelName}": ${errorMessage}`);
- }
- } finally {
- setLoadingModels(prev => {
- const newSet = new Set(prev);
- newSet.delete(modelId);
- return newSet;
- });
- }
- };
- const unloadModel = async (modelId: string, modelName: string) => {
- setLoadingModels(prev => new Set(prev).add(modelId));
- setError(null);
- setSuccess(null);
-
- try {
- await apiClient.unloadModel(modelId);
- setSuccess(`Model "${modelName}" unloaded successfully`);
-
- // Update the model's loaded status
- setModels(prev => prev.map(model =>
- model.id === modelId || model.sha256 === modelId || model.sha256_short === modelId
- ? { ...model, loaded: false }
- : model
- ));
- } catch (err) {
- setError(err instanceof Error ? err.message : "Failed to unload model");
- } finally {
- setLoadingModels(prev => {
- const newSet = new Set(prev);
- newSet.delete(modelId);
- return newSet;
- });
- }
- };
- const computeModelHash = async (modelId: string, modelName: string) => {
- setComputingHashes(prev => new Set(prev).add(modelId));
- setError(null);
- setSuccess(null);
-
- try {
- const result = await apiClient.computeModelHash(modelId);
- setSuccess(`Hash computation started for "${modelName}". Request ID: ${result.request_id}`);
-
- // Refresh models after a delay to get updated hash information
- setTimeout(() => {
- loadModels();
- }, 2000);
- } catch (err) {
- setError(err instanceof Error ? err.message : "Failed to compute model hash");
- } finally {
- setComputingHashes(prev => {
- const newSet = new Set(prev);
- newSet.delete(modelId);
- return newSet;
- });
- }
- };
- const copyToClipboard = (text: string) => {
- navigator.clipboard.writeText(text).then(() => {
- setSuccess("Copied to clipboard");
- setTimeout(() => setSuccess(null), 2000);
- });
- };
- // Filter models based on active tab and search query
- const filteredModels = models.filter(model => {
- const matchesTab = activeTab === "all" ||
- (activeTab === "loaded" && model.loaded) ||
- (activeTab === "unloaded" && !model.loaded);
-
- const matchesType = modelTypeFilter === "all" || model.type === modelTypeFilter;
-
- const matchesSearch = searchQuery === "" ||
- model.name.toLowerCase().includes(searchQuery.toLowerCase()) ||
- (model.sha256_short && model.sha256_short.toLowerCase().includes(searchQuery.toLowerCase()));
-
- return matchesTab && matchesType && matchesSearch;
- });
- // Get unique model types for filter dropdown
- const modelTypes = Array.from(new Set(models.map(model => model.type))).sort();
- const getModelIdentifier = (model: ModelInfo) => {
- // Prefer hash over name for model identification
- return model.sha256_short || model.sha256 || model.id || model.name;
- };
- return (
- <AppLayout>
- <Header
- title="Model Management"
- description="Load, unload, and manage your AI models"
- />
- <div className="container mx-auto p-6 space-y-6">
- {error && (
- <Alert variant="destructive">
- <AlertCircle className="h-4 w-4" />
- <AlertDescription>{error}</AlertDescription>
- </Alert>
- )}
-
- {success && (
- <Alert>
- <CheckCircle className="h-4 w-4" />
- <AlertDescription>{success}</AlertDescription>
- </Alert>
- )}
- <Card>
- <CardHeader>
- <div className="flex items-center justify-between">
- <div>
- <CardTitle>Models</CardTitle>
- <CardDescription>
- Manage your AI models - use hashes for reliable identification
- </CardDescription>
- </div>
- <Button onClick={refreshModels} disabled={loading}>
- <RefreshCw className={`h-4 w-4 mr-2 ${loading ? 'animate-spin' : ''}`} />
- Refresh
- </Button>
- </div>
- </CardHeader>
- <CardContent>
- <div className="flex gap-4 mb-4">
- <div className="flex-1 max-w-md">
- <Label htmlFor="search" className="sr-only">Search models</Label>
- <Input
- id="search"
- placeholder="Search models by name or hash..."
- value={searchQuery}
- onChange={(e) => setSearchQuery(e.target.value)}
- />
- </div>
- <div className="w-64">
- <Label htmlFor="type-filter" className="sr-only">Filter by type</Label>
- <Select value={modelTypeFilter} onValueChange={setModelTypeFilter}>
- <SelectTrigger id="type-filter">
- <SelectValue placeholder="Filter by type" />
- </SelectTrigger>
- <SelectContent>
- <SelectItem value="all">All Types</SelectItem>
- {modelTypes.map(type => (
- <SelectItem key={type} value={type}>{type}</SelectItem>
- ))}
- </SelectContent>
- </Select>
- </div>
- </div>
-
- <Tabs value={activeTab} onValueChange={setActiveTab}>
- <TabsList>
- <TabsTrigger value="all">All Models ({models.length})</TabsTrigger>
- <TabsTrigger value="loaded">Loaded ({models.filter(m => m.loaded).length})</TabsTrigger>
- <TabsTrigger value="unloaded">Unloaded ({models.filter(m => !m.loaded).length})</TabsTrigger>
- </TabsList>
-
- <TabsContent value={activeTab} className="mt-4">
- {loading ? (
- <div className="flex justify-center py-8">
- <Loader2 className="h-8 w-8 animate-spin" />
- </div>
- ) : filteredModels.length === 0 ? (
- <div className="text-center py-8 text-muted-foreground">
- No models found matching your criteria
- </div>
- ) : (
- <div className="grid gap-4">
- {filteredModels.map((model) => (
- <Card key={getModelIdentifier(model)} className="p-4">
- <div className="flex items-start justify-between">
- <div className="flex-1 space-y-2">
- <div className="flex items-center gap-2">
- <h3 className="font-semibold">{model.name}</h3>
- <Badge variant={model.loaded ? "default" : "secondary"}>
- {model.loaded ? "Loaded" : "Unloaded"}
- </Badge>
- <Badge variant="outline">{model.type}</Badge>
- </div>
-
- <div className="space-y-1 text-sm text-muted-foreground">
- <div className="flex items-center gap-2">
- <Hash className="h-3 w-3" />
- <span className="font-mono">
- {model.sha256_short || "No hash"}
- </span>
- {model.sha256_short && (
- <Button
- variant="ghost"
- size="sm"
- onClick={() => copyToClipboard(model.sha256_short!)}
- >
- <Copy className="h-3 w-3" />
- </Button>
- )}
- </div>
-
- {model.file_size_mb && (
- <div>Size: {model.file_size_mb.toFixed(2)} MB</div>
- )}
-
- {model.architecture && (
- <div>Architecture: {model.architecture}</div>
- )}
- </div>
-
- {!model.sha256_short && (
- <Alert>
- <AlertCircle className="h-4 w-4" />
- <AlertDescription>
- This model doesn't have a computed hash. Compute the hash to use it with the hash-only loading system.
- </AlertDescription>
- </Alert>
- )}
- </div>
-
- <div className="flex gap-2 ml-4">
- {!model.sha256_short && (
- <Button
- variant="outline"
- size="sm"
- onClick={() => computeModelHash(getModelIdentifier(model), model.name)}
- disabled={computingHashes.has(getModelIdentifier(model))}
- >
- {computingHashes.has(getModelIdentifier(model)) ? (
- <Loader2 className="h-4 w-4 animate-spin" />
- ) : (
- <Package className="h-4 w-4 mr-1" />
- )}
- Compute Hash
- </Button>
- )}
-
- {model.loaded ? (
- <Button
- variant="destructive"
- size="sm"
- onClick={() => unloadModel(getModelIdentifier(model), model.name)}
- disabled={loadingModels.has(getModelIdentifier(model))}
- >
- {loadingModels.has(getModelIdentifier(model)) ? (
- <Loader2 className="h-4 w-4 animate-spin" />
- ) : (
- <XCircle className="h-4 w-4 mr-1" />
- )}
- Unload
- </Button>
- ) : (
- <Button
- variant="default"
- size="sm"
- onClick={() => loadModel(getModelIdentifier(model), model.name)}
- disabled={loadingModels.has(getModelIdentifier(model)) || !model.sha256_short}
- >
- {loadingModels.has(getModelIdentifier(model)) ? (
- <Loader2 className="h-4 w-4 animate-spin" />
- ) : (
- <Download className="h-4 w-4 mr-1" />
- )}
- Load
- </Button>
- )}
- </div>
- </div>
- </Card>
- ))}
- </div>
- )}
- </TabsContent>
- </Tabs>
- </CardContent>
- </Card>
- </div>
- </AppLayout>
- );
- }
|