HASHING_IMPLEMENTATION_GUIDE.md 9.3 KB

Model Hashing - Remaining Implementation Guide

What's Already Done ✅

  • Data structures in generation_queue.h
  • All hash methods in model_manager.cpp (SHA256, JSON storage, lookup)
  • OpenSSL includes added to model_manager.cpp

What Needs To Be Added

1. CMakeLists.txt - Add OpenSSL Dependency

Find the find_package section and add:

find_package(OpenSSL REQUIRED)

Find the target_link_libraries for stable-diffusion-rest-server and add:

target_link_libraries(stable-diffusion-rest-server
    PRIVATE
    ...
    OpenSSL::Crypto
)

2. generation_queue.cpp - Add Hash Job Support

At the end of the file, add:

std::future<HashResult> GenerationQueue::enqueueHashRequest(const HashRequest& request) {
    auto promise = std::make_shared<std::promise<HashResult>>();
    auto future = promise->get_future();

    std::unique_lock<std::mutex> lock(pImpl->queueMutex);

    // Create a generation request that acts as a placeholder for hash job
    GenerationRequest hashJobPlaceholder;
    hashJobPlaceholder.id = request.id;
    hashJobPlaceholder.prompt = "HASH_JOB"; // Special marker
    hashJobPlaceholder.modelName = request.modelNames.empty() ? "ALL_MODELS" : request.modelNames[0];

    // Store promise for retrieval later
    pImpl->hashPromises[request.id] = promise;
    pImpl->hashRequests[request.id] = request;

    pImpl->requestQueue.push(hashJobPlaceholder);
    pImpl->queueCondition.notify_one();

    std::cout << "Enqueued hash request: " << request.id << std::endl;

    return future;
}

In the worker thread function, modify to detect hash jobs:

// Inside processRequests() or worker thread:
if (request.prompt == "HASH_JOB") {
    // This is a hash job
    auto hashIt = pImpl->hashRequests.find(request.id);
    if (hashIt != pImpl->hashRequests.end()) {
        HashResult result = performHashJob(hashIt->second);

        auto promiseIt = pImpl->hashPromises.find(request.id);
        if (promiseIt != pImpl->hashPromises.end()) {
            promiseIt->second->set_value(result);
            pImpl->hashPromises.erase(promiseIt);
        }
        pImpl->hashRequests.erase(hashIt);
    }
    continue;
}

Add the performHashJob method in the Impl class:

HashResult performHashJob(const HashRequest& request) {
    HashResult result;
    result.requestId = request.id;
    result.success = false;
    result.modelsHashed = 0;

    auto startTime = std::chrono::steady_clock::now();

    if (!modelManager) {
        result.errorMessage = "Model manager not available";
        result.status = GenerationStatus::FAILED;
        return result;
    }

    // Get list of models to hash
    std::vector<std::string> modelsToHash;
    if (request.modelNames.empty()) {
        // Hash all models without hashes
        auto allModels = modelManager->getAllModels();
        for (const auto& [name, info] : allModels) {
            if (info.sha256.empty() || request.forceRehash) {
                modelsToHash.push_back(name);
            }
        }
    } else {
        modelsToHash = request.modelNames;
    }

    std::cout << "Hashing " << modelsToHash.size() << " model(s)..." << std::endl;

    // Hash each model
    for (const auto& modelName : modelsToHash) {
        std::string hash = modelManager->ensureModelHash(modelName, request.forceRehash);
        if (!hash.empty()) {
            result.modelHashes[modelName] = hash;
            result.modelsHashed++;
        } else {
            std::cerr << "Failed to hash model: " << modelName << std::endl;
        }
    }

    auto endTime = std::chrono::steady_clock::now();
    result.hashingTime = std::chrono::duration_cast<std::chrono::milliseconds>(
        endTime - startTime).count();

    result.success = result.modelsHashed > 0;
    result.status = result.success ? GenerationStatus::COMPLETED : GenerationStatus::FAILED;

    if (!result.success) {
        result.errorMessage = "Failed to hash any models";
    }

    return result;
}

Add to the Impl class private members:

std::map<std::string, std::shared_ptr<std::promise<HashResult>>> hashPromises;
std::map<std::string, HashRequest> hashRequests;

3. server.cpp - Add Hash Endpoint

In registerEndpoints(), add:

m_httpServer->Post("/api/models/hash", [this](const httplib::Request& req, httplib::Response& res) {
    handleHashModels(req, res);
});

Implement the handler:

