|
@@ -1,13 +1,14 @@
|
|
|
#include "model_detector.h"
|
|
#include "model_detector.h"
|
|
|
-#include <nlohmann/json.hpp>
|
|
|
|
|
-#include <fstream>
|
|
|
|
|
#include <algorithm>
|
|
#include <algorithm>
|
|
|
#include <cstring>
|
|
#include <cstring>
|
|
|
-
|
|
|
|
|
|
|
+#include <fstream>
|
|
|
|
|
+#include <iostream>
|
|
|
|
|
+#include <nlohmann/json.hpp>
|
|
|
|
|
|
|
|
// Helper function for C++17 compatibility (ends_with is C++20)
|
|
// Helper function for C++17 compatibility (ends_with is C++20)
|
|
|
static bool endsWith(const std::string& str, const std::string& suffix) {
|
|
static bool endsWith(const std::string& str, const std::string& suffix) {
|
|
|
- if (suffix.size() > str.size()) return false;
|
|
|
|
|
|
|
+ if (suffix.size() > str.size())
|
|
|
|
|
+ return false;
|
|
|
return str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0;
|
|
return str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -26,12 +27,12 @@ ModelDetectionResult ModelDetector::detectModel(const std::string& modelPath) {
|
|
|
// For now, we cannot detect their architecture without loading the model
|
|
// For now, we cannot detect their architecture without loading the model
|
|
|
// Return unknown architecture with a note in metadata
|
|
// Return unknown architecture with a note in metadata
|
|
|
result.metadata["format"] = "pytorch_pickle";
|
|
result.metadata["format"] = "pytorch_pickle";
|
|
|
- result.metadata["note"] = "Architecture detection not supported for .ckpt/.pt files";
|
|
|
|
|
|
|
+ result.metadata["note"] = "Architecture detection not supported for .ckpt/.pt files";
|
|
|
return result;
|
|
return result;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if (!parsed) {
|
|
if (!parsed) {
|
|
|
- return result; // Unknown if we can't parse
|
|
|
|
|
|
|
+ return result; // Unknown if we can't parse
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// Store tensor names for reference
|
|
// Store tensor names for reference
|
|
@@ -40,84 +41,84 @@ ModelDetectionResult ModelDetector::detectModel(const std::string& modelPath) {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// Analyze architecture (pass filename for special detection)
|
|
// Analyze architecture (pass filename for special detection)
|
|
|
- std::string filename = modelPath.substr(modelPath.find_last_of("/\\") + 1);
|
|
|
|
|
- result.architecture = analyzeArchitecture(tensorInfo, result.metadata, filename);
|
|
|
|
|
|
|
+ std::string filename = modelPath.substr(modelPath.find_last_of("/\\") + 1);
|
|
|
|
|
+ result.architecture = analyzeArchitecture(tensorInfo, result.metadata, filename);
|
|
|
result.architectureName = getArchitectureName(result.architecture);
|
|
result.architectureName = getArchitectureName(result.architecture);
|
|
|
|
|
|
|
|
// Set architecture-specific properties and required models
|
|
// Set architecture-specific properties and required models
|
|
|
switch (result.architecture) {
|
|
switch (result.architecture) {
|
|
|
case ModelArchitecture::SD_1_5:
|
|
case ModelArchitecture::SD_1_5:
|
|
|
- result.textEncoderDim = 768;
|
|
|
|
|
- result.unetChannels = 1280;
|
|
|
|
|
- result.needsVAE = true;
|
|
|
|
|
- result.recommendedVAE = "vae-ft-mse-840000-ema-pruned.safetensors";
|
|
|
|
|
- result.needsTAESD = true;
|
|
|
|
|
|
|
+ result.textEncoderDim = 768;
|
|
|
|
|
+ result.unetChannels = 1280;
|
|
|
|
|
+ result.needsVAE = true;
|
|
|
|
|
+ result.recommendedVAE = "vae-ft-mse-840000-ema-pruned.safetensors";
|
|
|
|
|
+ result.needsTAESD = true;
|
|
|
result.suggestedParams["vae_flag"] = "--vae";
|
|
result.suggestedParams["vae_flag"] = "--vae";
|
|
|
break;
|
|
break;
|
|
|
|
|
|
|
|
case ModelArchitecture::SD_2_1:
|
|
case ModelArchitecture::SD_2_1:
|
|
|
- result.textEncoderDim = 1024;
|
|
|
|
|
- result.unetChannels = 1280;
|
|
|
|
|
- result.needsVAE = true;
|
|
|
|
|
- result.recommendedVAE = "vae-ft-ema-560000.safetensors";
|
|
|
|
|
- result.needsTAESD = true;
|
|
|
|
|
|
|
+ result.textEncoderDim = 1024;
|
|
|
|
|
+ result.unetChannels = 1280;
|
|
|
|
|
+ result.needsVAE = true;
|
|
|
|
|
+ result.recommendedVAE = "vae-ft-ema-560000.safetensors";
|
|
|
|
|
+ result.needsTAESD = true;
|
|
|
result.suggestedParams["vae_flag"] = "--vae";
|
|
result.suggestedParams["vae_flag"] = "--vae";
|
|
|
break;
|
|
break;
|
|
|
|
|
|
|
|
case ModelArchitecture::SDXL_BASE:
|
|
case ModelArchitecture::SDXL_BASE:
|
|
|
case ModelArchitecture::SDXL_REFINER:
|
|
case ModelArchitecture::SDXL_REFINER:
|
|
|
- result.textEncoderDim = 1280;
|
|
|
|
|
- result.unetChannels = 2560;
|
|
|
|
|
- result.hasConditioner = true;
|
|
|
|
|
- result.needsVAE = true;
|
|
|
|
|
- result.recommendedVAE = "sdxl_vae.safetensors";
|
|
|
|
|
- result.needsTAESD = true;
|
|
|
|
|
|
|
+ result.textEncoderDim = 1280;
|
|
|
|
|
+ result.unetChannels = 2560;
|
|
|
|
|
+ result.hasConditioner = true;
|
|
|
|
|
+ result.needsVAE = true;
|
|
|
|
|
+ result.recommendedVAE = "sdxl_vae.safetensors";
|
|
|
|
|
+ result.needsTAESD = true;
|
|
|
result.suggestedParams["vae_flag"] = "--vae";
|
|
result.suggestedParams["vae_flag"] = "--vae";
|
|
|
break;
|
|
break;
|
|
|
|
|
|
|
|
case ModelArchitecture::FLUX_SCHNELL:
|
|
case ModelArchitecture::FLUX_SCHNELL:
|
|
|
case ModelArchitecture::FLUX_DEV:
|
|
case ModelArchitecture::FLUX_DEV:
|
|
|
result.textEncoderDim = 4096;
|
|
result.textEncoderDim = 4096;
|
|
|
- result.needsVAE = true;
|
|
|
|
|
|
|
+ result.needsVAE = true;
|
|
|
result.recommendedVAE = "ae.safetensors";
|
|
result.recommendedVAE = "ae.safetensors";
|
|
|
// Flux requires CLIP-L and T5XXL
|
|
// Flux requires CLIP-L and T5XXL
|
|
|
- result.suggestedParams["vae_flag"] = "--vae";
|
|
|
|
|
|
|
+ result.suggestedParams["vae_flag"] = "--vae";
|
|
|
result.suggestedParams["clip_l_required"] = "clip_l.safetensors";
|
|
result.suggestedParams["clip_l_required"] = "clip_l.safetensors";
|
|
|
- result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors";
|
|
|
|
|
- result.suggestedParams["clip_l_flag"] = "--clip-l";
|
|
|
|
|
- result.suggestedParams["t5xxl_flag"] = "--t5xxl";
|
|
|
|
|
|
|
+ result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors";
|
|
|
|
|
+ result.suggestedParams["clip_l_flag"] = "--clip-l";
|
|
|
|
|
+ result.suggestedParams["t5xxl_flag"] = "--t5xxl";
|
|
|
break;
|
|
break;
|
|
|
|
|
|
|
|
case ModelArchitecture::FLUX_CHROMA:
|
|
case ModelArchitecture::FLUX_CHROMA:
|
|
|
result.textEncoderDim = 4096;
|
|
result.textEncoderDim = 4096;
|
|
|
- result.needsVAE = true;
|
|
|
|
|
|
|
+ result.needsVAE = true;
|
|
|
result.recommendedVAE = "ae.safetensors";
|
|
result.recommendedVAE = "ae.safetensors";
|
|
|
// Chroma (Flux Unlocked) requires VAE and T5XXL
|
|
// Chroma (Flux Unlocked) requires VAE and T5XXL
|
|
|
- result.suggestedParams["vae_flag"] = "--vae";
|
|
|
|
|
|
|
+ result.suggestedParams["vae_flag"] = "--vae";
|
|
|
result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors";
|
|
result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors";
|
|
|
- result.suggestedParams["t5xxl_flag"] = "--t5xxl";
|
|
|
|
|
|
|
+ result.suggestedParams["t5xxl_flag"] = "--t5xxl";
|
|
|
break;
|
|
break;
|
|
|
|
|
|
|
|
case ModelArchitecture::SD_3:
|
|
case ModelArchitecture::SD_3:
|
|
|
result.textEncoderDim = 4096;
|
|
result.textEncoderDim = 4096;
|
|
|
- result.needsVAE = true;
|
|
|
|
|
|
|
+ result.needsVAE = true;
|
|
|
result.recommendedVAE = "sd3_vae.safetensors";
|
|
result.recommendedVAE = "sd3_vae.safetensors";
|
|
|
// SD3 requires CLIP-L, CLIP-G, and T5XXL
|
|
// SD3 requires CLIP-L, CLIP-G, and T5XXL
|
|
|
- result.suggestedParams["vae_flag"] = "--vae";
|
|
|
|
|
|
|
+ result.suggestedParams["vae_flag"] = "--vae";
|
|
|
result.suggestedParams["clip_l_required"] = "clip_l.safetensors";
|
|
result.suggestedParams["clip_l_required"] = "clip_l.safetensors";
|
|
|
result.suggestedParams["clip_g_required"] = "clip_g.safetensors";
|
|
result.suggestedParams["clip_g_required"] = "clip_g.safetensors";
|
|
|
- result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors";
|
|
|
|
|
- result.suggestedParams["clip_l_flag"] = "--clip-l";
|
|
|
|
|
- result.suggestedParams["clip_g_flag"] = "--clip-g";
|
|
|
|
|
- result.suggestedParams["t5xxl_flag"] = "--t5xxl";
|
|
|
|
|
|
|
+ result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors";
|
|
|
|
|
+ result.suggestedParams["clip_l_flag"] = "--clip-l";
|
|
|
|
|
+ result.suggestedParams["clip_g_flag"] = "--clip-g";
|
|
|
|
|
+ result.suggestedParams["t5xxl_flag"] = "--t5xxl";
|
|
|
break;
|
|
break;
|
|
|
|
|
|
|
|
case ModelArchitecture::QWEN2VL:
|
|
case ModelArchitecture::QWEN2VL:
|
|
|
// Qwen2-VL requires vision and language model components
|
|
// Qwen2-VL requires vision and language model components
|
|
|
- result.suggestedParams["qwen2vl_required"] = "qwen2vl.safetensors";
|
|
|
|
|
|
|
+ result.suggestedParams["qwen2vl_required"] = "qwen2vl.safetensors";
|
|
|
result.suggestedParams["qwen2vl_vision_required"] = "qwen2vl_vision.safetensors";
|
|
result.suggestedParams["qwen2vl_vision_required"] = "qwen2vl_vision.safetensors";
|
|
|
- result.suggestedParams["qwen2vl_flag"] = "--qwen2vl";
|
|
|
|
|
- result.suggestedParams["qwen2vl_vision_flag"] = "--qwen2vl-vision";
|
|
|
|
|
|
|
+ result.suggestedParams["qwen2vl_flag"] = "--qwen2vl";
|
|
|
|
|
+ result.suggestedParams["qwen2vl_vision_flag"] = "--qwen2vl-vision";
|
|
|
break;
|
|
break;
|
|
|
|
|
|
|
|
default:
|
|
default:
|
|
@@ -139,8 +140,7 @@ ModelDetectionResult ModelDetector::detectModel(const std::string& modelPath) {
|
|
|
bool ModelDetector::parseSafetensorsHeader(
|
|
bool ModelDetector::parseSafetensorsHeader(
|
|
|
const std::string& filePath,
|
|
const std::string& filePath,
|
|
|
std::map<std::string, std::string>& metadata,
|
|
std::map<std::string, std::string>& metadata,
|
|
|
- std::map<std::string, std::vector<int64_t>>& tensorInfo
|
|
|
|
|
-) {
|
|
|
|
|
|
|
+ std::map<std::string, std::vector<int64_t>>& tensorInfo) {
|
|
|
std::ifstream file(filePath, std::ios::binary);
|
|
std::ifstream file(filePath, std::ios::binary);
|
|
|
if (!file.is_open()) {
|
|
if (!file.is_open()) {
|
|
|
return false;
|
|
return false;
|
|
@@ -179,7 +179,8 @@ bool ModelDetector::parseSafetensorsHeader(
|
|
|
|
|
|
|
|
// Extract tensor information
|
|
// Extract tensor information
|
|
|
for (auto it = headerJson.begin(); it != headerJson.end(); ++it) {
|
|
for (auto it = headerJson.begin(); it != headerJson.end(); ++it) {
|
|
|
- if (it.key() == "__metadata__") continue;
|
|
|
|
|
|
|
+ if (it.key() == "__metadata__")
|
|
|
|
|
+ continue;
|
|
|
|
|
|
|
|
if (it.value().contains("shape")) {
|
|
if (it.value().contains("shape")) {
|
|
|
std::vector<int64_t> shape;
|
|
std::vector<int64_t> shape;
|
|
@@ -199,8 +200,7 @@ bool ModelDetector::parseSafetensorsHeader(
|
|
|
ModelArchitecture ModelDetector::analyzeArchitecture(
|
|
ModelArchitecture ModelDetector::analyzeArchitecture(
|
|
|
const std::map<std::string, std::vector<int64_t>>& tensorInfo,
|
|
const std::map<std::string, std::vector<int64_t>>& tensorInfo,
|
|
|
const std::map<std::string, std::string>& metadata,
|
|
const std::map<std::string, std::string>& metadata,
|
|
|
- const std::string& filename
|
|
|
|
|
-) {
|
|
|
|
|
|
|
+ const std::string& filename) {
|
|
|
// Check metadata first for explicit architecture hints
|
|
// Check metadata first for explicit architecture hints
|
|
|
auto modelTypeIt = metadata.find("modelspec.architecture");
|
|
auto modelTypeIt = metadata.find("modelspec.architecture");
|
|
|
if (modelTypeIt != metadata.end()) {
|
|
if (modelTypeIt != metadata.end()) {
|
|
@@ -219,11 +219,11 @@ ModelArchitecture ModelDetector::analyzeArchitecture(
|
|
|
std::transform(lowerFilename.begin(), lowerFilename.end(), lowerFilename.begin(), ::tolower);
|
|
std::transform(lowerFilename.begin(), lowerFilename.end(), lowerFilename.begin(), ::tolower);
|
|
|
|
|
|
|
|
// Analyze tensor structure for architecture detection
|
|
// Analyze tensor structure for architecture detection
|
|
|
- bool hasConditioner = false;
|
|
|
|
|
- bool hasTextEncoder2 = false;
|
|
|
|
|
- bool hasFluxStructure = false;
|
|
|
|
|
- bool hasSD3Structure = false;
|
|
|
|
|
- int maxUNetChannels = 0;
|
|
|
|
|
|
|
+ bool hasConditioner = false;
|
|
|
|
|
+ bool hasTextEncoder2 = false;
|
|
|
|
|
+ bool hasFluxStructure = false;
|
|
|
|
|
+ bool hasSD3Structure = false;
|
|
|
|
|
+ int maxUNetChannels = 0;
|
|
|
int textEncoderOutputDim = 0;
|
|
int textEncoderOutputDim = 0;
|
|
|
|
|
|
|
|
for (const auto& [name, shape] : tensorInfo) {
|
|
for (const auto& [name, shape] : tensorInfo) {
|
|
@@ -313,16 +313,26 @@ ModelArchitecture ModelDetector::analyzeArchitecture(
|
|
|
|
|
|
|
|
std::string ModelDetector::getArchitectureName(ModelArchitecture arch) {
|
|
std::string ModelDetector::getArchitectureName(ModelArchitecture arch) {
|
|
|
switch (arch) {
|
|
switch (arch) {
|
|
|
- case ModelArchitecture::SD_1_5: return "Stable Diffusion 1.5";
|
|
|
|
|
- case ModelArchitecture::SD_2_1: return "Stable Diffusion 2.1";
|
|
|
|
|
- case ModelArchitecture::SDXL_BASE: return "Stable Diffusion XL Base";
|
|
|
|
|
- case ModelArchitecture::SDXL_REFINER: return "Stable Diffusion XL Refiner";
|
|
|
|
|
- case ModelArchitecture::FLUX_SCHNELL: return "Flux Schnell";
|
|
|
|
|
- case ModelArchitecture::FLUX_DEV: return "Flux Dev";
|
|
|
|
|
- case ModelArchitecture::FLUX_CHROMA: return "Flux Chroma (Unlocked)";
|
|
|
|
|
- case ModelArchitecture::SD_3: return "Stable Diffusion 3";
|
|
|
|
|
- case ModelArchitecture::QWEN2VL: return "Qwen2-VL";
|
|
|
|
|
- default: return "Unknown";
|
|
|
|
|
|
|
+ case ModelArchitecture::SD_1_5:
|
|
|
|
|
+ return "Stable Diffusion 1.5";
|
|
|
|
|
+ case ModelArchitecture::SD_2_1:
|
|
|
|
|
+ return "Stable Diffusion 2.1";
|
|
|
|
|
+ case ModelArchitecture::SDXL_BASE:
|
|
|
|
|
+ return "Stable Diffusion XL Base";
|
|
|
|
|
+ case ModelArchitecture::SDXL_REFINER:
|
|
|
|
|
+ return "Stable Diffusion XL Refiner";
|
|
|
|
|
+ case ModelArchitecture::FLUX_SCHNELL:
|
|
|
|
|
+ return "Flux Schnell";
|
|
|
|
|
+ case ModelArchitecture::FLUX_DEV:
|
|
|
|
|
+ return "Flux Dev";
|
|
|
|
|
+ case ModelArchitecture::FLUX_CHROMA:
|
|
|
|
|
+ return "Flux Chroma (Unlocked)";
|
|
|
|
|
+ case ModelArchitecture::SD_3:
|
|
|
|
|
+ return "Stable Diffusion 3";
|
|
|
|
|
+ case ModelArchitecture::QWEN2VL:
|
|
|
|
|
+ return "Qwen2-VL";
|
|
|
|
|
+ default:
|
|
|
|
|
+ return "Unknown";
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -331,60 +341,60 @@ std::map<std::string, std::string> ModelDetector::getRecommendedParams(ModelArch
|
|
|
|
|
|
|
|
switch (arch) {
|
|
switch (arch) {
|
|
|
case ModelArchitecture::SD_1_5:
|
|
case ModelArchitecture::SD_1_5:
|
|
|
- params["width"] = "512";
|
|
|
|
|
- params["height"] = "512";
|
|
|
|
|
|
|
+ params["width"] = "512";
|
|
|
|
|
+ params["height"] = "512";
|
|
|
params["cfg_scale"] = "7.5";
|
|
params["cfg_scale"] = "7.5";
|
|
|
- params["steps"] = "20";
|
|
|
|
|
- params["sampler"] = "euler_a";
|
|
|
|
|
|
|
+ params["steps"] = "20";
|
|
|
|
|
+ params["sampler"] = "euler_a";
|
|
|
break;
|
|
break;
|
|
|
|
|
|
|
|
case ModelArchitecture::SD_2_1:
|
|
case ModelArchitecture::SD_2_1:
|
|
|
- params["width"] = "768";
|
|
|
|
|
- params["height"] = "768";
|
|
|
|
|
|
|
+ params["width"] = "768";
|
|
|
|
|
+ params["height"] = "768";
|
|
|
params["cfg_scale"] = "7.0";
|
|
params["cfg_scale"] = "7.0";
|
|
|
- params["steps"] = "25";
|
|
|
|
|
- params["sampler"] = "euler_a";
|
|
|
|
|
|
|
+ params["steps"] = "25";
|
|
|
|
|
+ params["sampler"] = "euler_a";
|
|
|
break;
|
|
break;
|
|
|
|
|
|
|
|
case ModelArchitecture::SDXL_BASE:
|
|
case ModelArchitecture::SDXL_BASE:
|
|
|
case ModelArchitecture::SDXL_REFINER:
|
|
case ModelArchitecture::SDXL_REFINER:
|
|
|
- params["width"] = "1024";
|
|
|
|
|
- params["height"] = "1024";
|
|
|
|
|
|
|
+ params["width"] = "1024";
|
|
|
|
|
+ params["height"] = "1024";
|
|
|
params["cfg_scale"] = "7.0";
|
|
params["cfg_scale"] = "7.0";
|
|
|
- params["steps"] = "30";
|
|
|
|
|
- params["sampler"] = "dpm++2m";
|
|
|
|
|
|
|
+ params["steps"] = "30";
|
|
|
|
|
+ params["sampler"] = "dpm++2m";
|
|
|
break;
|
|
break;
|
|
|
|
|
|
|
|
case ModelArchitecture::FLUX_SCHNELL:
|
|
case ModelArchitecture::FLUX_SCHNELL:
|
|
|
- params["width"] = "1024";
|
|
|
|
|
- params["height"] = "1024";
|
|
|
|
|
|
|
+ params["width"] = "1024";
|
|
|
|
|
+ params["height"] = "1024";
|
|
|
params["cfg_scale"] = "1.0";
|
|
params["cfg_scale"] = "1.0";
|
|
|
- params["steps"] = "4";
|
|
|
|
|
- params["sampler"] = "euler";
|
|
|
|
|
|
|
+ params["steps"] = "4";
|
|
|
|
|
+ params["sampler"] = "euler";
|
|
|
break;
|
|
break;
|
|
|
|
|
|
|
|
case ModelArchitecture::FLUX_DEV:
|
|
case ModelArchitecture::FLUX_DEV:
|
|
|
- params["width"] = "1024";
|
|
|
|
|
- params["height"] = "1024";
|
|
|
|
|
|
|
+ params["width"] = "1024";
|
|
|
|
|
+ params["height"] = "1024";
|
|
|
params["cfg_scale"] = "1.0";
|
|
params["cfg_scale"] = "1.0";
|
|
|
- params["steps"] = "20";
|
|
|
|
|
- params["sampler"] = "euler";
|
|
|
|
|
|
|
+ params["steps"] = "20";
|
|
|
|
|
+ params["sampler"] = "euler";
|
|
|
break;
|
|
break;
|
|
|
|
|
|
|
|
case ModelArchitecture::FLUX_CHROMA:
|
|
case ModelArchitecture::FLUX_CHROMA:
|
|
|
- params["width"] = "1024";
|
|
|
|
|
- params["height"] = "1024";
|
|
|
|
|
|
|
+ params["width"] = "1024";
|
|
|
|
|
+ params["height"] = "1024";
|
|
|
params["cfg_scale"] = "1.0";
|
|
params["cfg_scale"] = "1.0";
|
|
|
- params["steps"] = "20";
|
|
|
|
|
- params["sampler"] = "euler";
|
|
|
|
|
|
|
+ params["steps"] = "20";
|
|
|
|
|
+ params["sampler"] = "euler";
|
|
|
break;
|
|
break;
|
|
|
|
|
|
|
|
case ModelArchitecture::SD_3:
|
|
case ModelArchitecture::SD_3:
|
|
|
- params["width"] = "1024";
|
|
|
|
|
- params["height"] = "1024";
|
|
|
|
|
|
|
+ params["width"] = "1024";
|
|
|
|
|
+ params["height"] = "1024";
|
|
|
params["cfg_scale"] = "5.0";
|
|
params["cfg_scale"] = "5.0";
|
|
|
- params["steps"] = "28";
|
|
|
|
|
- params["sampler"] = "dpm++2m";
|
|
|
|
|
|
|
+ params["steps"] = "28";
|
|
|
|
|
+ params["sampler"] = "dpm++2m";
|
|
|
break;
|
|
break;
|
|
|
|
|
|
|
|
default:
|
|
default:
|
|
@@ -397,8 +407,7 @@ std::map<std::string, std::string> ModelDetector::getRecommendedParams(ModelArch
|
|
|
bool ModelDetector::parseGGUFHeader(
|
|
bool ModelDetector::parseGGUFHeader(
|
|
|
const std::string& filePath,
|
|
const std::string& filePath,
|
|
|
std::map<std::string, std::string>& metadata,
|
|
std::map<std::string, std::string>& metadata,
|
|
|
- std::map<std::string, std::vector<int64_t>>& tensorInfo
|
|
|
|
|
-) {
|
|
|
|
|
|
|
+ std::map<std::string, std::vector<int64_t>>& tensorInfo) {
|
|
|
std::ifstream file(filePath, std::ios::binary);
|
|
std::ifstream file(filePath, std::ios::binary);
|
|
|
if (!file.is_open()) {
|
|
if (!file.is_open()) {
|
|
|
return false;
|
|
return false;
|
|
@@ -450,44 +459,46 @@ bool ModelDetector::parseGGUFHeader(
|
|
|
// Read metadata key-value pairs
|
|
// Read metadata key-value pairs
|
|
|
for (uint64_t i = 0; i < metadataCount && file.good(); ++i) {
|
|
for (uint64_t i = 0; i < metadataCount && file.good(); ++i) {
|
|
|
std::string key = readString();
|
|
std::string key = readString();
|
|
|
- if (key.empty()) break;
|
|
|
|
|
|
|
+ if (key.empty())
|
|
|
|
|
+ break;
|
|
|
|
|
|
|
|
// Read value type (uint32)
|
|
// Read value type (uint32)
|
|
|
uint32_t valueType;
|
|
uint32_t valueType;
|
|
|
file.read(reinterpret_cast<char*>(&valueType), 4);
|
|
file.read(reinterpret_cast<char*>(&valueType), 4);
|
|
|
- if (file.gcount() != 4) break;
|
|
|
|
|
|
|
+ if (file.gcount() != 4)
|
|
|
|
|
+ break;
|
|
|
|
|
|
|
|
// Parse value based on type
|
|
// Parse value based on type
|
|
|
std::string value;
|
|
std::string value;
|
|
|
switch (valueType) {
|
|
switch (valueType) {
|
|
|
- case 8: // String
|
|
|
|
|
|
|
+ case 8: // String
|
|
|
value = readString();
|
|
value = readString();
|
|
|
break;
|
|
break;
|
|
|
- case 4: { // Uint32
|
|
|
|
|
|
|
+ case 4: { // Uint32
|
|
|
uint32_t val;
|
|
uint32_t val;
|
|
|
file.read(reinterpret_cast<char*>(&val), 4);
|
|
file.read(reinterpret_cast<char*>(&val), 4);
|
|
|
value = std::to_string(val);
|
|
value = std::to_string(val);
|
|
|
break;
|
|
break;
|
|
|
}
|
|
}
|
|
|
- case 5: { // Int32
|
|
|
|
|
|
|
+ case 5: { // Int32
|
|
|
int32_t val;
|
|
int32_t val;
|
|
|
file.read(reinterpret_cast<char*>(&val), 4);
|
|
file.read(reinterpret_cast<char*>(&val), 4);
|
|
|
value = std::to_string(val);
|
|
value = std::to_string(val);
|
|
|
break;
|
|
break;
|
|
|
}
|
|
}
|
|
|
- case 6: { // Float32
|
|
|
|
|
|
|
+ case 6: { // Float32
|
|
|
float val;
|
|
float val;
|
|
|
file.read(reinterpret_cast<char*>(&val), 4);
|
|
file.read(reinterpret_cast<char*>(&val), 4);
|
|
|
value = std::to_string(val);
|
|
value = std::to_string(val);
|
|
|
break;
|
|
break;
|
|
|
}
|
|
}
|
|
|
- case 0: { // Uint8
|
|
|
|
|
|
|
+ case 0: { // Uint8
|
|
|
uint8_t val;
|
|
uint8_t val;
|
|
|
file.read(reinterpret_cast<char*>(&val), 1);
|
|
file.read(reinterpret_cast<char*>(&val), 1);
|
|
|
value = std::to_string(val);
|
|
value = std::to_string(val);
|
|
|
break;
|
|
break;
|
|
|
}
|
|
}
|
|
|
- case 1: { // Int8
|
|
|
|
|
|
|
+ case 1: { // Int8
|
|
|
int8_t val;
|
|
int8_t val;
|
|
|
file.read(reinterpret_cast<char*>(&val), 1);
|
|
file.read(reinterpret_cast<char*>(&val), 1);
|
|
|
value = std::to_string(val);
|
|
value = std::to_string(val);
|
|
@@ -507,19 +518,22 @@ bool ModelDetector::parseGGUFHeader(
|
|
|
// Read tensor information
|
|
// Read tensor information
|
|
|
for (uint64_t i = 0; i < tensorCount && file.good(); ++i) {
|
|
for (uint64_t i = 0; i < tensorCount && file.good(); ++i) {
|
|
|
std::string tensorName = readString();
|
|
std::string tensorName = readString();
|
|
|
- if (tensorName.empty()) break;
|
|
|
|
|
|
|
+ if (tensorName.empty())
|
|
|
|
|
+ break;
|
|
|
|
|
|
|
|
// Read number of dimensions (uint32)
|
|
// Read number of dimensions (uint32)
|
|
|
uint32_t nDims;
|
|
uint32_t nDims;
|
|
|
file.read(reinterpret_cast<char*>(&nDims), 4);
|
|
file.read(reinterpret_cast<char*>(&nDims), 4);
|
|
|
- if (file.gcount() != 4 || nDims > 10) break;
|
|
|
|
|
|
|
+ if (file.gcount() != 4 || nDims > 10)
|
|
|
|
|
+ break;
|
|
|
|
|
|
|
|
// Read dimensions (uint64 array)
|
|
// Read dimensions (uint64 array)
|
|
|
std::vector<int64_t> shape(nDims);
|
|
std::vector<int64_t> shape(nDims);
|
|
|
for (uint32_t d = 0; d < nDims; ++d) {
|
|
for (uint32_t d = 0; d < nDims; ++d) {
|
|
|
uint64_t dim;
|
|
uint64_t dim;
|
|
|
file.read(reinterpret_cast<char*>(&dim), 8);
|
|
file.read(reinterpret_cast<char*>(&dim), 8);
|
|
|
- if (file.gcount() != 8) break;
|
|
|
|
|
|
|
+ if (file.gcount() != 8)
|
|
|
|
|
+ break;
|
|
|
shape[d] = static_cast<int64_t>(dim);
|
|
shape[d] = static_cast<int64_t>(dim);
|
|
|
}
|
|
}
|
|
|
|
|
|