import { ModelInfo, RequiredModelInfo, RecommendedModelInfo, AutoSelectionState } from './api'; export class AutoModelSelector { private models: ModelInfo[] = []; private cache: Map = new Map(); constructor(models: ModelInfo[] = []) { this.models = models; } // Update the models list updateModels(models: ModelInfo[]): void { this.models = models; this.cache.clear(); // Clear cache when models change } // Get architecture-specific required models for a checkpoint getRequiredModels(checkpointModel: ModelInfo): RequiredModelInfo[] { if (!checkpointModel.architecture) { return []; } const architecture = checkpointModel.architecture.toLowerCase(); switch (architecture) { case 'sd3': case 'sd3.5': return [ { type: 'vae', description: 'VAE for SD3', optional: true, priority: 1 }, { type: 'clip-l', description: 'CLIP-L for SD3', optional: false, priority: 2 }, { type: 'clip-g', description: 'CLIP-G for SD3', optional: false, priority: 3 }, { type: 't5xxl', description: 'T5XXL for SD3', optional: false, priority: 4 } ]; case 'sdxl': return [ { type: 'vae', description: 'VAE for SDXL', optional: true, priority: 1 } ]; case 'sd1.x': case 'sd2.x': return [ { type: 'vae', description: 'VAE for SD1.x/2.x', optional: true, priority: 1 } ]; case 'flux': return [ { type: 'vae', description: 'VAE for FLUX', optional: true, priority: 1 }, { type: 'clip-l', description: 'CLIP-L for FLUX', optional: false, priority: 2 }, { type: 't5xxl', description: 'T5XXL for FLUX', optional: false, priority: 3 } ]; case 'kontext': return [ { type: 'vae', description: 'VAE for Kontext', optional: true, priority: 1 }, { type: 'clip-l', description: 'CLIP-L for Kontext', optional: false, priority: 2 }, { type: 't5xxl', description: 'T5XXL for Kontext', optional: false, priority: 3 } ]; case 'chroma': return [ { type: 'vae', description: 'VAE for Chroma', optional: true, priority: 1 }, { type: 't5xxl', description: 'T5XXL for Chroma', optional: false, priority: 2 } ]; case 'wan': return [ { type: 'vae', description: 'VAE for Wan', optional: true, priority: 1 }, { type: 't5xxl', description: 'T5XXL for Wan', optional: false, priority: 2 }, { type: 'clip-vision', description: 'CLIP-Vision for Wan', optional: false, priority: 3 } ]; case 'qwen': return [ { type: 'vae', description: 'VAE for Qwen', optional: true, priority: 1 }, { type: 'qwen2vl', description: 'Qwen2VL for Qwen', optional: false, priority: 2 } ]; default: return []; } } // Find available models by type findModelsByType(type: string): ModelInfo[] { return this.models.filter(model => model.type.toLowerCase() === type.toLowerCase() ); } // Find models by name pattern findModelsByName(pattern: string): ModelInfo[] { const lowerPattern = pattern.toLowerCase(); return this.models.filter(model => model.name.toLowerCase().includes(lowerPattern) ); } // Get best match for a required model type getBestModelForType(type: string, preferredName?: string): ModelInfo | null { const modelsOfType = this.findModelsByType(type); if (modelsOfType.length === 0) { return null; } // If preferred name is specified, try to find it first if (preferredName) { const preferred = modelsOfType.find(model => model.name.toLowerCase().includes(preferredName.toLowerCase()) ); if (preferred) { return preferred; } } // Prefer loaded models const loadedModels = modelsOfType.filter(model => model.loaded); if (loadedModels.length > 0) { return loadedModels[0]; } // Return first available model return modelsOfType[0]; } // Perform automatic model selection for a checkpoint async selectModels(checkpointModel: ModelInfo): Promise { const cacheKey = checkpointModel.id || checkpointModel.name; // Check cache first const cached = this.cache.get(cacheKey); if (cached) { return cached; } const state: AutoSelectionState = { selectedModels: {}, autoSelectedModels: {}, missingModels: [], warnings: [], errors: [], isAutoSelecting: false }; try { state.isAutoSelecting = true; // Get required models for this architecture const requiredModels = this.getRequiredModels(checkpointModel); // Sort by priority requiredModels.sort((a, b) => (a.priority || 0) - (b.priority || 0)); for (const required of requiredModels) { const bestModel = this.getBestModelForType(required.type); if (bestModel) { state.autoSelectedModels[required.type] = bestModel.name; state.selectedModels[required.type] = bestModel.name; if (!bestModel.loaded && !required.optional) { state.warnings.push( `Selected ${required.type} model "${bestModel.name}" is not loaded. Consider loading it for better performance.` ); } } else if (!required.optional) { state.missingModels.push(required.type); state.errors.push( `Required ${required.type} model not found: ${required.description || required.type}` ); } else { state.warnings.push( `Optional ${required.type} model not found: ${required.description || required.type}` ); } } // Check for recommended models if (checkpointModel.recommended_vae) { const vae = this.getBestModelForType('vae', checkpointModel.recommended_vae.name); if (vae && vae.name !== state.selectedModels['vae']) { state.autoSelectedModels['vae'] = vae.name; state.selectedModels['vae'] = vae.name; state.warnings.push( `Using recommended VAE: ${vae.name} (${checkpointModel.recommended_vae.reason})` ); } } } catch (error) { state.errors.push(`Auto-selection failed: ${error instanceof Error ? error.message : 'Unknown error'}`); } finally { state.isAutoSelecting = false; } // Cache the result this.cache.set(cacheKey, state); return state; } // Get model selection state for multiple checkpoints async selectModelsForCheckpoints(checkpoints: ModelInfo[]): Promise> { const results: Record = {}; for (const checkpoint of checkpoints) { const key = checkpoint.id || checkpoint.name; results[key] = await this.selectModels(checkpoint); } return results; } // Clear the cache clearCache(): void { this.cache.clear(); } // Get cached selection state getCachedState(checkpointId: string): AutoSelectionState | null { return this.cache.get(checkpointId) || null; } // Validate model selection validateSelection(checkpointModel: ModelInfo, selectedModels: Record): { isValid: boolean; missingRequired: string[]; warnings: string[]; } { const requiredModels = this.getRequiredModels(checkpointModel); const missingRequired: string[] = []; const warnings: string[] = []; for (const required of requiredModels) { if (!required.optional && !selectedModels[required.type]) { missingRequired.push(required.type); } } return { isValid: missingRequired.length === 0, missingRequired, warnings }; } }