model_detector.cpp 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698
  1. #include "model_detector.h"
  2. #include <algorithm>
  3. #include <cstring>
  4. #include <fstream>
  5. #include <iostream>
  6. #include <nlohmann/json.hpp>
  7. // Helper function for C++17 compatibility (ends_with is C++20)
  8. static bool endsWith(const std::string& str, const std::string& suffix) {
  9. if (suffix.size() > str.size())
  10. return false;
  11. return str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0;
  12. }
  13. ModelDetectionResult ModelDetector::detectModel(const std::string& modelPath) {
  14. ModelDetectionResult result;
  15. std::map<std::string, std::vector<int64_t>> tensorInfo;
  16. // Determine file type and parse accordingly
  17. bool parsed = false;
  18. if (endsWith(modelPath, ".safetensors")) {
  19. parsed = parseSafetensorsHeader(modelPath, result.metadata, tensorInfo);
  20. } else if (endsWith(modelPath, ".gguf")) {
  21. parsed = parseGGUFHeader(modelPath, result.metadata, tensorInfo);
  22. } else if (endsWith(modelPath, ".ckpt") || endsWith(modelPath, ".pt")) {
  23. // PyTorch pickle files - these require the full PyTorch library to parse safely
  24. // For now, we cannot detect their architecture without loading the model
  25. // Return unknown architecture with a note in metadata
  26. result.metadata["format"] = "pytorch_pickle";
  27. result.metadata["note"] = "Architecture detection not supported for .ckpt/.pt files";
  28. return result;
  29. }
  30. if (!parsed) {
  31. return result; // Unknown if we can't parse
  32. }
  33. // Store tensor names for reference
  34. for (const auto& [name, _] : tensorInfo) {
  35. result.tensorNames.push_back(name);
  36. }
  37. // Analyze architecture (pass filename for special detection)
  38. std::string filename = modelPath.substr(modelPath.find_last_of("/\\") + 1);
  39. result.architecture = analyzeArchitecture(tensorInfo, result.metadata, filename);
  40. result.architectureName = getArchitectureName(result.architecture);
  41. // Set architecture-specific properties and required models
  42. switch (result.architecture) {
  43. case ModelArchitecture::SD_1_5:
  44. result.textEncoderDim = 512;
  45. result.unetChannels = 512;
  46. result.needsVAE = false;
  47. result.recommendedVAE = "vae-ft-mse-840000-ema-pruned.safetensors";
  48. result.needsTAESD = true;
  49. result.suggestedParams["vae_flag"] = "--vae";
  50. break;
  51. case ModelArchitecture::SD_2_1:
  52. result.textEncoderDim = 1024;
  53. result.unetChannels = 1024;
  54. result.needsVAE = false;
  55. result.recommendedVAE = "vae-ft-ema-560000.safetensors";
  56. result.needsTAESD = true;
  57. result.suggestedParams["vae_flag"] = "--vae";
  58. break;
  59. case ModelArchitecture::SDXL_BASE:
  60. case ModelArchitecture::SDXL_REFINER:
  61. result.textEncoderDim = 1024;
  62. result.unetChannels = 1024;
  63. result.hasConditioner = true;
  64. result.needsVAE = false;
  65. result.recommendedVAE = "sdxl_vae.safetensors";
  66. result.needsTAESD = true;
  67. result.suggestedParams["vae_flag"] = "--vae";
  68. break;
  69. case ModelArchitecture::FLUX_SCHNELL:
  70. case ModelArchitecture::FLUX_DEV:
  71. result.textEncoderDim = 4096;
  72. result.needsVAE = true;
  73. result.recommendedVAE = "ae.safetensors";
  74. // Flux requires CLIP-L and T5XXL
  75. result.suggestedParams["vae_flag"] = "--vae";
  76. result.suggestedParams["clip_l_required"] = "clip_l.safetensors";
  77. result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors";
  78. result.suggestedParams["clip_l_flag"] = "--clip-l";
  79. result.suggestedParams["t5xxl_flag"] = "--t5xxl";
  80. break;
  81. case ModelArchitecture::FLUX_CHROMA:
  82. result.textEncoderDim = 4096;
  83. result.needsVAE = true;
  84. result.recommendedVAE = "ae.safetensors";
  85. // Chroma (Flux Unlocked) requires VAE and T5XXL
  86. result.suggestedParams["vae_flag"] = "--vae";
  87. result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors";
  88. result.suggestedParams["t5xxl_flag"] = "--t5xxl";
  89. break;
  90. case ModelArchitecture::SD_3:
  91. result.textEncoderDim = 4096;
  92. result.needsVAE = true;
  93. result.recommendedVAE = "sd3_vae.safetensors";
  94. // SD3 requires CLIP-L, CLIP-G, and T5XXL
  95. result.suggestedParams["vae_flag"] = "--vae";
  96. result.suggestedParams["clip_l_required"] = "clip_l.safetensors";
  97. result.suggestedParams["clip_g_required"] = "clip_g.safetensors";
  98. result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors";
  99. result.suggestedParams["clip_l_flag"] = "--clip-l";
  100. result.suggestedParams["clip_g_flag"] = "--clip-g";
  101. result.suggestedParams["t5xxl_flag"] = "--t5xxl";
  102. break;
  103. case ModelArchitecture::QWEN2VL:
  104. // Qwen2-VL requires vision and language model components
  105. result.suggestedParams["qwen2vl_required"] = "qwen2vl.safetensors";
  106. result.suggestedParams["qwen2vl_vision_required"] = "qwen2vl_vision.safetensors";
  107. result.suggestedParams["qwen2vl_flag"] = "--qwen2vl";
  108. result.suggestedParams["qwen2vl_vision_flag"] = "--qwen2vl-vision";
  109. break;
  110. default:
  111. break;
  112. }
  113. // Merge with general recommended parameters (width, height, steps, etc.)
  114. auto generalParams = getRecommendedParams(result.architecture);
  115. for (const auto& [key, value] : generalParams) {
  116. // Only add if not already set (preserve architecture-specific flags)
  117. if (result.suggestedParams.find(key) == result.suggestedParams.end()) {
  118. result.suggestedParams[key] = value;
  119. }
  120. }
  121. return result;
  122. }
  123. bool ModelDetector::parseSafetensorsHeader(
  124. const std::string& filePath,
  125. std::map<std::string, std::string>& metadata,
  126. std::map<std::string, std::vector<int64_t>>& tensorInfo) {
  127. std::ifstream file(filePath, std::ios::binary);
  128. if (!file.is_open()) {
  129. return false;
  130. }
  131. // Read header length (first 8 bytes, little-endian uint64)
  132. uint64_t headerLength = 0;
  133. file.read(reinterpret_cast<char*>(&headerLength), 8);
  134. if (file.gcount() != 8) {
  135. return false;
  136. }
  137. // Sanity check: header should be reasonable size (< 100MB)
  138. if (headerLength == 0 || headerLength > 100 * 1024 * 1024) {
  139. return false;
  140. }
  141. // Read header JSON
  142. std::vector<char> headerBuffer(headerLength);
  143. file.read(headerBuffer.data(), headerLength);
  144. if (file.gcount() != static_cast<std::streamsize>(headerLength)) {
  145. return false;
  146. }
  147. // Parse JSON
  148. try {
  149. nlohmann::json headerJson = nlohmann::json::parse(headerBuffer.begin(), headerBuffer.end());
  150. // Extract metadata if present
  151. if (headerJson.contains("__metadata__")) {
  152. auto metadataJson = headerJson["__metadata__"];
  153. for (auto it = metadataJson.begin(); it != metadataJson.end(); ++it) {
  154. metadata[it.key()] = it.value().get<std::string>();
  155. }
  156. }
  157. // Extract tensor information
  158. for (auto it = headerJson.begin(); it != headerJson.end(); ++it) {
  159. if (it.key() == "__metadata__")
  160. continue;
  161. if (it.value().contains("shape")) {
  162. std::vector<int64_t> shape;
  163. for (const auto& dim : it.value()["shape"]) {
  164. shape.push_back(dim.get<int64_t>());
  165. }
  166. tensorInfo[it.key()] = shape;
  167. }
  168. }
  169. return true;
  170. } catch (const std::exception& e) {
  171. return false;
  172. }
  173. }
  174. ModelArchitecture ModelDetector::analyzeArchitecture(
  175. const std::map<std::string, std::vector<int64_t>>& tensorInfo,
  176. const std::map<std::string, std::string>& metadata,
  177. const std::string& filename) {
  178. // Check metadata first for explicit architecture hints
  179. auto modelTypeIt = metadata.find("modelspec.architecture");
  180. if (modelTypeIt != metadata.end()) {
  181. const std::string& archName = modelTypeIt->second;
  182. if (archName.find("stable-diffusion-xl") != std::string::npos) {
  183. return ModelArchitecture::SDXL_BASE;
  184. } else if (archName.find("stable-diffusion-v2") != std::string::npos) {
  185. return ModelArchitecture::SD_2_1;
  186. } else if (archName.find("stable-diffusion-v1") != std::string::npos) {
  187. return ModelArchitecture::SD_1_5;
  188. }
  189. }
  190. // Check filename for special variants
  191. std::string lowerFilename = filename;
  192. std::transform(lowerFilename.begin(), lowerFilename.end(), lowerFilename.begin(), ::tolower);
  193. // Analyze tensor structure for architecture detection
  194. bool hasConditioner = false;
  195. bool hasTextEncoder2 = false;
  196. bool hasFluxStructure = false;
  197. bool hasSD3Structure = false;
  198. int maxUNetChannels = 0;
  199. int textEncoderOutputDim = 0;
  200. for (const auto& [name, shape] : tensorInfo) {
  201. // Check for SDXL-specific components
  202. if (name.find("conditioner") != std::string::npos) {
  203. hasConditioner = true;
  204. }
  205. if (name.find("text_encoder_2") != std::string::npos ||
  206. name.find("cond_stage_model.1") != std::string::npos) {
  207. hasTextEncoder2 = true;
  208. }
  209. // Check for Flux-specific patterns
  210. if (name.find("double_blocks") != std::string::npos ||
  211. name.find("single_blocks") != std::string::npos) {
  212. hasFluxStructure = true;
  213. }
  214. // Check for SD3-specific patterns
  215. if (name.find("joint_blocks") != std::string::npos) {
  216. hasSD3Structure = true;
  217. }
  218. // Analyze UNet structure
  219. if (name.find("model.diffusion_model") != std::string::npos ||
  220. name.find("unet") != std::string::npos) {
  221. if (shape.size() >= 2) {
  222. maxUNetChannels = std::max(maxUNetChannels, static_cast<int>(shape[0]));
  223. }
  224. }
  225. // Check text encoder dimensions
  226. if (name.find("cond_stage_model") != std::string::npos ||
  227. name.find("text_encoder") != std::string::npos) {
  228. if (name.find("proj") != std::string::npos && shape.size() >= 2) {
  229. textEncoderOutputDim = std::max(textEncoderOutputDim, static_cast<int>(shape[1]));
  230. }
  231. }
  232. }
  233. // Determine architecture based on analysis with improved priority
  234. if (hasFluxStructure) {
  235. // Check for Chroma variant (unlocked Flux)
  236. if (lowerFilename.find("chroma") != std::string::npos) {
  237. return ModelArchitecture::FLUX_CHROMA;
  238. }
  239. // Check if it's Schnell or Dev based on step count hints
  240. auto stepsIt = metadata.find("diffusion_steps");
  241. if (stepsIt != metadata.end() && stepsIt->second.find("4") != std::string::npos) {
  242. return ModelArchitecture::FLUX_SCHNELL;
  243. }
  244. return ModelArchitecture::FLUX_DEV;
  245. }
  246. if (hasSD3Structure) {
  247. return ModelArchitecture::SD_3;
  248. }
  249. if (hasConditioner || hasTextEncoder2) {
  250. // SDXL architecture - check for refiner using multiple criteria
  251. bool hasRefinerMarkers = false;
  252. bool hasSmallUNet = false;
  253. // Check for refiner markers in tensor names
  254. for (const auto& [name, _] : tensorInfo) {
  255. if (name.find("refiner") != std::string::npos) {
  256. hasRefinerMarkers = true;
  257. break;
  258. }
  259. }
  260. // Check for smaller UNet channel counts (typical of refiner models)
  261. if (maxUNetChannels > 0 && maxUNetChannels < 2400) {
  262. hasSmallUNet = true;
  263. }
  264. // Additional check: look for refiner-specific metadata
  265. auto refinerIt = metadata.find("refiner");
  266. if (refinerIt != metadata.end() && refinerIt->second == "true") {
  267. hasRefinerMarkers = true;
  268. }
  269. // Return refiner if either marker is found, otherwise base
  270. if (hasRefinerMarkers || hasSmallUNet) {
  271. return ModelArchitecture::SDXL_REFINER;
  272. }
  273. return ModelArchitecture::SDXL_BASE;
  274. }
  275. // Check for Qwen2-VL specific patterns before falling back to dimension-based detection
  276. bool hasQwenPatterns = false;
  277. // Check metadata for Qwen pipeline class
  278. auto pipelineIt = metadata.find("_model_name");
  279. if (pipelineIt != metadata.end() && pipelineIt->second.find("QwenImagePipeline") != std::string::npos) {
  280. hasQwenPatterns = true;
  281. }
  282. // Check for Qwen-specific tensor patterns
  283. bool hasTransformerBlocks = false;
  284. bool hasImgMod = false;
  285. bool hasTxtMod = false;
  286. bool hasImgIn = false;
  287. bool hasTxtIn = false;
  288. bool hasProjOut = false;
  289. bool hasVisualBlocks = false;
  290. for (const auto& [name, shape] : tensorInfo) {
  291. // Check for transformer blocks
  292. if (name.find("transformer_blocks") != std::string::npos) {
  293. hasTransformerBlocks = true;
  294. }
  295. // Check for modulation patterns
  296. if (name.find("img_mod") != std::string::npos) {
  297. hasImgMod = true;
  298. }
  299. if (name.find("txt_mod") != std::string::npos) {
  300. hasTxtMod = true;
  301. }
  302. // Check for input patterns
  303. if (name.find("img_in") != std::string::npos) {
  304. hasImgIn = true;
  305. }
  306. if (name.find("txt_in") != std::string::npos) {
  307. hasTxtIn = true;
  308. }
  309. // Check for output projection
  310. if (name.find("proj_out") != std::string::npos) {
  311. hasProjOut = true;
  312. }
  313. // Check for visual blocks (Qwen2-VL structure)
  314. if (name.find("visual.blocks") != std::string::npos) {
  315. hasVisualBlocks = true;
  316. }
  317. }
  318. // Determine if this is a Qwen model based on multiple patterns
  319. if (hasTransformerBlocks && (hasImgMod || hasTxtMod) && (hasImgIn || hasTxtIn) && hasProjOut) {
  320. hasQwenPatterns = true;
  321. }
  322. // Additional check for visual blocks pattern
  323. if (hasVisualBlocks && (hasImgMod || hasTxtMod)) {
  324. hasQwenPatterns = true;
  325. }
  326. if (hasQwenPatterns) {
  327. return ModelArchitecture::QWEN2VL;
  328. }
  329. // Improved detection priority order
  330. // First, check for Flux-specific patterns even if text encoder dimension is 1280
  331. if (hasFluxStructure) {
  332. // This should have been caught earlier, but double-check for edge cases
  333. if (lowerFilename.find("chroma") != std::string::npos) {
  334. return ModelArchitecture::FLUX_CHROMA;
  335. }
  336. auto stepsIt = metadata.find("diffusion_steps");
  337. if (stepsIt != metadata.end() && stepsIt->second.find("4") != std::string::npos) {
  338. return ModelArchitecture::FLUX_SCHNELL;
  339. }
  340. return ModelArchitecture::FLUX_DEV;
  341. }
  342. // Check text encoder dimensions with enhanced logic for 1280 dimension
  343. if (textEncoderOutputDim == 768) {
  344. return ModelArchitecture::SD_1_5;
  345. }
  346. if (textEncoderOutputDim >= 1024 && textEncoderOutputDim < 1280) {
  347. return ModelArchitecture::SD_2_1;
  348. }
  349. if (textEncoderOutputDim == 1280) {
  350. // Enhanced 1280 dimension detection: distinguish between SDXL Base, SDXL Refiner, and Flux
  351. // Check if we already determined this is Flux (should have been caught earlier)
  352. if (hasFluxStructure) {
  353. if (lowerFilename.find("chroma") != std::string::npos) {
  354. return ModelArchitecture::FLUX_CHROMA;
  355. }
  356. auto stepsIt = metadata.find("diffusion_steps");
  357. if (stepsIt != metadata.end() && stepsIt->second.find("4") != std::string::npos) {
  358. return ModelArchitecture::FLUX_SCHNELL;
  359. }
  360. return ModelArchitecture::FLUX_DEV;
  361. }
  362. // Check for SDXL Refiner indicators
  363. bool hasRefinerMarkers = false;
  364. bool hasSmallUNet = false;
  365. for (const auto& [name, _] : tensorInfo) {
  366. if (name.find("refiner") != std::string::npos) {
  367. hasRefinerMarkers = true;
  368. break;
  369. }
  370. }
  371. if (maxUNetChannels > 0 && maxUNetChannels < 2400) {
  372. hasSmallUNet = true;
  373. }
  374. auto refinerIt = metadata.find("refiner");
  375. if (refinerIt != metadata.end() && refinerIt->second == "true") {
  376. hasRefinerMarkers = true;
  377. }
  378. if (hasRefinerMarkers || hasSmallUNet) {
  379. return ModelArchitecture::SDXL_REFINER;
  380. }
  381. // Default to SDXL Base for 1280 dimension
  382. return ModelArchitecture::SDXL_BASE;
  383. }
  384. // Only use UNet channel count as a last resort when text encoder dimensions are unclear
  385. if (maxUNetChannels >= 2048) {
  386. return ModelArchitecture::SDXL_BASE;
  387. }
  388. // Fallback detection based on UNet channels when text encoder info is unavailable
  389. if (maxUNetChannels == 1280) {
  390. return ModelArchitecture::SD_2_1;
  391. }
  392. if (maxUNetChannels <= 1280) {
  393. return ModelArchitecture::SD_1_5;
  394. }
  395. return ModelArchitecture::UNKNOWN;
  396. }
  397. std::string ModelDetector::getArchitectureName(ModelArchitecture arch) {
  398. switch (arch) {
  399. case ModelArchitecture::SD_1_5:
  400. return "Stable Diffusion 1.5";
  401. case ModelArchitecture::SD_2_1:
  402. return "Stable Diffusion 2.1";
  403. case ModelArchitecture::SDXL_BASE:
  404. return "Stable Diffusion XL Base";
  405. case ModelArchitecture::SDXL_REFINER:
  406. return "Stable Diffusion XL Refiner";
  407. case ModelArchitecture::FLUX_SCHNELL:
  408. return "Flux Schnell";
  409. case ModelArchitecture::FLUX_DEV:
  410. return "Flux Dev";
  411. case ModelArchitecture::FLUX_CHROMA:
  412. return "Flux Chroma (Unlocked)";
  413. case ModelArchitecture::SD_3:
  414. return "Stable Diffusion 3";
  415. case ModelArchitecture::QWEN2VL:
  416. return "Qwen2-VL";
  417. default:
  418. return "Unknown";
  419. }
  420. }
  421. std::map<std::string, std::string> ModelDetector::getRecommendedParams(ModelArchitecture arch) {
  422. std::map<std::string, std::string> params;
  423. switch (arch) {
  424. case ModelArchitecture::SD_1_5:
  425. params["width"] = "512";
  426. params["height"] = "512";
  427. params["cfg_scale"] = "7.5";
  428. params["steps"] = "20";
  429. params["sampler"] = "euler_a";
  430. break;
  431. case ModelArchitecture::SD_2_1:
  432. params["width"] = "768";
  433. params["height"] = "768";
  434. params["cfg_scale"] = "7.0";
  435. params["steps"] = "25";
  436. params["sampler"] = "euler_a";
  437. break;
  438. case ModelArchitecture::SDXL_BASE:
  439. case ModelArchitecture::SDXL_REFINER:
  440. params["width"] = "1024";
  441. params["height"] = "1024";
  442. params["cfg_scale"] = "7.0";
  443. params["steps"] = "30";
  444. params["sampler"] = "dpm++2m";
  445. break;
  446. case ModelArchitecture::FLUX_SCHNELL:
  447. params["width"] = "1024";
  448. params["height"] = "1024";
  449. params["cfg_scale"] = "1.0";
  450. params["steps"] = "4";
  451. params["sampler"] = "euler";
  452. break;
  453. case ModelArchitecture::FLUX_DEV:
  454. params["width"] = "1024";
  455. params["height"] = "1024";
  456. params["cfg_scale"] = "1.0";
  457. params["steps"] = "20";
  458. params["sampler"] = "euler";
  459. break;
  460. case ModelArchitecture::FLUX_CHROMA:
  461. params["width"] = "1024";
  462. params["height"] = "1024";
  463. params["cfg_scale"] = "1.0";
  464. params["steps"] = "20";
  465. params["sampler"] = "euler";
  466. break;
  467. case ModelArchitecture::SD_3:
  468. params["width"] = "1024";
  469. params["height"] = "1024";
  470. params["cfg_scale"] = "5.0";
  471. params["steps"] = "28";
  472. params["sampler"] = "dpm++2m";
  473. break;
  474. default:
  475. break;
  476. }
  477. return params;
  478. }
  479. bool ModelDetector::parseGGUFHeader(
  480. const std::string& filePath,
  481. std::map<std::string, std::string>& metadata,
  482. std::map<std::string, std::vector<int64_t>>& tensorInfo) {
  483. std::ifstream file(filePath, std::ios::binary);
  484. if (!file.is_open()) {
  485. return false;
  486. }
  487. // Read and verify magic number "GGUF"
  488. char magic[4];
  489. file.read(magic, 4);
  490. if (file.gcount() != 4 || std::memcmp(magic, "GGUF", 4) != 0) {
  491. return false;
  492. }
  493. // Read version (uint32)
  494. uint32_t version;
  495. file.read(reinterpret_cast<char*>(&version), 4);
  496. if (file.gcount() != 4) {
  497. return false;
  498. }
  499. // Read tensor count (uint64)
  500. uint64_t tensorCount;
  501. file.read(reinterpret_cast<char*>(&tensorCount), 8);
  502. if (file.gcount() != 8) {
  503. return false;
  504. }
  505. // Read metadata KV count (uint64)
  506. uint64_t metadataCount;
  507. file.read(reinterpret_cast<char*>(&metadataCount), 8);
  508. if (file.gcount() != 8) {
  509. return false;
  510. }
  511. // Helper function to read string
  512. auto readString = [&file]() -> std::string {
  513. uint64_t length;
  514. file.read(reinterpret_cast<char*>(&length), 8);
  515. if (file.gcount() != 8 || length == 0 || length > 10000) {
  516. return "";
  517. }
  518. std::vector<char> buffer(length);
  519. file.read(buffer.data(), length);
  520. if (file.gcount() != static_cast<std::streamsize>(length)) {
  521. return "";
  522. }
  523. return std::string(buffer.begin(), buffer.end());
  524. };
  525. // Read metadata key-value pairs
  526. for (uint64_t i = 0; i < metadataCount && file.good(); ++i) {
  527. std::string key = readString();
  528. if (key.empty())
  529. break;
  530. // Read value type (uint32)
  531. uint32_t valueType;
  532. file.read(reinterpret_cast<char*>(&valueType), 4);
  533. if (file.gcount() != 4)
  534. break;
  535. // Parse value based on type
  536. std::string value;
  537. switch (valueType) {
  538. case 8: // String
  539. value = readString();
  540. break;
  541. case 4: { // Uint32
  542. uint32_t val;
  543. file.read(reinterpret_cast<char*>(&val), 4);
  544. value = std::to_string(val);
  545. break;
  546. }
  547. case 5: { // Int32
  548. int32_t val;
  549. file.read(reinterpret_cast<char*>(&val), 4);
  550. value = std::to_string(val);
  551. break;
  552. }
  553. case 6: { // Float32
  554. float val;
  555. file.read(reinterpret_cast<char*>(&val), 4);
  556. value = std::to_string(val);
  557. break;
  558. }
  559. case 0: { // Uint8
  560. uint8_t val;
  561. file.read(reinterpret_cast<char*>(&val), 1);
  562. value = std::to_string(val);
  563. break;
  564. }
  565. case 1: { // Int8
  566. int8_t val;
  567. file.read(reinterpret_cast<char*>(&val), 1);
  568. value = std::to_string(val);
  569. break;
  570. }
  571. default:
  572. // Skip unknown types
  573. file.seekg(8, std::ios::cur);
  574. continue;
  575. }
  576. if (!value.empty()) {
  577. metadata[key] = value;
  578. }
  579. }
  580. // Read tensor information
  581. for (uint64_t i = 0; i < tensorCount && file.good(); ++i) {
  582. std::string tensorName = readString();
  583. if (tensorName.empty())
  584. break;
  585. // Read number of dimensions (uint32)
  586. uint32_t nDims;
  587. file.read(reinterpret_cast<char*>(&nDims), 4);
  588. if (file.gcount() != 4 || nDims > 10)
  589. break;
  590. // Read dimensions (uint64 array)
  591. std::vector<int64_t> shape(nDims);
  592. for (uint32_t d = 0; d < nDims; ++d) {
  593. uint64_t dim;
  594. file.read(reinterpret_cast<char*>(&dim), 8);
  595. if (file.gcount() != 8)
  596. break;
  597. shape[d] = static_cast<int64_t>(dim);
  598. }
  599. // Skip type (uint32) and offset (uint64)
  600. file.seekg(12, std::ios::cur);
  601. tensorInfo[tensorName] = shape;
  602. }
  603. return !tensorInfo.empty();
  604. }