model-selection-context.tsx 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. 'use client';
  2. import React, { createContext, useContext, useReducer, useEffect, useMemo, ReactNode } from 'react';
  3. import { ModelInfo, AutoSelectionState } from '@/lib/api';
  4. import { AutoModelSelector } from '@/lib/services/auto-model-selector';
  5. // Types for the context
  6. interface ModelSelectionState {
  7. selectedCheckpoint: string | null;
  8. selectedModels: Record<string, string | undefined>; // modelType -> modelHash
  9. autoSelectedModels: Record<string, string>; // modelType -> modelHash
  10. userOverrides: Record<string, string>; // modelType -> modelHash (user manual selections)
  11. autoSelectionState: AutoSelectionState | null;
  12. availableModels: ModelInfo[];
  13. isLoading: boolean;
  14. error: string | null;
  15. warnings: string[];
  16. isAutoSelecting: boolean;
  17. }
  18. type ModelSelectionAction =
  19. | { type: 'SET_MODELS'; payload: ModelInfo[] }
  20. | { type: 'SET_SELECTED_CHECKPOINT'; payload: string | null }
  21. | { type: 'SET_SELECTED_MODEL'; payload: { type: string; hash: string | undefined } }
  22. | { type: 'SET_USER_OVERRIDE'; payload: { type: string; hash: string } }
  23. | { type: 'CLEAR_USER_OVERRIDE'; payload: string }
  24. | { type: 'SET_AUTO_SELECTION_STATE'; payload: AutoSelectionState }
  25. | { type: 'SET_LOADING'; payload: boolean }
  26. | { type: 'SET_ERROR'; payload: string | null }
  27. | { type: 'ADD_WARNING'; payload: string }
  28. | { type: 'CLEAR_WARNINGS' }
  29. | { type: 'SET_AUTO_SELECTING'; payload: boolean }
  30. | { type: 'RESET_SELECTION' };
  31. // Initial state
  32. const initialState: ModelSelectionState = {
  33. selectedCheckpoint: null,
  34. selectedModels: {},
  35. autoSelectedModels: {},
  36. userOverrides: {},
  37. autoSelectionState: null,
  38. availableModels: [],
  39. isLoading: false,
  40. error: null,
  41. warnings: [],
  42. isAutoSelecting: false,
  43. };
  44. // Reducer function
  45. function modelSelectionReducer(
  46. state: ModelSelectionState,
  47. action: ModelSelectionAction
  48. ): ModelSelectionState {
  49. switch (action.type) {
  50. case 'SET_MODELS':
  51. return {
  52. ...state,
  53. availableModels: action.payload,
  54. error: null,
  55. };
  56. case 'SET_SELECTED_CHECKPOINT':
  57. return {
  58. ...state,
  59. selectedCheckpoint: action.payload,
  60. // Clear auto-selection when checkpoint changes
  61. autoSelectedModels: {},
  62. autoSelectionState: null,
  63. warnings: [],
  64. };
  65. case 'SET_SELECTED_MODEL':
  66. return {
  67. ...state,
  68. selectedModels: {
  69. ...state.selectedModels,
  70. [action.payload.type]: action.payload.hash,
  71. },
  72. };
  73. case 'SET_USER_OVERRIDE':
  74. return {
  75. ...state,
  76. userOverrides: {
  77. ...state.userOverrides,
  78. [action.payload.type]: action.payload.hash,
  79. },
  80. selectedModels: {
  81. ...state.selectedModels,
  82. [action.payload.type]: action.payload.hash,
  83. },
  84. };
  85. case 'CLEAR_USER_OVERRIDE':
  86. const newUserOverrides = { ...state.userOverrides };
  87. delete newUserOverrides[action.payload];
  88. // If we had an auto-selected model for this type, restore it
  89. const restoredModel = state.autoSelectedModels[action.payload];
  90. const newSelectedModels = { ...state.selectedModels };
  91. if (restoredModel) {
  92. newSelectedModels[action.payload] = restoredModel;
  93. } else {
  94. delete newSelectedModels[action.payload];
  95. }
  96. return {
  97. ...state,
  98. userOverrides: newUserOverrides,
  99. selectedModels: newSelectedModels,
  100. };
  101. case 'SET_AUTO_SELECTION_STATE':
  102. return {
  103. ...state,
  104. autoSelectionState: action.payload,
  105. // Merge auto-selected models with current selections, but don't override user selections
  106. selectedModels: {
  107. ...action.payload.autoSelectedModels,
  108. ...state.userOverrides, // User overrides take precedence
  109. },
  110. autoSelectedModels: action.payload.autoSelectedModels,
  111. warnings: [...state.warnings, ...action.payload.warnings],
  112. error: action.payload.errors.length > 0 ? action.payload.errors[0] : state.error,
  113. };
  114. case 'SET_LOADING':
  115. return {
  116. ...state,
  117. isLoading: action.payload,
  118. };
  119. case 'SET_ERROR':
  120. return {
  121. ...state,
  122. error: action.payload,
  123. };
  124. case 'ADD_WARNING':
  125. return {
  126. ...state,
  127. warnings: [...state.warnings, action.payload],
  128. };
  129. case 'CLEAR_WARNINGS':
  130. return {
  131. ...state,
  132. warnings: [],
  133. };
  134. case 'SET_AUTO_SELECTING':
  135. return {
  136. ...state,
  137. isAutoSelecting: action.payload,
  138. };
  139. case 'RESET_SELECTION':
  140. return {
  141. ...initialState,
  142. availableModels: state.availableModels,
  143. };
  144. default:
  145. return state;
  146. }
  147. }
  148. // Context type
  149. interface ModelSelectionContextType {
  150. state: ModelSelectionState;
  151. actions: {
  152. setModels: (models: ModelInfo[]) => void;
  153. setSelectedCheckpoint: (checkpointHash: string | null) => void;
  154. setSelectedModel: (type: string, hash: string | undefined) => void;
  155. setUserOverride: (type: string, hash: string) => void;
  156. clearUserOverride: (type: string) => void;
  157. performAutoSelection: (checkpointModel: ModelInfo) => Promise<void>;
  158. clearWarnings: () => void;
  159. resetSelection: () => void;
  160. validateSelection: (checkpointModel: ModelInfo) => {
  161. isValid: boolean;
  162. missingRequired: string[];
  163. warnings: string[];
  164. };
  165. };
  166. }
  167. // Create context
  168. const ModelSelectionContext = createContext<ModelSelectionContextType | null>(null);
  169. // Provider component
  170. interface ModelSelectionProviderProps {
  171. children: ReactNode;
  172. }
  173. export function ModelSelectionProvider({ children }: ModelSelectionProviderProps) {
  174. const [state, dispatch] = useReducer(modelSelectionReducer, initialState);
  175. const autoSelectorRef = React.useRef<AutoModelSelector>(new AutoModelSelector());
  176. // Update auto selector when models change
  177. useEffect(() => {
  178. autoSelectorRef.current.updateModels(state.availableModels);
  179. }, [state.availableModels]);
  180. // Actions
  181. const actions = useMemo(() => ({
  182. setModels: (models: ModelInfo[]) => {
  183. dispatch({ type: 'SET_MODELS', payload: models });
  184. },
  185. setSelectedCheckpoint: (checkpointHash: string | null) => {
  186. dispatch({ type: 'SET_SELECTED_CHECKPOINT', payload: checkpointHash });
  187. },
  188. setSelectedModel: (type: string, hash: string | undefined) => {
  189. dispatch({ type: 'SET_SELECTED_MODEL', payload: { type, hash } });
  190. },
  191. setUserOverride: (type: string, hash: string) => {
  192. dispatch({ type: 'SET_USER_OVERRIDE', payload: { type, hash } });
  193. },
  194. clearUserOverride: (type: string) => {
  195. dispatch({ type: 'CLEAR_USER_OVERRIDE', payload: type });
  196. },
  197. performAutoSelection: async (checkpointModel: ModelInfo) => {
  198. try {
  199. dispatch({ type: 'SET_AUTO_SELECTING', payload: true });
  200. dispatch({ type: 'SET_LOADING', payload: true });
  201. dispatch({ type: 'SET_ERROR', payload: null });
  202. const autoSelectionState = await autoSelectorRef.current.selectModels(checkpointModel);
  203. dispatch({ type: 'SET_AUTO_SELECTION_STATE', payload: autoSelectionState });
  204. } catch (error) {
  205. const errorMessage = error instanceof Error ? error.message : 'Auto-selection failed';
  206. dispatch({ type: 'SET_ERROR', payload: errorMessage });
  207. } finally {
  208. dispatch({ type: 'SET_AUTO_SELECTING', payload: false });
  209. dispatch({ type: 'SET_LOADING', payload: false });
  210. }
  211. },
  212. clearWarnings: () => {
  213. dispatch({ type: 'CLEAR_WARNINGS' });
  214. },
  215. resetSelection: () => {
  216. dispatch({ type: 'RESET_SELECTION' });
  217. },
  218. validateSelection: (checkpointModel: ModelInfo) => {
  219. return autoSelectorRef.current.validateSelection(checkpointModel, state.selectedModels);
  220. },
  221. }), [state.selectedModels]);
  222. // Auto-select when checkpoint changes
  223. useEffect(() => {
  224. if (state.selectedCheckpoint && state.availableModels.length > 0) {
  225. const checkpointModel = state.availableModels.find(
  226. model => (model.sha256_short === state.selectedCheckpoint || model.sha256 === state.selectedCheckpoint || model.id === state.selectedCheckpoint) &&
  227. (model.type === 'checkpoint' || model.type === 'stable-diffusion')
  228. );
  229. if (checkpointModel) {
  230. actions.performAutoSelection(checkpointModel);
  231. }
  232. }
  233. }, [state.selectedCheckpoint, state.availableModels, actions]);
  234. const value = {
  235. state,
  236. actions,
  237. };
  238. return (
  239. <ModelSelectionContext.Provider value={value}>
  240. {children}
  241. </ModelSelectionContext.Provider>
  242. );
  243. }
  244. // Hook to use the context
  245. export function useModelSelection() {
  246. const context = useContext(ModelSelectionContext);
  247. if (!context) {
  248. throw new Error('useModelSelection must be used within a ModelSelectionProvider');
  249. }
  250. return context;
  251. }
  252. // Helper hook for checkpoint selection
  253. export function useCheckpointSelection() {
  254. const { state, actions } = useModelSelection();
  255. const checkpointModels = state.availableModels.filter(model =>
  256. model.type === 'checkpoint' || model.type === 'stable-diffusion'
  257. );
  258. const selectedCheckpointModel = state.selectedCheckpoint
  259. ? state.availableModels.find(model =>
  260. model.sha256_short === state.selectedCheckpoint ||
  261. model.sha256 === state.selectedCheckpoint ||
  262. model.id === state.selectedCheckpoint ||
  263. model.name === state.selectedCheckpoint
  264. )
  265. : null;
  266. return {
  267. checkpointModels,
  268. selectedCheckpointModel,
  269. selectedCheckpoint: state.selectedCheckpoint,
  270. setSelectedCheckpoint: actions.setSelectedCheckpoint,
  271. autoSelectedModels: state.autoSelectedModels,
  272. userOverrides: state.userOverrides,
  273. isAutoSelecting: state.isAutoSelecting,
  274. warnings: state.warnings,
  275. error: state.error,
  276. };
  277. }
  278. // Helper hook for model type selection
  279. export function useModelTypeSelection(modelType: string) {
  280. const { state, actions } = useModelSelection();
  281. const availableModels = state.availableModels.filter(model =>
  282. model.type.toLowerCase() === modelType.toLowerCase()
  283. );
  284. const selectedModel = state.selectedModels[modelType];
  285. const isUserOverride = !!state.userOverrides[modelType];
  286. const isAutoSelected = !!state.autoSelectedModels[modelType] && !isUserOverride;
  287. return {
  288. availableModels,
  289. selectedModel,
  290. isUserOverride,
  291. isAutoSelected,
  292. setSelectedModel: (hash: string | undefined) => actions.setSelectedModel(modelType, hash),
  293. setUserOverride: (hash: string) => actions.setUserOverride(modelType, hash),
  294. clearUserOverride: () => actions.clearUserOverride(modelType),
  295. };
  296. }