'use client'; import React, { createContext, useContext, useReducer, useEffect, useMemo, ReactNode } from 'react'; import { ModelInfo, AutoSelectionState } from '@/lib/api'; import { AutoModelSelector } from '@/lib/services/auto-model-selector'; // Types for the context interface ModelSelectionState { selectedCheckpoint: string | null; selectedModels: Record; // modelType -> modelHash autoSelectedModels: Record; // modelType -> modelHash userOverrides: Record; // modelType -> modelHash (user manual selections) autoSelectionState: AutoSelectionState | null; availableModels: ModelInfo[]; isLoading: boolean; error: string | null; warnings: string[]; isAutoSelecting: boolean; } type ModelSelectionAction = | { type: 'SET_MODELS'; payload: ModelInfo[] } | { type: 'SET_SELECTED_CHECKPOINT'; payload: string | null } | { type: 'SET_SELECTED_MODEL'; payload: { type: string; hash: string | undefined } } | { type: 'SET_USER_OVERRIDE'; payload: { type: string; hash: string } } | { type: 'CLEAR_USER_OVERRIDE'; payload: string } | { type: 'SET_AUTO_SELECTION_STATE'; payload: AutoSelectionState } | { type: 'SET_LOADING'; payload: boolean } | { type: 'SET_ERROR'; payload: string | null } | { type: 'ADD_WARNING'; payload: string } | { type: 'CLEAR_WARNINGS' } | { type: 'SET_AUTO_SELECTING'; payload: boolean } | { type: 'RESET_SELECTION' }; // Initial state const initialState: ModelSelectionState = { selectedCheckpoint: null, selectedModels: {}, autoSelectedModels: {}, userOverrides: {}, autoSelectionState: null, availableModels: [], isLoading: false, error: null, warnings: [], isAutoSelecting: false, }; // Reducer function function modelSelectionReducer( state: ModelSelectionState, action: ModelSelectionAction ): ModelSelectionState { switch (action.type) { case 'SET_MODELS': return { ...state, availableModels: action.payload, error: null, }; case 'SET_SELECTED_CHECKPOINT': return { ...state, selectedCheckpoint: action.payload, // Clear auto-selection when checkpoint changes autoSelectedModels: {}, autoSelectionState: null, warnings: [], }; case 'SET_SELECTED_MODEL': return { ...state, selectedModels: { ...state.selectedModels, [action.payload.type]: action.payload.hash, }, }; case 'SET_USER_OVERRIDE': return { ...state, userOverrides: { ...state.userOverrides, [action.payload.type]: action.payload.hash, }, selectedModels: { ...state.selectedModels, [action.payload.type]: action.payload.hash, }, }; case 'CLEAR_USER_OVERRIDE': const newUserOverrides = { ...state.userOverrides }; delete newUserOverrides[action.payload]; // If we had an auto-selected model for this type, restore it const restoredModel = state.autoSelectedModels[action.payload]; const newSelectedModels = { ...state.selectedModels }; if (restoredModel) { newSelectedModels[action.payload] = restoredModel; } else { delete newSelectedModels[action.payload]; } return { ...state, userOverrides: newUserOverrides, selectedModels: newSelectedModels, }; case 'SET_AUTO_SELECTION_STATE': return { ...state, autoSelectionState: action.payload, // Merge auto-selected models with current selections, but don't override user selections selectedModels: { ...action.payload.autoSelectedModels, ...state.userOverrides, // User overrides take precedence }, autoSelectedModels: action.payload.autoSelectedModels, warnings: [...state.warnings, ...action.payload.warnings], error: action.payload.errors.length > 0 ? action.payload.errors[0] : state.error, }; case 'SET_LOADING': return { ...state, isLoading: action.payload, }; case 'SET_ERROR': return { ...state, error: action.payload, }; case 'ADD_WARNING': return { ...state, warnings: [...state.warnings, action.payload], }; case 'CLEAR_WARNINGS': return { ...state, warnings: [], }; case 'SET_AUTO_SELECTING': return { ...state, isAutoSelecting: action.payload, }; case 'RESET_SELECTION': return { ...initialState, availableModels: state.availableModels, }; default: return state; } } // Context type interface ModelSelectionContextType { state: ModelSelectionState; actions: { setModels: (models: ModelInfo[]) => void; setSelectedCheckpoint: (checkpointHash: string | null) => void; setSelectedModel: (type: string, hash: string | undefined) => void; setUserOverride: (type: string, hash: string) => void; clearUserOverride: (type: string) => void; performAutoSelection: (checkpointModel: ModelInfo) => Promise; clearWarnings: () => void; resetSelection: () => void; validateSelection: (checkpointModel: ModelInfo) => { isValid: boolean; missingRequired: string[]; warnings: string[]; }; }; } // Create context const ModelSelectionContext = createContext(null); // Provider component interface ModelSelectionProviderProps { children: ReactNode; } export function ModelSelectionProvider({ children }: ModelSelectionProviderProps) { const [state, dispatch] = useReducer(modelSelectionReducer, initialState); const autoSelectorRef = React.useRef(new AutoModelSelector()); // Update auto selector when models change useEffect(() => { autoSelectorRef.current.updateModels(state.availableModels); }, [state.availableModels]); // Actions const actions = useMemo(() => ({ setModels: (models: ModelInfo[]) => { dispatch({ type: 'SET_MODELS', payload: models }); }, setSelectedCheckpoint: (checkpointHash: string | null) => { dispatch({ type: 'SET_SELECTED_CHECKPOINT', payload: checkpointHash }); }, setSelectedModel: (type: string, hash: string | undefined) => { dispatch({ type: 'SET_SELECTED_MODEL', payload: { type, hash } }); }, setUserOverride: (type: string, hash: string) => { dispatch({ type: 'SET_USER_OVERRIDE', payload: { type, hash } }); }, clearUserOverride: (type: string) => { dispatch({ type: 'CLEAR_USER_OVERRIDE', payload: type }); }, performAutoSelection: async (checkpointModel: ModelInfo) => { try { dispatch({ type: 'SET_AUTO_SELECTING', payload: true }); dispatch({ type: 'SET_LOADING', payload: true }); dispatch({ type: 'SET_ERROR', payload: null }); const autoSelectionState = await autoSelectorRef.current.selectModels(checkpointModel); dispatch({ type: 'SET_AUTO_SELECTION_STATE', payload: autoSelectionState }); } catch (error) { const errorMessage = error instanceof Error ? error.message : 'Auto-selection failed'; dispatch({ type: 'SET_ERROR', payload: errorMessage }); } finally { dispatch({ type: 'SET_AUTO_SELECTING', payload: false }); dispatch({ type: 'SET_LOADING', payload: false }); } }, clearWarnings: () => { dispatch({ type: 'CLEAR_WARNINGS' }); }, resetSelection: () => { dispatch({ type: 'RESET_SELECTION' }); }, validateSelection: (checkpointModel: ModelInfo) => { return autoSelectorRef.current.validateSelection(checkpointModel, state.selectedModels); }, }), [state.selectedModels]); // Auto-select when checkpoint changes useEffect(() => { if (state.selectedCheckpoint && state.availableModels.length > 0) { const checkpointModel = state.availableModels.find( model => (model.sha256_short === state.selectedCheckpoint || model.sha256 === state.selectedCheckpoint || model.id === state.selectedCheckpoint) && (model.type === 'checkpoint' || model.type === 'stable-diffusion') ); if (checkpointModel) { actions.performAutoSelection(checkpointModel); } } }, [state.selectedCheckpoint, state.availableModels, actions]); const value = { state, actions, }; return ( {children} ); } // Hook to use the context export function useModelSelection() { const context = useContext(ModelSelectionContext); if (!context) { throw new Error('useModelSelection must be used within a ModelSelectionProvider'); } return context; } // Helper hook for checkpoint selection export function useCheckpointSelection() { const { state, actions } = useModelSelection(); const checkpointModels = state.availableModels.filter(model => model.type === 'checkpoint' || model.type === 'stable-diffusion' ); const selectedCheckpointModel = state.selectedCheckpoint ? state.availableModels.find(model => model.sha256_short === state.selectedCheckpoint || model.sha256 === state.selectedCheckpoint || model.id === state.selectedCheckpoint || model.name === state.selectedCheckpoint ) : null; return { checkpointModels, selectedCheckpointModel, selectedCheckpoint: state.selectedCheckpoint, setSelectedCheckpoint: actions.setSelectedCheckpoint, autoSelectedModels: state.autoSelectedModels, userOverrides: state.userOverrides, isAutoSelecting: state.isAutoSelecting, warnings: state.warnings, error: state.error, }; } // Helper hook for model type selection export function useModelTypeSelection(modelType: string) { const { state, actions } = useModelSelection(); const availableModels = state.availableModels.filter(model => model.type.toLowerCase() === modelType.toLowerCase() ); const selectedModel = state.selectedModels[modelType]; const isUserOverride = !!state.userOverrides[modelType]; const isAutoSelected = !!state.autoSelectedModels[modelType] && !isUserOverride; return { availableModels, selectedModel, isUserOverride, isAutoSelected, setSelectedModel: (hash: string | undefined) => actions.setSelectedModel(modelType, hash), setUserOverride: (hash: string) => actions.setUserOverride(modelType, hash), clearUserOverride: () => actions.clearUserOverride(modelType), }; }