| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247 |
- import { ModelInfo, RequiredModelInfo, RecommendedModelInfo, AutoSelectionState } from './api';
- export class AutoModelSelector {
- private models: ModelInfo[] = [];
- private cache: Map<string, AutoSelectionState> = new Map();
- constructor(models: ModelInfo[] = []) {
- this.models = models;
- }
- // Update the models list
- updateModels(models: ModelInfo[]): void {
- this.models = models;
- this.cache.clear(); // Clear cache when models change
- }
- // Get architecture-specific required models for a checkpoint
- getRequiredModels(checkpointModel: ModelInfo): RequiredModelInfo[] {
- if (!checkpointModel.architecture) {
- return [];
- }
- const architecture = checkpointModel.architecture.toLowerCase();
-
- switch (architecture) {
- case 'sd3':
- case 'sd3.5':
- return [
- { type: 'vae', description: 'VAE for SD3', optional: true, priority: 1 },
- { type: 'clip-l', description: 'CLIP-L for SD3', optional: false, priority: 2 },
- { type: 'clip-g', description: 'CLIP-G for SD3', optional: false, priority: 3 },
- { type: 't5xxl', description: 'T5XXL for SD3', optional: false, priority: 4 }
- ];
-
- case 'sdxl':
- return [
- { type: 'vae', description: 'VAE for SDXL', optional: true, priority: 1 }
- ];
-
- case 'sd1.x':
- case 'sd2.x':
- return [
- { type: 'vae', description: 'VAE for SD1.x/2.x', optional: true, priority: 1 }
- ];
-
- case 'flux':
- return [
- { type: 'vae', description: 'VAE for FLUX', optional: true, priority: 1 },
- { type: 'clip-l', description: 'CLIP-L for FLUX', optional: false, priority: 2 },
- { type: 't5xxl', description: 'T5XXL for FLUX', optional: false, priority: 3 }
- ];
-
- case 'kontext':
- return [
- { type: 'vae', description: 'VAE for Kontext', optional: true, priority: 1 },
- { type: 'clip-l', description: 'CLIP-L for Kontext', optional: false, priority: 2 },
- { type: 't5xxl', description: 'T5XXL for Kontext', optional: false, priority: 3 }
- ];
-
- case 'chroma':
- return [
- { type: 'vae', description: 'VAE for Chroma', optional: true, priority: 1 },
- { type: 't5xxl', description: 'T5XXL for Chroma', optional: false, priority: 2 }
- ];
-
- case 'wan':
- return [
- { type: 'vae', description: 'VAE for Wan', optional: true, priority: 1 },
- { type: 't5xxl', description: 'T5XXL for Wan', optional: false, priority: 2 },
- { type: 'clip-vision', description: 'CLIP-Vision for Wan', optional: false, priority: 3 }
- ];
-
- case 'qwen':
- return [
- { type: 'vae', description: 'VAE for Qwen', optional: true, priority: 1 },
- { type: 'qwen2vl', description: 'Qwen2VL for Qwen', optional: false, priority: 2 }
- ];
-
- default:
- return [];
- }
- }
- // Find available models by type
- findModelsByType(type: string): ModelInfo[] {
- return this.models.filter(model =>
- model.type.toLowerCase() === type.toLowerCase()
- );
- }
- // Find models by name pattern
- findModelsByName(pattern: string): ModelInfo[] {
- const lowerPattern = pattern.toLowerCase();
- return this.models.filter(model =>
- model.name.toLowerCase().includes(lowerPattern)
- );
- }
- // Get best match for a required model type
- getBestModelForType(type: string, preferredName?: string): ModelInfo | null {
- const modelsOfType = this.findModelsByType(type);
-
- if (modelsOfType.length === 0) {
- return null;
- }
- // If preferred name is specified, try to find it first
- if (preferredName) {
- const preferred = modelsOfType.find(model =>
- model.name.toLowerCase().includes(preferredName.toLowerCase())
- );
- if (preferred) {
- return preferred;
- }
- }
- // Prefer loaded models
- const loadedModels = modelsOfType.filter(model => model.loaded);
- if (loadedModels.length > 0) {
- return loadedModels[0];
- }
- // Return first available model
- return modelsOfType[0];
- }
- // Perform automatic model selection for a checkpoint
- async selectModels(checkpointModel: ModelInfo): Promise<AutoSelectionState> {
- const cacheKey = checkpointModel.id || checkpointModel.name;
-
- // Check cache first
- const cached = this.cache.get(cacheKey);
- if (cached) {
- return cached;
- }
- const state: AutoSelectionState = {
- selectedModels: {},
- autoSelectedModels: {},
- missingModels: [],
- warnings: [],
- errors: [],
- isAutoSelecting: false
- };
- try {
- state.isAutoSelecting = true;
- // Get required models for this architecture
- const requiredModels = this.getRequiredModels(checkpointModel);
-
- // Sort by priority
- requiredModels.sort((a, b) => (a.priority || 0) - (b.priority || 0));
- for (const required of requiredModels) {
- const bestModel = this.getBestModelForType(required.type);
-
- if (bestModel) {
- state.autoSelectedModels[required.type] = bestModel.name;
- state.selectedModels[required.type] = bestModel.name;
-
- if (!bestModel.loaded && !required.optional) {
- state.warnings.push(
- `Selected ${required.type} model "${bestModel.name}" is not loaded. Consider loading it for better performance.`
- );
- }
- } else if (!required.optional) {
- state.missingModels.push(required.type);
- state.errors.push(
- `Required ${required.type} model not found: ${required.description || required.type}`
- );
- } else {
- state.warnings.push(
- `Optional ${required.type} model not found: ${required.description || required.type}`
- );
- }
- }
- // Check for recommended models
- if (checkpointModel.recommended_vae) {
- const vae = this.getBestModelForType('vae', checkpointModel.recommended_vae.name);
- if (vae && vae.name !== state.selectedModels['vae']) {
- state.autoSelectedModels['vae'] = vae.name;
- state.selectedModels['vae'] = vae.name;
- state.warnings.push(
- `Using recommended VAE: ${vae.name} (${checkpointModel.recommended_vae.reason})`
- );
- }
- }
- } catch (error) {
- state.errors.push(`Auto-selection failed: ${error instanceof Error ? error.message : 'Unknown error'}`);
- } finally {
- state.isAutoSelecting = false;
- }
- // Cache the result
- this.cache.set(cacheKey, state);
-
- return state;
- }
- // Get model selection state for multiple checkpoints
- async selectModelsForCheckpoints(checkpoints: ModelInfo[]): Promise<Record<string, AutoSelectionState>> {
- const results: Record<string, AutoSelectionState> = {};
-
- for (const checkpoint of checkpoints) {
- const key = checkpoint.id || checkpoint.name;
- results[key] = await this.selectModels(checkpoint);
- }
-
- return results;
- }
- // Clear the cache
- clearCache(): void {
- this.cache.clear();
- }
- // Get cached selection state
- getCachedState(checkpointId: string): AutoSelectionState | null {
- return this.cache.get(checkpointId) || null;
- }
- // Validate model selection
- validateSelection(checkpointModel: ModelInfo, selectedModels: Record<string, string>): {
- isValid: boolean;
- missingRequired: string[];
- warnings: string[];
- } {
- const requiredModels = this.getRequiredModels(checkpointModel);
- const missingRequired: string[] = [];
- const warnings: string[] = [];
- for (const required of requiredModels) {
- if (!required.optional && !selectedModels[required.type]) {
- missingRequired.push(required.type);
- }
- }
- return {
- isValid: missingRequired.length === 0,
- missingRequired,
- warnings
- };
- }
- }
|