Ver Fonte

implemenet progress tracking #26

AI Agent 001 há 3 meses atrás
pai
commit
88342216ab
2 ficheiros alterados com 59 adições e 7 exclusões
  1. 5 0
      include/generation_queue.h
  2. 54 7
      src/generation_queue.cpp

+ 5 - 0
include/generation_queue.h

@@ -229,6 +229,11 @@ struct JobInfo {
     std::vector<std::string> outputFiles; ///< Paths to generated output files
     std::string errorMessage;          ///< Error message if job failed
     float progress = 0.0f;             ///< Generation progress (0.0 to 1.0)
+    int currentStep = 0;               ///< Current step in generation
+    int totalSteps = 0;                ///< Total steps in generation
+    int64_t timeElapsed = 0;           ///< Time elapsed in milliseconds
+    int64_t timeRemaining = 0;         ///< Estimated time remaining in milliseconds
+    float speed = 0.0f;                ///< Generation speed in steps per second
 };
 
 /**

+ 54 - 7
src/generation_queue.cpp

@@ -129,6 +129,12 @@ public:
             if (activeJobs.find(request.id) != activeJobs.end()) {
                 activeJobs[request.id].status = GenerationStatus::PROCESSING;
                 activeJobs[request.id].startTime = startTime;
+                activeJobs[request.id].progress = 0.0f;
+                activeJobs[request.id].currentStep = 0;
+                activeJobs[request.id].totalSteps = 0;
+                activeJobs[request.id].timeElapsed = 0;
+                activeJobs[request.id].timeRemaining = 0;
+                activeJobs[request.id].speed = 0.0f;
                 saveJobToFile(activeJobs[request.id]);
             }
         }
@@ -139,8 +145,8 @@ public:
                   << " (prompt: " << request.prompt.substr(0, 50)
                   << (request.prompt.length() > 50 ? "..." : "") << ")" << std::endl;
 
-        // Real generation logic using stable-diffusion.cpp
-        GenerationResult result = performActualGeneration(request);
+        // Real generation logic using stable-diffusion.cpp with progress tracking
+        GenerationResult result = performActualGeneration(request, request.id);
 
         auto endTime = std::chrono::steady_clock::now();
         auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
@@ -152,6 +158,14 @@ public:
             if (activeJobs.find(request.id) != activeJobs.end()) {
                 activeJobs[request.id].status = result.success ? GenerationStatus::COMPLETED : GenerationStatus::FAILED;
                 activeJobs[request.id].endTime = endTime;
+                
+                // Set final progress to 100% if successful
+                if (result.success) {
+                    activeJobs[request.id].progress = 1.0f;
+                    if (activeJobs[request.id].totalSteps > 0) {
+                        activeJobs[request.id].currentStep = activeJobs[request.id].totalSteps;
+                    }
+                }
 
                 // Store output files and error message
                 activeJobs[request.id].outputFiles = result.imagePaths;
@@ -181,7 +195,32 @@ public:
         std::cout << std::endl;
     }
 
-    GenerationResult performActualGeneration(const GenerationRequest& request) {
+    // Progress callback that updates the job info
+    void updateJobProgress(const std::string& jobId, int step, int totalSteps, float progress, uint64_t timeElapsed) {
+        std::lock_guard<std::mutex> lock(jobsMutex);
+        auto it = activeJobs.find(jobId);
+        if (it != activeJobs.end()) {
+            it->second.progress = progress;
+            it->second.currentStep = step;
+            it->second.totalSteps = totalSteps;
+            it->second.timeElapsed = static_cast<int64_t>(timeElapsed);
+            
+            // Calculate time remaining and speed
+            if (step > 0 && timeElapsed > 0) {
+                double avgStepTime = static_cast<double>(timeElapsed) / step;
+                int remainingSteps = totalSteps - step;
+                it->second.timeRemaining = static_cast<int64_t>(avgStepTime * remainingSteps);
+                it->second.speed = 1000.0 / avgStepTime; // steps per second
+            }
+            
+            // Save progress to file periodically (every 10 steps or on significant progress changes)
+            if (step % 10 == 0 || progress >= 0.99f) {
+                saveJobToFile(it->second);
+            }
+        }
+    }
+
+    GenerationResult performActualGeneration(const GenerationRequest& request, const std::string& jobId) {
         GenerationResult result;
         result.requestId = request.id;
         result.success = false;
@@ -252,13 +291,19 @@ public:
         }
         result.actualSeed = params.seed;
 
-        // Generate images based on request type
+        // Generate images based on request type with progress tracking
         try {
             std::vector<StableDiffusionWrapper::GeneratedImage> generatedImages;
 
+            // Create progress callback that updates job info
+            auto progressCallback = [this, jobId](int step, int totalSteps, float progress, uint64_t timeElapsed) -> bool {
+                updateJobProgress(jobId, step, totalSteps, progress, timeElapsed);
+                return true; // Continue generation
+            };
+
             switch (request.requestType) {
                 case GenerationRequest::RequestType::TEXT2IMG:
-                    generatedImages = modelWrapper->generateImage(params);
+                    generatedImages = modelWrapper->generateImage(params, progressCallback);
                     break;
 
                 case GenerationRequest::RequestType::IMG2IMG:
@@ -270,7 +315,8 @@ public:
                         params,
                         request.initImageData,
                         request.initImageWidth,
-                        request.initImageHeight
+                        request.initImageHeight,
+                        progressCallback
                     );
                     break;
 
@@ -283,7 +329,8 @@ public:
                         params,
                         request.controlImageData,
                         request.controlImageWidth,
-                        request.controlImageHeight
+                        request.controlImageHeight,
+                        progressCallback
                     );
                     break;