Browse Source

Make SD1x and SDXL VAE models optional instead of required

- Modified model_detector.cpp to set needsVAE=false for SD1.5, SD2.1, SDXL_BASE, and SDXL_REFINER architectures
- Updated model_manager.cpp to not add VAE to requiredModels list for these architectures
- Modified stable_diffusion_wrapper.cpp to handle missing VAE files gracefully
- VAE models are now recommended but not required for SD1x and SDXL models
- The API will continue to function properly without VAE models
Fszontagh 3 tháng trước cách đây
mục cha
commit
a58f73db87

+ 3 - 3
src/model_detector.cpp

@@ -50,7 +50,7 @@ ModelDetectionResult ModelDetector::detectModel(const std::string& modelPath) {
         case ModelArchitecture::SD_1_5:
             result.textEncoderDim              = 768;
             result.unetChannels                = 1280;
-            result.needsVAE                    = true;
+            result.needsVAE                    = false;
             result.recommendedVAE              = "vae-ft-mse-840000-ema-pruned.safetensors";
             result.needsTAESD                  = true;
             result.suggestedParams["vae_flag"] = "--vae";
@@ -59,7 +59,7 @@ ModelDetectionResult ModelDetector::detectModel(const std::string& modelPath) {
         case ModelArchitecture::SD_2_1:
             result.textEncoderDim              = 1024;
             result.unetChannels                = 1280;
-            result.needsVAE                    = true;
+            result.needsVAE                    = false;
             result.recommendedVAE              = "vae-ft-ema-560000.safetensors";
             result.needsTAESD                  = true;
             result.suggestedParams["vae_flag"] = "--vae";
@@ -70,7 +70,7 @@ ModelDetectionResult ModelDetector::detectModel(const std::string& modelPath) {
             result.textEncoderDim              = 1280;
             result.unetChannels                = 2560;
             result.hasConditioner              = true;
-            result.needsVAE                    = true;
+            result.needsVAE                    = false;
             result.recommendedVAE              = "sdxl_vae.safetensors";
             result.needsTAESD                  = true;
             result.suggestedParams["vae_flag"] = "--vae";

+ 6 - 7
src/model_manager.cpp

@@ -484,9 +484,8 @@ public:
                                         }
 
                                         // Build list of required models based on detection
-                                        if (detection.needsVAE && !detection.recommendedVAE.empty()) {
-                                            info.requiredModels.push_back("VAE: " + detection.recommendedVAE);
-                                        }
+                                        // Note: VAE is now optional for SD1x and SDXL models, so we don't add it to requiredModels
+                                        // The VAE will still be recommended but not required
 
                                         // Add CLIP-L if required
                                         if (detection.suggestedParams.count("clip_l_required")) {
@@ -661,7 +660,8 @@ bool ModelManager::loadModel(const std::string& name, const std::string& path, M
                 }
 
                 // Set additional model paths based on detection
-                if (detection.needsVAE && !detection.recommendedVAE.empty()) {
+                // VAE is now optional for SD1x and SDXL models, but we still set it if available
+                if (!detection.recommendedVAE.empty()) {
                     loadParams.vaePath = detection.recommendedVAE;
                 }
 
@@ -1329,9 +1329,8 @@ void ModelManager::ModelDetectionCache::cacheDetectionResult(
     }
     
     // Build list of required models
-    if (detection.needsVAE && !detection.recommendedVAE.empty()) {
-        entry.requiredModels.push_back("VAE: " + detection.recommendedVAE);
-    }
+    // Note: VAE is now optional for SD1x and SDXL models, so we don't add it to requiredModels
+    // The VAE will still be recommended but not required
     
     if (detection.suggestedParams.count("clip_l_required")) {
         entry.requiredModels.push_back("CLIP-L: " + detection.suggestedParams.at("clip_l_required"));

+ 15 - 3
src/stable_diffusion_wrapper.cpp

@@ -201,8 +201,15 @@ public:
             std::cout << "Using CLIP-G path: " << std::filesystem::absolute(persistentClipGPath) << std::endl;
         }
         if (!persistentVaePath.empty()) {
-            ctxParams.vae_path = persistentVaePath.c_str();
-            std::cout << "Using VAE path: " << std::filesystem::absolute(persistentVaePath) << std::endl;
+            // Check if VAE file exists before setting it
+            if (std::filesystem::exists(persistentVaePath)) {
+                ctxParams.vae_path = persistentVaePath.c_str();
+                std::cout << "Using VAE path: " << std::filesystem::absolute(persistentVaePath) << std::endl;
+            } else {
+                std::cout << "VAE file not found: " << std::filesystem::absolute(persistentVaePath)
+                         << " - continuing without VAE" << std::endl;
+                ctxParams.vae_path = nullptr;
+            }
         }
         if (!persistentTaesdPath.empty()) {
             ctxParams.taesd_path = persistentTaesdPath.c_str();
@@ -259,7 +266,12 @@ public:
                     ctxParams.clip_g_path = persistentClipGPath.c_str();
                 }
                 if (!persistentVaePath.empty()) {
-                    ctxParams.vae_path = persistentVaePath.c_str();
+                    // Check if VAE file exists before setting it
+                    if (std::filesystem::exists(persistentVaePath)) {
+                        ctxParams.vae_path = persistentVaePath.c_str();
+                    } else {
+                        ctxParams.vae_path = nullptr;
+                    }
                 }
                 if (!persistentTaesdPath.empty()) {
                     ctxParams.taesd_path = persistentTaesdPath.c_str();

+ 2 - 103
webui/app/img2img/page.tsx

@@ -14,9 +14,6 @@ import { apiClient, type JobInfo } from '@/lib/api';
 import { Loader2, Download, X } from 'lucide-react';
 import { downloadImage, downloadAuthenticatedImage, fileToBase64 } from '@/lib/utils';
 import { useLocalStorage } from '@/lib/hooks';
-import { ModelSelectionProvider, useModelSelection, useCheckpointSelection, useModelTypeSelection } from '@/contexts/model-selection-context';
-import { EnhancedModelSelect, EnhancedModelSelectGroup } from '@/components/enhanced-model-select';
-import { ModelSelectionWarning, AutoSelectionStatus } from '@/components/model-selection-indicator';
 
 type Img2ImgFormData = {
   prompt: string;
@@ -45,26 +42,6 @@ const defaultFormData: Img2ImgFormData = {
 };
 
 function Img2ImgForm() {
-  const { state, actions } = useModelSelection();
-  const {
-    checkpointModels,
-    selectedCheckpointModel,
-    selectedCheckpoint,
-    setSelectedCheckpoint,
-    isAutoSelecting,
-    warnings,
-    error: checkpointError
-  } = useCheckpointSelection();
-  
-  const {
-    availableModels: vaeModels,
-    selectedModel: selectedVae,
-    isUserOverride: isVaeUserOverride,
-    isAutoSelected: isVaeAutoSelected,
-    setSelectedModel: setSelectedVae,
-    setUserOverride: setVaeUserOverride,
-    clearUserOverride: clearVaeUserOverride,
-  } = useModelTypeSelection('vae');
 
   const [formData, setFormData] = useLocalStorage<Img2ImgFormData>(
     'img2img-form-data',
@@ -86,12 +63,10 @@ function Img2ImgForm() {
   useEffect(() => {
     const loadModels = async () => {
       try {
-        const [modelsData, loras, embeds] = await Promise.all([
-          apiClient.getModels(), // Get all models with enhanced info
+        const [loras, embeds] = await Promise.all([
           apiClient.getModels('lora'),
           apiClient.getModels('embedding'),
         ]);
-        actions.setModels(modelsData.models);
         setLoraModels(loras.models.map(m => m.name));
         setEmbeddings(embeds.models.map(m => m.name));
       } catch (err) {
@@ -99,17 +74,8 @@ function Img2ImgForm() {
       }
     };
     loadModels();
-  }, [actions]);
+  }, []);
 
-  // Update form data when checkpoint changes
-  useEffect(() => {
-    if (selectedCheckpoint) {
-      setFormData(prev => ({
-        ...prev,
-        model: selectedCheckpoint,
-      }));
-    }
-  }, [selectedCheckpoint, setFormData]);
 
   const handleInputChange = (
     e: React.ChangeEvent<HTMLInputElement | HTMLTextAreaElement | HTMLSelectElement>
@@ -276,20 +242,8 @@ function Img2ImgForm() {
     setJobInfo(null);
 
     try {
-      // Validate model selection
-      if (selectedCheckpointModel) {
-        const validation = actions.validateSelection(selectedCheckpointModel);
-        if (!validation.isValid) {
-          setError(`Missing required models: ${validation.missingRequired.join(', ')}`);
-          setLoading(false);
-          return;
-        }
-      }
-
       const requestData = {
         ...formData,
-        model: selectedCheckpoint || undefined,
-        vae: selectedVae || undefined,
       };
 
       const job = await apiClient.img2img(requestData);
@@ -485,61 +439,6 @@ function Img2ImgForm() {
                   </select>
                 </div>
 
-                {/* Model Selection Section */}
-                <EnhancedModelSelectGroup
-                  title="Model Selection"
-                  description="Select the checkpoint and additional models for generation"
-                >
-                  {/* Checkpoint Selection */}
-                  <div className="space-y-2">
-                    <Label htmlFor="checkpoint">Checkpoint Model *</Label>
-                    <select
-                      id="checkpoint"
-                      value={selectedCheckpoint || ''}
-                      onChange={(e) => setSelectedCheckpoint(e.target.value || null)}
-                      className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
-                      disabled={isAutoSelecting}
-                    >
-                      <option value="">Select a checkpoint model...</option>
-                      {checkpointModels.map((model) => (
-                        <option key={model.id} value={model.name}>
-                          {model.name} {model.loaded ? '(Loaded)' : ''}
-                        </option>
-                      ))}
-                    </select>
-                  </div>
-
-                  {/* VAE Selection */}
-                  <EnhancedModelSelect
-                    modelType="vae"
-                    label="VAE Model"
-                    description="Optional VAE model for improved image quality"
-                    value={selectedVae}
-                    availableModels={vaeModels}
-                    isAutoSelected={isVaeAutoSelected}
-                    isUserOverride={isVaeUserOverride}
-                    isLoading={isAutoSelecting}
-                    onValueChange={setSelectedVae}
-                    onSetUserOverride={setVaeUserOverride}
-                    onClearOverride={clearVaeUserOverride}
-                    placeholder="Use default VAE"
-                  />
-
-                  {/* Auto-selection Status */}
-                  <div className="pt-2">
-                    <AutoSelectionStatus
-                      isAutoSelecting={isAutoSelecting}
-                      hasAutoSelection={Object.keys(state.autoSelectedModels).length > 0}
-                    />
-                  </div>
-
-                  {/* Warnings and Errors */}
-                  <ModelSelectionWarning
-                    warnings={warnings}
-                    errors={error ? [error] : []}
-                    onClearWarnings={actions.clearWarnings}
-                  />
-                </EnhancedModelSelectGroup>
 
                 <div className="flex gap-2">
                   <Button type="submit" disabled={loading || !formData.image || (imageValidation && !imageValidation.isValid)} className="flex-1">

+ 2 - 106
webui/app/text2img/page.tsx

@@ -13,9 +13,6 @@ import { apiClient, type GenerationRequest, type JobInfo, type ModelInfo } from
 import { Loader2, Download, X, Trash2, RotateCcw, Power } from 'lucide-react';
 import { downloadImage, downloadAuthenticatedImage } from '@/lib/utils';
 import { useLocalStorage } from '@/lib/hooks';
-import { ModelSelectionProvider, useModelSelection, useCheckpointSelection, useModelTypeSelection } from '@/contexts/model-selection-context';
-import { EnhancedModelSelect, EnhancedModelSelectGroup } from '@/components/enhanced-model-select';
-import { ModelSelectionWarning, AutoSelectionStatus } from '@/components/model-selection-indicator';
 
 const defaultFormData: GenerationRequest = {
   prompt: '',
@@ -31,26 +28,6 @@ const defaultFormData: GenerationRequest = {
 };
 
 function Text2ImgForm() {
-  const { state, actions } = useModelSelection();
-  const {
-    checkpointModels,
-    selectedCheckpointModel,
-    selectedCheckpoint,
-    setSelectedCheckpoint,
-    isAutoSelecting,
-    warnings,
-    error: checkpointError
-  } = useCheckpointSelection();
-  
-  const {
-    availableModels: vaeModels,
-    selectedModel: selectedVae,
-    isUserOverride: isVaeUserOverride,
-    isAutoSelected: isVaeAutoSelected,
-    setSelectedModel: setSelectedVae,
-    setUserOverride: setVaeUserOverride,
-    clearUserOverride: clearVaeUserOverride,
-  } = useModelTypeSelection('vae');
 
   const [formData, setFormData] = useLocalStorage<GenerationRequest>(
     'text2img-form-data',
@@ -69,16 +46,14 @@ function Text2ImgForm() {
   useEffect(() => {
     const loadOptions = async () => {
       try {
-        const [samplersData, schedulersData, modelsData, loras, embeds] = await Promise.all([
+        const [samplersData, schedulersData, loras, embeds] = await Promise.all([
           apiClient.getSamplers(),
           apiClient.getSchedulers(),
-          apiClient.getModels(), // Get all models with enhanced info
           apiClient.getModels('lora'),
           apiClient.getModels('embedding'),
         ]);
         setSamplers(samplersData);
         setSchedulers(schedulersData);
-        actions.setModels(modelsData.models);
         setLoraModels(loras.models.map(m => m.name));
         setEmbeddings(embeds.models.map(m => m.name));
       } catch (err) {
@@ -86,17 +61,8 @@ function Text2ImgForm() {
       }
     };
     loadOptions();
-  }, [actions]);
+  }, []);
 
-  // Update form data when checkpoint changes
-  useEffect(() => {
-    if (selectedCheckpoint) {
-      setFormData(prev => ({
-        ...prev,
-        model: selectedCheckpoint,
-      }));
-    }
-  }, [selectedCheckpoint, setFormData]);
 
   const handleInputChange = (
     e: React.ChangeEvent<HTMLInputElement | HTMLTextAreaElement | HTMLSelectElement>
@@ -179,20 +145,8 @@ function Text2ImgForm() {
     setJobInfo(null);
 
     try {
-      // Validate model selection
-      if (selectedCheckpointModel) {
-        const validation = actions.validateSelection(selectedCheckpointModel);
-        if (!validation.isValid) {
-          setError(`Missing required models: ${validation.missingRequired.join(', ')}`);
-          setLoading(false);
-          return;
-        }
-      }
-
       const requestData = {
         ...formData,
-        model: selectedCheckpoint || undefined,
-        vae: selectedVae || undefined,
       };
 
       const job = await apiClient.text2img(requestData);
@@ -229,9 +183,6 @@ function Text2ImgForm() {
 
   const handleResetToDefaults = () => {
     setFormData(defaultFormData);
-    setSelectedCheckpoint(null);
-    setSelectedVae('');
-    actions.resetSelection();
   };
 
   const handleServerRestart = async () => {
@@ -434,61 +385,6 @@ function Text2ImgForm() {
                   </select>
                 </div>
 
-                {/* Model Selection Section */}
-                <EnhancedModelSelectGroup
-                  title="Model Selection"
-                  description="Select the checkpoint and additional models for generation"
-                >
-                  {/* Checkpoint Selection */}
-                  <div className="space-y-2">
-                    <Label htmlFor="checkpoint">Checkpoint Model *</Label>
-                    <select
-                      id="checkpoint"
-                      value={selectedCheckpoint || ''}
-                      onChange={(e) => setSelectedCheckpoint(e.target.value || null)}
-                      className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm"
-                      disabled={isAutoSelecting}
-                    >
-                      <option value="">Select a checkpoint model...</option>
-                      {checkpointModels.map((model) => (
-                        <option key={model.id} value={model.name}>
-                          {model.name} {model.loaded ? '(Loaded)' : ''}
-                        </option>
-                      ))}
-                    </select>
-                  </div>
-
-                  {/* VAE Selection */}
-                  <EnhancedModelSelect
-                    modelType="vae"
-                    label="VAE Model"
-                    description="Optional VAE model for improved image quality"
-                    value={selectedVae}
-                    availableModels={vaeModels}
-                    isAutoSelected={isVaeAutoSelected}
-                    isUserOverride={isVaeUserOverride}
-                    isLoading={isAutoSelecting}
-                    onValueChange={setSelectedVae}
-                    onSetUserOverride={setVaeUserOverride}
-                    onClearOverride={clearVaeUserOverride}
-                    placeholder="Use default VAE"
-                  />
-
-                  {/* Auto-selection Status */}
-                  <div className="pt-2">
-                    <AutoSelectionStatus
-                      isAutoSelecting={isAutoSelecting}
-                      hasAutoSelection={Object.keys(state.autoSelectedModels).length > 0}
-                    />
-                  </div>
-
-                  {/* Warnings and Errors */}
-                  <ModelSelectionWarning
-                    warnings={warnings}
-                    errors={error ? [error] : []}
-                    onClearWarnings={actions.clearWarnings}
-                  />
-                </EnhancedModelSelectGroup>
 
                 <div className="space-y-2">
                   <Label htmlFor="batch_count">Batch Count</Label>