auto-model-selector.ts 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. import { ModelInfo, RequiredModelInfo, RecommendedModelInfo, AutoSelectionState } from './api';
  2. export class AutoModelSelector {
  3. private models: ModelInfo[] = [];
  4. private cache: Map<string, AutoSelectionState> = new Map();
  5. constructor(models: ModelInfo[] = []) {
  6. this.models = models;
  7. }
  8. // Update the models list
  9. updateModels(models: ModelInfo[]): void {
  10. this.models = models;
  11. this.cache.clear(); // Clear cache when models change
  12. }
  13. // Get architecture-specific required models for a checkpoint
  14. getRequiredModels(checkpointModel: ModelInfo): RequiredModelInfo[] {
  15. if (!checkpointModel.architecture) {
  16. return [];
  17. }
  18. const architecture = checkpointModel.architecture.toLowerCase();
  19. switch (architecture) {
  20. case 'sd3':
  21. case 'sd3.5':
  22. return [
  23. { type: 'vae', description: 'VAE for SD3', optional: true, priority: 1 },
  24. { type: 'clip-l', description: 'CLIP-L for SD3', optional: false, priority: 2 },
  25. { type: 'clip-g', description: 'CLIP-G for SD3', optional: false, priority: 3 },
  26. { type: 't5xxl', description: 'T5XXL for SD3', optional: false, priority: 4 }
  27. ];
  28. case 'sdxl':
  29. return [
  30. { type: 'vae', description: 'VAE for SDXL', optional: true, priority: 1 }
  31. ];
  32. case 'sd1.x':
  33. case 'sd2.x':
  34. return [
  35. { type: 'vae', description: 'VAE for SD1.x/2.x', optional: true, priority: 1 }
  36. ];
  37. case 'flux':
  38. return [
  39. { type: 'vae', description: 'VAE for FLUX', optional: true, priority: 1 },
  40. { type: 'clip-l', description: 'CLIP-L for FLUX', optional: false, priority: 2 },
  41. { type: 't5xxl', description: 'T5XXL for FLUX', optional: false, priority: 3 }
  42. ];
  43. case 'kontext':
  44. return [
  45. { type: 'vae', description: 'VAE for Kontext', optional: true, priority: 1 },
  46. { type: 'clip-l', description: 'CLIP-L for Kontext', optional: false, priority: 2 },
  47. { type: 't5xxl', description: 'T5XXL for Kontext', optional: false, priority: 3 }
  48. ];
  49. case 'chroma':
  50. return [
  51. { type: 'vae', description: 'VAE for Chroma', optional: true, priority: 1 },
  52. { type: 't5xxl', description: 'T5XXL for Chroma', optional: false, priority: 2 }
  53. ];
  54. case 'wan':
  55. return [
  56. { type: 'vae', description: 'VAE for Wan', optional: true, priority: 1 },
  57. { type: 't5xxl', description: 'T5XXL for Wan', optional: false, priority: 2 },
  58. { type: 'clip-vision', description: 'CLIP-Vision for Wan', optional: false, priority: 3 }
  59. ];
  60. case 'qwen':
  61. return [
  62. { type: 'vae', description: 'VAE for Qwen', optional: true, priority: 1 },
  63. { type: 'qwen2vl', description: 'Qwen2VL for Qwen', optional: false, priority: 2 }
  64. ];
  65. default:
  66. return [];
  67. }
  68. }
  69. // Find available models by type
  70. findModelsByType(type: string): ModelInfo[] {
  71. return this.models.filter(model =>
  72. model.type.toLowerCase() === type.toLowerCase()
  73. );
  74. }
  75. // Find models by name pattern
  76. findModelsByName(pattern: string): ModelInfo[] {
  77. const lowerPattern = pattern.toLowerCase();
  78. return this.models.filter(model =>
  79. model.name.toLowerCase().includes(lowerPattern)
  80. );
  81. }
  82. // Get best match for a required model type
  83. getBestModelForType(type: string, preferredName?: string): ModelInfo | null {
  84. const modelsOfType = this.findModelsByType(type);
  85. if (modelsOfType.length === 0) {
  86. return null;
  87. }
  88. // If preferred name is specified, try to find it first
  89. if (preferredName) {
  90. const preferred = modelsOfType.find(model =>
  91. model.name.toLowerCase().includes(preferredName.toLowerCase())
  92. );
  93. if (preferred) {
  94. return preferred;
  95. }
  96. }
  97. // Prefer loaded models
  98. const loadedModels = modelsOfType.filter(model => model.loaded);
  99. if (loadedModels.length > 0) {
  100. return loadedModels[0];
  101. }
  102. // Return first available model
  103. return modelsOfType[0];
  104. }
  105. // Perform automatic model selection for a checkpoint
  106. async selectModels(checkpointModel: ModelInfo): Promise<AutoSelectionState> {
  107. const cacheKey = checkpointModel.id || checkpointModel.name;
  108. // Check cache first
  109. const cached = this.cache.get(cacheKey);
  110. if (cached) {
  111. return cached;
  112. }
  113. const state: AutoSelectionState = {
  114. selectedModels: {},
  115. autoSelectedModels: {},
  116. missingModels: [],
  117. warnings: [],
  118. errors: [],
  119. isAutoSelecting: false
  120. };
  121. try {
  122. state.isAutoSelecting = true;
  123. // Get required models for this architecture
  124. const requiredModels = this.getRequiredModels(checkpointModel);
  125. // Sort by priority
  126. requiredModels.sort((a, b) => (a.priority || 0) - (b.priority || 0));
  127. for (const required of requiredModels) {
  128. const bestModel = this.getBestModelForType(required.type);
  129. if (bestModel) {
  130. state.autoSelectedModels[required.type] = bestModel.name;
  131. state.selectedModels[required.type] = bestModel.name;
  132. if (!bestModel.loaded && !required.optional) {
  133. state.warnings.push(
  134. `Selected ${required.type} model "${bestModel.name}" is not loaded. Consider loading it for better performance.`
  135. );
  136. }
  137. } else if (!required.optional) {
  138. state.missingModels.push(required.type);
  139. state.errors.push(
  140. `Required ${required.type} model not found: ${required.description || required.type}`
  141. );
  142. } else {
  143. state.warnings.push(
  144. `Optional ${required.type} model not found: ${required.description || required.type}`
  145. );
  146. }
  147. }
  148. // Check for recommended models
  149. if (checkpointModel.recommended_vae) {
  150. const vae = this.getBestModelForType('vae', checkpointModel.recommended_vae.name);
  151. if (vae && vae.name !== state.selectedModels['vae']) {
  152. state.autoSelectedModels['vae'] = vae.name;
  153. state.selectedModels['vae'] = vae.name;
  154. state.warnings.push(
  155. `Using recommended VAE: ${vae.name} (${checkpointModel.recommended_vae.reason})`
  156. );
  157. }
  158. }
  159. } catch (error) {
  160. state.errors.push(`Auto-selection failed: ${error instanceof Error ? error.message : 'Unknown error'}`);
  161. } finally {
  162. state.isAutoSelecting = false;
  163. }
  164. // Cache the result
  165. this.cache.set(cacheKey, state);
  166. return state;
  167. }
  168. // Get model selection state for multiple checkpoints
  169. async selectModelsForCheckpoints(checkpoints: ModelInfo[]): Promise<Record<string, AutoSelectionState>> {
  170. const results: Record<string, AutoSelectionState> = {};
  171. for (const checkpoint of checkpoints) {
  172. const key = checkpoint.id || checkpoint.name;
  173. results[key] = await this.selectModels(checkpoint);
  174. }
  175. return results;
  176. }
  177. // Clear the cache
  178. clearCache(): void {
  179. this.cache.clear();
  180. }
  181. // Get cached selection state
  182. getCachedState(checkpointId: string): AutoSelectionState | null {
  183. return this.cache.get(checkpointId) || null;
  184. }
  185. // Validate model selection
  186. validateSelection(checkpointModel: ModelInfo, selectedModels: Record<string, string>): {
  187. isValid: boolean;
  188. missingRequired: string[];
  189. warnings: string[];
  190. } {
  191. const requiredModels = this.getRequiredModels(checkpointModel);
  192. const missingRequired: string[] = [];
  193. const warnings: string[] = [];
  194. for (const required of requiredModels) {
  195. if (!required.optional && !selectedModels[required.type]) {
  196. missingRequired.push(required.type);
  197. }
  198. }
  199. return {
  200. isValid: missingRequired.length === 0,
  201. missingRequired,
  202. warnings
  203. };
  204. }
  205. }