Browse Source

upscaler progress

Fszontagh 3 months ago
parent
commit
af2d50723f
4 changed files with 43 additions and 22 deletions
  1. 8 2
      include/stable_diffusion_wrapper.h
  2. 11 15
      src/generation_queue.cpp
  3. 1 1
      src/server.cpp
  4. 23 4
      src/stable_diffusion_wrapper.cpp

+ 8 - 2
include/stable_diffusion_wrapper.h

@@ -229,6 +229,8 @@ public:
      * @param nThreads Number of threads (-1 for auto)
      * @param offloadParamsToCpu Offload parameters to CPU
      * @param direct Use direct mode
+     * @param progressCallback Optional progress callback
+     * @param userData User data for progress callback
      * @return GeneratedImage Upscaled image
      */
     GeneratedImage upscaleImage(
@@ -240,7 +242,9 @@ public:
         uint32_t upscaleFactor,
         int nThreads = -1,
         bool offloadParamsToCpu = false,
-        bool direct = false);
+        bool direct = false,
+        ProgressCallback progressCallback = nullptr,
+        void* userData = nullptr);
 
     // File-based overload for upscaler input
     GeneratedImage upscaleImage(
@@ -249,7 +253,9 @@ public:
         uint32_t upscaleFactor,
         int nThreads = -1,
         bool offloadParamsToCpu = false,
-        bool direct = false);
+        bool direct = false,
+        ProgressCallback progressCallback = nullptr,
+        void* userData = nullptr);
 
     /**
      * @brief Get the last error message

+ 11 - 15
src/generation_queue.cpp

@@ -14,11 +14,7 @@
 #include "stable_diffusion_wrapper.h"
 #include "utils.h"
 
-#define STB_IMAGE_WRITE_IMPLEMENTATION
-#include "../stable-diffusion.cpp-src/thirdparty/stb_image_write.h"
-
-#define STB_IMAGE_IMPLEMENTATION
-#include "../stable-diffusion.cpp-src/thirdparty/stb_image.h"
+#include "../build/stable-diffusion.cpp-src/thirdparty/stb_image_write.h"
 
 class GenerationQueue::Impl {
 public:
@@ -503,19 +499,17 @@ public:
                     {
                         // For upscaler, create a progress callback and temporary wrapper instance
                         StableDiffusionWrapper tempWrapper;
-                        
-                        // Create progress callback for upscaler (no model loading phase)
+
+                        // Create progress callback for upscaler
                         auto progressCallback = [this, jobId = request.id](int step, int totalSteps, float stepTime, void* userData) {
-                            // For upscaler, totalSteps is 0 (no generation steps), so we show progress based on time
+                            // For upscaler, we get proper progress from the upscaleImage method
                             auto currentTime     = std::chrono::system_clock::now();
                             uint64_t timeElapsed = std::chrono::duration_cast<std::chrono::milliseconds>(currentTime - *static_cast<std::chrono::system_clock::time_point*>(userData)).count();
-                            
-                            // Estimate progress based on time (upscaling typically takes a few seconds)
-                            float estimatedProgress = std::min(0.95f, static_cast<float>(timeElapsed) / 5000.0f); // Assume 5 seconds max
-                            
-                            updateGenerationProgress(jobId, 0, 0, stepTime, timeElapsed);
+
+                            // Use the progress reported by the upscaleImage method
+                            updateGenerationProgress(jobId, step, totalSteps, stepTime, timeElapsed);
                         };
-                        
+
                         auto upscaledImage = tempWrapper.upscaleImage(
                             request.esrganPath,
                             request.initImageData,
@@ -525,7 +519,9 @@ public:
                             request.upscaleFactor,
                             request.nThreads,
                             request.offloadParamsToCpu,
-                            request.diffusionConvDirect);
+                            request.diffusionConvDirect,
+                            progressCallback,
+                            &generationStartTime);
                         generatedImages.push_back(upscaledImage);
                     }
                     break;

+ 1 - 1
src/server.cpp

@@ -23,7 +23,7 @@
 #include <netinet/in.h>
 #include <sys/socket.h>
 #include <unistd.h>
-#include "../stable-diffusion.cpp-src/thirdparty/stb_image.h"
+#include "../build/stable-diffusion.cpp-src/thirdparty/stb_image.h"
 
 Server::Server(ModelManager* modelManager, GenerationQueue* generationQueue, const std::string& outputDir, const std::string& uiDir, const ServerConfig& config)
     : m_modelManager(modelManager), m_generationQueue(generationQueue), m_isRunning(false), m_startupFailed(false), m_port(config.port), m_outputDir(outputDir), m_uiDir(uiDir), m_userManager(nullptr), m_authMiddleware(nullptr), m_config(config) {

+ 23 - 4
src/stable_diffusion_wrapper.cpp

@@ -8,7 +8,7 @@
 #include "model_detector.h"
 
 #define STB_IMAGE_RESIZE_IMPLEMENTATION
-#include "../stable-diffusion.cpp-src/thirdparty/stb_image_resize.h"
+#include "../build/stable-diffusion.cpp-src/thirdparty/stb_image_resize.h"
 
 extern "C" {
 #include "stable-diffusion.h"
@@ -951,7 +951,9 @@ public:
         uint32_t upscaleFactor,
         int nThreads,
         bool offloadParamsToCpu,
-        bool direct) {
+        bool direct,
+        StableDiffusionWrapper::ProgressCallback progressCallback,
+        void* userData) {
         StableDiffusionWrapper::GeneratedImage result;
 
         auto startTime = std::chrono::high_resolution_clock::now();
@@ -960,6 +962,11 @@ public:
         LOG_DEBUG("Input: " + std::to_string(inputWidth) + "x" + std::to_string(inputHeight) + "x" + std::to_string(inputChannels));
         LOG_DEBUG("Upscale factor: " + std::to_string(upscaleFactor));
 
+        // Report initial progress
+        if (progressCallback) {
+            progressCallback(0, 1, 0.0f, userData);
+        }
+
         // Validate input
         if (inputWidth <= 0 || inputHeight <= 0 || inputChannels <= 0) {
             lastError = "Invalid input image dimensions";
@@ -979,6 +986,11 @@ public:
 
         LOG_DEBUG("Output: " + std::to_string(outputWidth) + "x" + std::to_string(outputHeight) + "x" + std::to_string(inputChannels));
 
+        // Report progress after validation
+        if (progressCallback) {
+            progressCallback(0, 1, 0.1f, userData);
+        }
+
         // Allocate output buffer
         std::vector<uint8_t> resizedData(outputWidth * outputHeight * inputChannels);
 
@@ -1008,6 +1020,11 @@ public:
 
         LOG_DEBUG("Bilinear resize completed successfully in " + std::to_string(duration.count()) + "ms");
 
+        // Report completion progress
+        if (progressCallback) {
+            progressCallback(1, 1, 1.0f, userData);
+        }
+
         // Return resized image
         result.width = outputWidth;
         result.height = outputHeight;
@@ -1195,9 +1212,11 @@ StableDiffusionWrapper::GeneratedImage StableDiffusionWrapper::upscaleImage(
     uint32_t upscaleFactor,
     int nThreads,
     bool offloadParamsToCpu,
-    bool direct) {
+    bool direct,
+    ProgressCallback progressCallback,
+    void* userData) {
     std::lock_guard<std::mutex> lock(wrapperMutex);
-    return pImpl->upscaleImage(esrganPath, inputData, inputWidth, inputHeight, inputChannels, upscaleFactor, nThreads, offloadParamsToCpu, direct);
+    return pImpl->upscaleImage(esrganPath, inputData, inputWidth, inputHeight, inputChannels, upscaleFactor, nThreads, offloadParamsToCpu, direct, progressCallback, userData);
 }
 
 std::string StableDiffusionWrapper::getLastError() const {