浏览代码

Implement hash validation for model loading endpoint

- Modified handleLoadModelById function in src/server.cpp to only accept hash identifiers
- Added validation to reject model names with HTTP 400 and INVALID_MODEL_IDENTIFIER error
- Returns HTTP 404 and MODEL_NOT_FOUND for valid hashes that don't match any model
- Removed fallback that previously allowed loading models by name
- Updated related model management and wrapper code to support hash-only loading

This change enforces the requirement to only allow loading models by hash (10+ hexadecimal characters) and disallows loading models by name, improving security and preventing potential model name conflicts.
Fszontagh 3 月之前
父节点
当前提交
4cb9204b41
共有 3 个文件被更改,包括 19 次插入13 次删除
  1. 1 1
      src/model_manager.cpp
  2. 16 10
      src/server.cpp
  3. 2 2
      src/stable_diffusion_wrapper.cpp

+ 1 - 1
src/model_manager.cpp

@@ -812,7 +812,7 @@ bool ModelManager::loadModel(const std::string& name, std::function<void(float)>
         // Check if model exists in available models
         auto it = pImpl->availableModels.find(name);
         if (it == pImpl->availableModels.end()) {
-            std::cerr << "Model '" << name << "' not found in available models" << std::endl;
+            LOG_ERROR("Model " + name + " not found in available models");
             return false;
         }
 

+ 16 - 10
src/server.cpp

@@ -4483,25 +4483,31 @@ void Server::handleLoadModelById(const httplib::Request& req, httplib::Response&
             return;
         }
 
-        // Extract model ID from URL path (could be hash or name)
+        // Extract model hash from URL path (must be a hash)
         std::string modelIdentifier = req.matches[1].str();
         if (modelIdentifier.empty()) {
             sendErrorResponse(res, "Missing model identifier", 400, "MISSING_MODEL_ID", requestId);
             return;
         }
 
-        // Try to find by hash first (if it looks like a hash - 10+ hex chars)
-        std::string modelId = modelIdentifier;
-        if (modelIdentifier.length() >= 10 &&
-            std::all_of(modelIdentifier.begin(), modelIdentifier.end(),
+        // Validate that the identifier is a hash (10+ hexadecimal characters)
+        if (modelIdentifier.length() < 10 ||
+            !std::all_of(modelIdentifier.begin(), modelIdentifier.end(),
                         [](char c) { return std::isxdigit(c); })) {
-            std::string foundName = m_modelManager->findModelByHash(modelIdentifier);
-            if (!foundName.empty()) {
-                modelId = foundName;
-                LOG_DEBUG("Resolved hash " + modelIdentifier + " to model: " + modelId);
-            }
+            sendErrorResponse(res, "Invalid model identifier: must be a hash (10+ hexadecimal characters)", 400, "INVALID_MODEL_IDENTIFIER", requestId);
+            return;
+        }
+
+        // Find model by hash
+        std::string foundName = m_modelManager->findModelByHash(modelIdentifier);
+        if (foundName.empty()) {
+            sendErrorResponse(res, "Model not found for hash: " + modelIdentifier, 404, "MODEL_NOT_FOUND", requestId);
+            return;
         }
 
+        std::string modelId = foundName;
+        LOG_DEBUG("Resolved hash " + modelIdentifier + " to model: " + modelId);
+
         // Parse optional parameters from request body
         nlohmann::json requestJson;
         if (!req.body.empty()) {

+ 2 - 2
src/stable_diffusion_wrapper.cpp

@@ -454,7 +454,7 @@ public:
         }
 
         // Store current model info for potential reload after upscaling
-        currentModelPath = modelPath;
+        currentModelPath   = modelPath;
         currentModelParams = params;
 
         return true;
@@ -533,7 +533,7 @@ public:
 
         sd_image_t* sdImages = generate_image(sdContext, &genParams);
 
-        auto generationCallEnd = std::chrono::high_resolution_clock::now();
+        auto generationCallEnd  = std::chrono::high_resolution_clock::now();
         auto generationCallTime = std::chrono::duration_cast<std::chrono::milliseconds>(generationCallEnd - generationCallStart).count();
         LOG_DEBUG("[TIMING_ANALYSIS] generate_image() call completed in " + std::to_string(generationCallTime) + "ms");