void Server::handleHashModels(const httplib::Request& req, httplib::Response& res) {
    std::string requestId = generateRequestId();

    try {
        if (!m_generationQueue || !m_modelManager) {
            sendErrorResponse(res, "Services not available", 500, "SERVICE_UNAVAILABLE", requestId);
            return;
        }

        // Parse request body
        json requestJson;
        if (!req.body.empty()) {
            requestJson = json::parse(req.body);
        }

        HashRequest hashReq;
        hashReq.id = requestId;
        hashReq.forceRehash = requestJson.value("force_rehash", false);

        if (requestJson.contains("models") && requestJson["models"].is_array()) {
            for (const auto& model : requestJson["models"]) {
                hashReq.modelNames.push_back(model.get<std::string>());
            }
        }

        // Enqueue hash request
        auto future = m_generationQueue->enqueueHashRequest(hashReq);

        json response = {
            {"request_id", requestId},
            {"status", "queued"},
            {"message", "Hash job queued successfully"},
            {"models_to_hash", hashReq.modelNames.empty() ? "all_unhashed" : std::to_string(hashReq.modelNames.size())}
        };

        sendJsonResponse(res, response, 202);
    } catch (const json::parse_error& e) {
        sendErrorResponse(res, std::string("Invalid JSON: ") + e.what(), 400, "JSON_PARSE_ERROR", requestId);
    } catch (const std::exception& e) {
        sendErrorResponse(res, std::string("Hash request failed: ") + e.what(), 500, "INTERNAL_ERROR", requestId);
    }
}

4. server.cpp - Update Models List to Show Hashes

In handleModelsList(), modify the model JSON to include hash:

json modelJson = {
    {"name", modelInfo.name},
    {"type", ModelManager::modelTypeToString(modelInfo.type)},
    {"file_size", modelInfo.fileSize},
    {"file_size_mb", modelInfo.fileSize / (1024.0 * 1024.0)},
    {"sha256", modelInfo.sha256.empty() ? nullptr : modelInfo.sha256},  // Add this
    {"sha256_short", modelInfo.sha256.empty() ? nullptr : modelInfo.sha256.substr(0, 10)}  // Add this
};

5. server.cpp - Modify Load Endpoint to Accept Hash

Update handleLoadModelById() to try hash-based lookup first:

void Server::handleLoadModelById(const httplib::Request& req, httplib::Response& res) {
    std::string requestId = generateRequestId();

    try {
        if (!m_modelManager) {
            sendErrorResponse(res, "Model manager not available", 500, "MODEL_MANAGER_UNAVAILABLE", requestId);
            return;
        }

        // Extract model ID (could be hash or name)
        std::string modelIdentifier = req.matches[1].str();
        if (modelIdentifier.empty()) {
            sendErrorResponse(res, "Missing model identifier", 400, "MISSING_MODEL_ID", requestId);
            return;
        }

        // Try to find by hash first (if it looks like a hash - 10+ hex chars)
        std::string modelName = modelIdentifier;
        if (modelIdentifier.length() >= 10 &&
            std::all_of(modelIdentifier.begin(), modelIdentifier.end(),
                       [](char c) { return std::isxdigit(c); })) {
            std::string foundName = m_modelManager->findModelByHash(modelIdentifier);
            if (!foundName.empty()) {
                modelName = foundName;
                std::cout << "Resolved hash " << modelIdentifier << " to model: " << modelName << std::endl;
            }
        }

        // Rest of load logic using modelName...
        // (keep existing load code)
    } catch (const std::exception& e) {
        sendErrorResponse(res, std::string("Model load failed: ") + e.what(), 500, "MODEL_LOAD_ERROR", requestId);
    }
}

6. server.h - Add Handler Declaration

In the private section, add:

void handleHashModels(const httplib::Request& req, httplib::Response& res);

Testing

  1. Build with OpenSSL:

    cmake --build . --target stable-diffusion-rest-server
    
  2. Start server:

    ./stable-diffusion-rest-server --models-dir /data/SD_MODELS --checkpoints checkpoints/SD1x --port 8082
    
  3. Trigger hashing:

    # Hash all unhashed models
    curl -X POST http://localhost:8082/api/models/hash
    
    # Hash specific models
    curl -X POST http://localhost:8082/api/models/hash -H "Content-Type: application/json" -d '{
    "models": ["checkpoints/SD1x/model1.safetensors"]
    }'
    
    # Force rehash
    curl -X POST http://localhost:8082/api/models/hash -H "Content-Type: application/json" -d '{
    "force_rehash": true
    }'
    
  4. Check models with hashes:

    curl -s http://localhost:8082/api/models | jq '.models[] | {name, sha256_short}'
    
  5. Load by hash:

    # Use first 10 chars of hash
    curl -X POST http://localhost:8082/api/models/a1b2c3d4e5/load
    

Notes

  • Hashes are stored in <model_file_path>.json
  • Minimum 10 characters required for partial hash matching
  • Hashing runs in the queue system, blocking other operations
  • Progress is logged every 100MB during hashing