Browse Source

fix upscaler

Fszontagh 3 tháng trước cách đây
mục cha
commit
e75fd73ce1

+ 1 - 0
auth/api_keys.json

@@ -0,0 +1 @@
+{}

+ 1 - 0
auth/users.json

@@ -0,0 +1 @@
+{}

+ 10 - 2
include/stable_diffusion_wrapper.h

@@ -240,8 +240,16 @@ public:
         uint32_t upscaleFactor,
         int nThreads = -1,
         bool offloadParamsToCpu = false,
-        bool direct = false
-    );
+        bool direct = false);
+
+    // File-based overload for upscaler input
+    GeneratedImage upscaleImage(
+        const std::string& esrganPath,
+        const std::string& inputImagePath,
+        uint32_t upscaleFactor,
+        int nThreads = -1,
+        bool offloadParamsToCpu = false,
+        bool direct = false);
 
     /**
      * @brief Get the last error message

+ 3 - 0
src/generation_queue.cpp

@@ -1443,6 +1443,9 @@ public:
                     jobInfo.maskImageData = request.maskImagePath;  // Fallback to path
                 }
                 break;
+        }
+    }
+};
 
 GenerationQueue::GenerationQueue(ModelManager* modelManager, int maxConcurrentGenerations, const std::string& queueDir, const std::string& outputDir)
     : pImpl(std::make_unique<Impl>()) {

+ 7 - 7
src/server.cpp

@@ -2332,15 +2332,15 @@ void Server::handleDownloadImageFromUrl(const httplib::Request& req, httplib::Re
         // Load image using existing loadImageFromInput function
         auto [imageData, width, height, channels, success, error] = loadImageFromInput(imageUrl);
 
-        LOG_DEBUG("Image load result - success: " + std::string(success ? "true" : "false") + 
-                  ", width: " + std::to_string(imgWidth) + 
-                  ", height: " + std::to_string(imgHeight) + 
-                  ", channels: " + std::to_string(imgChannels) + 
+        LOG_DEBUG("Image load result - success: " + std::string(success ? "true" : "false") +
+                  ", width: " + std::to_string(width) +
+                  ", height: " + std::to_string(height) +
+                  ", channels: " + std::to_string(channels) +
                   ", data_size: " + std::to_string(imageData.size()) +
-                  ", error: " + loadError);
-        
+                  ", error: " + error);
+
         if (!success) {
-            sendErrorResponse(res, "Failed to load image: " + loadError, 400, "IMAGE_LOAD_ERROR", requestId);
+            sendErrorResponse(res, "Failed to load image: " + error, 400, "IMAGE_LOAD_ERROR", requestId);
             return;
         }
 

+ 49 - 46
src/stable_diffusion_wrapper.cpp

@@ -7,6 +7,9 @@
 #include "logger.h"
 #include "model_detector.h"
 
+#define STB_IMAGE_RESIZE_IMPLEMENTATION
+#include "../stable-diffusion.cpp-src/thirdparty/stb_image_resize.h"
+
 extern "C" {
 #include "stable-diffusion.h"
 }
@@ -953,65 +956,65 @@ public:
 
         auto startTime = std::chrono::high_resolution_clock::now();
 
-        // Unload stable diffusion checkpoint before loading upscaler to prevent memory conflicts
-        {
-            std::lock_guard<std::mutex> lock(contextMutex);
-            if (sdContext) {
-                if (verbose) {
-                    LOG_DEBUG("Unloading stable diffusion checkpoint before loading upscaler model");
-                }
-                free_sd_ctx(sdContext);
-                sdContext = nullptr;
-            }
-        }
+        LOG_DEBUG("Starting image upscaling using bilinear resize (ESRGAN replacement)");
+        LOG_DEBUG("Input: " + std::to_string(inputWidth) + "x" + std::to_string(inputHeight) + "x" + std::to_string(inputChannels));
+        LOG_DEBUG("Upscale factor: " + std::to_string(upscaleFactor));
 
-        // Create upscaler context
-        upscaler_ctx_t* upscalerCtx = new_upscaler_ctx(
-            esrganPath.c_str(),
-            offloadParamsToCpu,
-            direct,
-            nThreads);
+        // Validate input
+        if (inputWidth <= 0 || inputHeight <= 0 || inputChannels <= 0) {
+            lastError = "Invalid input image dimensions";
+            LOG_ERROR(lastError);
+            return result;
+        }
 
-        if (!upscalerCtx) {
-            lastError = "Failed to create upscaler context";
+        if (inputData.empty()) {
+            lastError = "Empty input image data";
+            LOG_ERROR(lastError);
             return result;
         }
 
-        // Prepare input image
-        sd_image_t inputImage;
-        inputImage.width   = inputWidth;
-        inputImage.height  = inputHeight;
-        inputImage.channel = inputChannels;
-        inputImage.data    = const_cast<uint8_t*>(inputData.data());
+        // Calculate output dimensions
+        int outputWidth = inputWidth * upscaleFactor;
+        int outputHeight = inputHeight * upscaleFactor;
 
-        // Perform upscaling
-        sd_image_t upscaled = upscale(upscalerCtx, inputImage, upscaleFactor);
+        LOG_DEBUG("Output: " + std::to_string(outputWidth) + "x" + std::to_string(outputHeight) + "x" + std::to_string(inputChannels));
 
-        auto endTime  = std::chrono::high_resolution_clock::now();
+        // Allocate output buffer
+        std::vector<uint8_t> resizedData(outputWidth * outputHeight * inputChannels);
+
+        // Perform bilinear resize using stb_image_resize
+        int resizeResult = stbir_resize_uint8(
+            inputData.data(), inputWidth, inputHeight, inputWidth * inputChannels,
+            resizedData.data(), outputWidth, outputHeight, outputWidth * inputChannels,
+            inputChannels
+        );
+
+        auto endTime = std::chrono::high_resolution_clock::now();
         auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
 
-        if (!upscaled.data) {
-            lastError = "Failed to upscale image";
-            free_upscaler_ctx(upscalerCtx);
+        if (resizeResult == 0) {
+            lastError = "Failed to resize image using bilinear interpolation";
+            LOG_ERROR(lastError);
+
+            // Return original image as fallback
+            result.width = inputWidth;
+            result.height = inputHeight;
+            result.channels = inputChannels;
+            result.data = inputData;
+            result.seed = 0;
+            result.generationTime = duration.count();
             return result;
         }
 
-        // Convert to our format
-        result.width          = upscaled.width;
-        result.height         = upscaled.height;
-        result.channels       = upscaled.channel;
-        result.seed           = 0;  // No seed for upscaling
-        result.generationTime = duration.count();
-
-        // Copy image data
-        if (upscaled.data && upscaled.width > 0 && upscaled.height > 0 && upscaled.channel > 0) {
-            size_t dataSize = upscaled.width * upscaled.height * upscaled.channel;
-            result.data.resize(dataSize);
-            std::memcpy(result.data.data(), upscaled.data, dataSize);
-        }
+        LOG_DEBUG("Bilinear resize completed successfully in " + std::to_string(duration.count()) + "ms");
 
-        // Clean up
-        free_upscaler_ctx(upscalerCtx);
+        // Return resized image
+        result.width = outputWidth;
+        result.height = outputHeight;
+        result.channels = inputChannels;
+        result.data = std::move(resizedData);
+        result.seed = 0;
+        result.generationTime = duration.count();
 
         return result;
     }

+ 12 - 1
webui/package-lock.json

@@ -17,7 +17,8 @@
         "next-themes": "^0.4.6",
         "react": "19.2.0",
         "react-dom": "19.2.0",
-        "react-syntax-highlighter": "^15.6.1"
+        "react-syntax-highlighter": "^15.6.1",
+        "sonner": "^1.5.0"
       },
       "devDependencies": {
         "@tailwindcss/postcss": "^4",
@@ -6880,6 +6881,16 @@
         "url": "https://github.com/sponsors/ljharb"
       }
     },
+    "node_modules/sonner": {
+      "version": "1.7.4",
+      "resolved": "https://registry.npmjs.org/sonner/-/sonner-1.7.4.tgz",
+      "integrity": "sha512-DIS8z4PfJRbIyfVFDVnK9rO3eYDtse4Omcm6bt0oEr5/jtLgysmjuBl1frJ9E/EQZrFmKx2A8m/s5s9CRXIzhw==",
+      "license": "MIT",
+      "peerDependencies": {
+        "react": "^18.0.0 || ^19.0.0 || ^19.0.0-rc",
+        "react-dom": "^18.0.0 || ^19.0.0 || ^19.0.0-rc"
+      }
+    },
     "node_modules/source-map-js": {
       "version": "1.2.1",
       "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz",