| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344 |
- 'use client';
- import React, { createContext, useContext, useReducer, useEffect, useMemo, ReactNode } from 'react';
- import { ModelInfo, AutoSelectionState } from '@/lib/api';
- import { AutoModelSelector } from '@/lib/services/auto-model-selector';
- // Types for the context
- interface ModelSelectionState {
- selectedCheckpoint: string | null;
- selectedModels: Record<string, string | undefined>; // modelType -> modelHash
- autoSelectedModels: Record<string, string>; // modelType -> modelHash
- userOverrides: Record<string, string>; // modelType -> modelHash (user manual selections)
- autoSelectionState: AutoSelectionState | null;
- availableModels: ModelInfo[];
- isLoading: boolean;
- error: string | null;
- warnings: string[];
- isAutoSelecting: boolean;
- }
- type ModelSelectionAction =
- | { type: 'SET_MODELS'; payload: ModelInfo[] }
- | { type: 'SET_SELECTED_CHECKPOINT'; payload: string | null }
- | { type: 'SET_SELECTED_MODEL'; payload: { type: string; hash: string | undefined } }
- | { type: 'SET_USER_OVERRIDE'; payload: { type: string; hash: string } }
- | { type: 'CLEAR_USER_OVERRIDE'; payload: string }
- | { type: 'SET_AUTO_SELECTION_STATE'; payload: AutoSelectionState }
- | { type: 'SET_LOADING'; payload: boolean }
- | { type: 'SET_ERROR'; payload: string | null }
- | { type: 'ADD_WARNING'; payload: string }
- | { type: 'CLEAR_WARNINGS' }
- | { type: 'SET_AUTO_SELECTING'; payload: boolean }
- | { type: 'RESET_SELECTION' };
- // Initial state
- const initialState: ModelSelectionState = {
- selectedCheckpoint: null,
- selectedModels: {},
- autoSelectedModels: {},
- userOverrides: {},
- autoSelectionState: null,
- availableModels: [],
- isLoading: false,
- error: null,
- warnings: [],
- isAutoSelecting: false,
- };
- // Reducer function
- function modelSelectionReducer(
- state: ModelSelectionState,
- action: ModelSelectionAction
- ): ModelSelectionState {
- switch (action.type) {
- case 'SET_MODELS':
- return {
- ...state,
- availableModels: action.payload,
- error: null,
- };
- case 'SET_SELECTED_CHECKPOINT':
- return {
- ...state,
- selectedCheckpoint: action.payload,
- // Clear auto-selection when checkpoint changes
- autoSelectedModels: {},
- autoSelectionState: null,
- warnings: [],
- };
- case 'SET_SELECTED_MODEL':
- return {
- ...state,
- selectedModels: {
- ...state.selectedModels,
- [action.payload.type]: action.payload.hash,
- },
- };
- case 'SET_USER_OVERRIDE':
- return {
- ...state,
- userOverrides: {
- ...state.userOverrides,
- [action.payload.type]: action.payload.hash,
- },
- selectedModels: {
- ...state.selectedModels,
- [action.payload.type]: action.payload.hash,
- },
- };
- case 'CLEAR_USER_OVERRIDE':
- const newUserOverrides = { ...state.userOverrides };
- delete newUserOverrides[action.payload];
-
- // If we had an auto-selected model for this type, restore it
- const restoredModel = state.autoSelectedModels[action.payload];
- const newSelectedModels = { ...state.selectedModels };
- if (restoredModel) {
- newSelectedModels[action.payload] = restoredModel;
- } else {
- delete newSelectedModels[action.payload];
- }
- return {
- ...state,
- userOverrides: newUserOverrides,
- selectedModels: newSelectedModels,
- };
- case 'SET_AUTO_SELECTION_STATE':
- return {
- ...state,
- autoSelectionState: action.payload,
- // Merge auto-selected models with current selections, but don't override user selections
- selectedModels: {
- ...action.payload.autoSelectedModels,
- ...state.userOverrides, // User overrides take precedence
- },
- autoSelectedModels: action.payload.autoSelectedModels,
- warnings: [...state.warnings, ...action.payload.warnings],
- error: action.payload.errors.length > 0 ? action.payload.errors[0] : state.error,
- };
- case 'SET_LOADING':
- return {
- ...state,
- isLoading: action.payload,
- };
- case 'SET_ERROR':
- return {
- ...state,
- error: action.payload,
- };
- case 'ADD_WARNING':
- return {
- ...state,
- warnings: [...state.warnings, action.payload],
- };
- case 'CLEAR_WARNINGS':
- return {
- ...state,
- warnings: [],
- };
- case 'SET_AUTO_SELECTING':
- return {
- ...state,
- isAutoSelecting: action.payload,
- };
- case 'RESET_SELECTION':
- return {
- ...initialState,
- availableModels: state.availableModels,
- };
- default:
- return state;
- }
- }
- // Context type
- interface ModelSelectionContextType {
- state: ModelSelectionState;
- actions: {
- setModels: (models: ModelInfo[]) => void;
- setSelectedCheckpoint: (checkpointHash: string | null) => void;
- setSelectedModel: (type: string, hash: string | undefined) => void;
- setUserOverride: (type: string, hash: string) => void;
- clearUserOverride: (type: string) => void;
- performAutoSelection: (checkpointModel: ModelInfo) => Promise<void>;
- clearWarnings: () => void;
- resetSelection: () => void;
- validateSelection: (checkpointModel: ModelInfo) => {
- isValid: boolean;
- missingRequired: string[];
- warnings: string[];
- };
- };
- }
- // Create context
- const ModelSelectionContext = createContext<ModelSelectionContextType | null>(null);
- // Provider component
- interface ModelSelectionProviderProps {
- children: ReactNode;
- }
- export function ModelSelectionProvider({ children }: ModelSelectionProviderProps) {
- const [state, dispatch] = useReducer(modelSelectionReducer, initialState);
- const autoSelectorRef = React.useRef<AutoModelSelector>(new AutoModelSelector());
- // Update auto selector when models change
- useEffect(() => {
- autoSelectorRef.current.updateModels(state.availableModels);
- }, [state.availableModels]);
- // Actions
- const actions = useMemo(() => ({
- setModels: (models: ModelInfo[]) => {
- dispatch({ type: 'SET_MODELS', payload: models });
- },
- setSelectedCheckpoint: (checkpointHash: string | null) => {
- dispatch({ type: 'SET_SELECTED_CHECKPOINT', payload: checkpointHash });
- },
- setSelectedModel: (type: string, hash: string | undefined) => {
- dispatch({ type: 'SET_SELECTED_MODEL', payload: { type, hash } });
- },
- setUserOverride: (type: string, hash: string) => {
- dispatch({ type: 'SET_USER_OVERRIDE', payload: { type, hash } });
- },
- clearUserOverride: (type: string) => {
- dispatch({ type: 'CLEAR_USER_OVERRIDE', payload: type });
- },
- performAutoSelection: async (checkpointModel: ModelInfo) => {
- try {
- dispatch({ type: 'SET_AUTO_SELECTING', payload: true });
- dispatch({ type: 'SET_LOADING', payload: true });
- dispatch({ type: 'SET_ERROR', payload: null });
- const autoSelectionState = await autoSelectorRef.current.selectModels(checkpointModel);
-
- dispatch({ type: 'SET_AUTO_SELECTION_STATE', payload: autoSelectionState });
- } catch (error) {
- const errorMessage = error instanceof Error ? error.message : 'Auto-selection failed';
- dispatch({ type: 'SET_ERROR', payload: errorMessage });
- } finally {
- dispatch({ type: 'SET_AUTO_SELECTING', payload: false });
- dispatch({ type: 'SET_LOADING', payload: false });
- }
- },
- clearWarnings: () => {
- dispatch({ type: 'CLEAR_WARNINGS' });
- },
- resetSelection: () => {
- dispatch({ type: 'RESET_SELECTION' });
- },
- validateSelection: (checkpointModel: ModelInfo) => {
- return autoSelectorRef.current.validateSelection(checkpointModel, state.selectedModels);
- },
- }), [state.selectedModels]);
- // Auto-select when checkpoint changes
- useEffect(() => {
- if (state.selectedCheckpoint && state.availableModels.length > 0) {
- const checkpointModel = state.availableModels.find(
- model => (model.sha256_short === state.selectedCheckpoint || model.sha256 === state.selectedCheckpoint || model.id === state.selectedCheckpoint) &&
- (model.type === 'checkpoint' || model.type === 'stable-diffusion')
- );
- if (checkpointModel) {
- actions.performAutoSelection(checkpointModel);
- }
- }
- }, [state.selectedCheckpoint, state.availableModels, actions]);
- const value = {
- state,
- actions,
- };
- return (
- <ModelSelectionContext.Provider value={value}>
- {children}
- </ModelSelectionContext.Provider>
- );
- }
- // Hook to use the context
- export function useModelSelection() {
- const context = useContext(ModelSelectionContext);
- if (!context) {
- throw new Error('useModelSelection must be used within a ModelSelectionProvider');
- }
- return context;
- }
- // Helper hook for checkpoint selection
- export function useCheckpointSelection() {
- const { state, actions } = useModelSelection();
-
- const checkpointModels = state.availableModels.filter(model =>
- model.type === 'checkpoint' || model.type === 'stable-diffusion'
- );
- const selectedCheckpointModel = state.selectedCheckpoint
- ? state.availableModels.find(model =>
- model.sha256_short === state.selectedCheckpoint ||
- model.sha256 === state.selectedCheckpoint ||
- model.id === state.selectedCheckpoint ||
- model.name === state.selectedCheckpoint
- )
- : null;
- return {
- checkpointModels,
- selectedCheckpointModel,
- selectedCheckpoint: state.selectedCheckpoint,
- setSelectedCheckpoint: actions.setSelectedCheckpoint,
- autoSelectedModels: state.autoSelectedModels,
- userOverrides: state.userOverrides,
- isAutoSelecting: state.isAutoSelecting,
- warnings: state.warnings,
- error: state.error,
- };
- }
- // Helper hook for model type selection
- export function useModelTypeSelection(modelType: string) {
- const { state, actions } = useModelSelection();
-
- const availableModels = state.availableModels.filter(model =>
- model.type.toLowerCase() === modelType.toLowerCase()
- );
- const selectedModel = state.selectedModels[modelType];
- const isUserOverride = !!state.userOverrides[modelType];
- const isAutoSelected = !!state.autoSelectedModels[modelType] && !isUserOverride;
- return {
- availableModels,
- selectedModel,
- isUserOverride,
- isAutoSelected,
- setSelectedModel: (hash: string | undefined) => actions.setSelectedModel(modelType, hash),
- setUserOverride: (hash: string) => actions.setUserOverride(modelType, hash),
- clearUserOverride: () => actions.clearUserOverride(modelType),
- };
- }
|