model_detector.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. #include "model_detector.h"
  2. #include <algorithm>
  3. #include <cstring>
  4. #include <fstream>
  5. #include <iostream>
  6. #include <nlohmann/json.hpp>
  7. // Helper function for C++17 compatibility (ends_with is C++20)
  8. static bool endsWith(const std::string& str, const std::string& suffix) {
  9. if (suffix.size() > str.size())
  10. return false;
  11. return str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0;
  12. }
  13. ModelDetectionResult ModelDetector::detectModel(const std::string& modelPath) {
  14. ModelDetectionResult result;
  15. std::map<std::string, std::vector<int64_t>> tensorInfo;
  16. // Determine file type and parse accordingly
  17. bool parsed = false;
  18. if (endsWith(modelPath, ".safetensors")) {
  19. parsed = parseSafetensorsHeader(modelPath, result.metadata, tensorInfo);
  20. } else if (endsWith(modelPath, ".gguf")) {
  21. parsed = parseGGUFHeader(modelPath, result.metadata, tensorInfo);
  22. } else if (endsWith(modelPath, ".ckpt") || endsWith(modelPath, ".pt")) {
  23. // PyTorch pickle files - these require the full PyTorch library to parse safely
  24. // For now, we cannot detect their architecture without loading the model
  25. // Return unknown architecture with a note in metadata
  26. result.metadata["format"] = "pytorch_pickle";
  27. result.metadata["note"] = "Architecture detection not supported for .ckpt/.pt files";
  28. return result;
  29. }
  30. if (!parsed) {
  31. return result; // Unknown if we can't parse
  32. }
  33. // Store tensor names for reference
  34. for (const auto& [name, _] : tensorInfo) {
  35. result.tensorNames.push_back(name);
  36. }
  37. // Analyze architecture (pass filename for special detection)
  38. std::string filename = modelPath.substr(modelPath.find_last_of("/\\") + 1);
  39. result.architecture = analyzeArchitecture(tensorInfo, result.metadata, filename);
  40. result.architectureName = getArchitectureName(result.architecture);
  41. // Set architecture-specific properties and required models
  42. switch (result.architecture) {
  43. case ModelArchitecture::SD_1_5:
  44. result.textEncoderDim = 768;
  45. result.unetChannels = 1280;
  46. result.needsVAE = true;
  47. result.recommendedVAE = "vae-ft-mse-840000-ema-pruned.safetensors";
  48. result.needsTAESD = true;
  49. result.suggestedParams["vae_flag"] = "--vae";
  50. break;
  51. case ModelArchitecture::SD_2_1:
  52. result.textEncoderDim = 1024;
  53. result.unetChannels = 1280;
  54. result.needsVAE = true;
  55. result.recommendedVAE = "vae-ft-ema-560000.safetensors";
  56. result.needsTAESD = true;
  57. result.suggestedParams["vae_flag"] = "--vae";
  58. break;
  59. case ModelArchitecture::SDXL_BASE:
  60. case ModelArchitecture::SDXL_REFINER:
  61. result.textEncoderDim = 1280;
  62. result.unetChannels = 2560;
  63. result.hasConditioner = true;
  64. result.needsVAE = true;
  65. result.recommendedVAE = "sdxl_vae.safetensors";
  66. result.needsTAESD = true;
  67. result.suggestedParams["vae_flag"] = "--vae";
  68. break;
  69. case ModelArchitecture::FLUX_SCHNELL:
  70. case ModelArchitecture::FLUX_DEV:
  71. result.textEncoderDim = 4096;
  72. result.needsVAE = true;
  73. result.recommendedVAE = "ae.safetensors";
  74. // Flux requires CLIP-L and T5XXL
  75. result.suggestedParams["vae_flag"] = "--vae";
  76. result.suggestedParams["clip_l_required"] = "clip_l.safetensors";
  77. result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors";
  78. result.suggestedParams["clip_l_flag"] = "--clip-l";
  79. result.suggestedParams["t5xxl_flag"] = "--t5xxl";
  80. break;
  81. case ModelArchitecture::FLUX_CHROMA:
  82. result.textEncoderDim = 4096;
  83. result.needsVAE = true;
  84. result.recommendedVAE = "ae.safetensors";
  85. // Chroma (Flux Unlocked) requires VAE and T5XXL
  86. result.suggestedParams["vae_flag"] = "--vae";
  87. result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors";
  88. result.suggestedParams["t5xxl_flag"] = "--t5xxl";
  89. break;
  90. case ModelArchitecture::SD_3:
  91. result.textEncoderDim = 4096;
  92. result.needsVAE = true;
  93. result.recommendedVAE = "sd3_vae.safetensors";
  94. // SD3 requires CLIP-L, CLIP-G, and T5XXL
  95. result.suggestedParams["vae_flag"] = "--vae";
  96. result.suggestedParams["clip_l_required"] = "clip_l.safetensors";
  97. result.suggestedParams["clip_g_required"] = "clip_g.safetensors";
  98. result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors";
  99. result.suggestedParams["clip_l_flag"] = "--clip-l";
  100. result.suggestedParams["clip_g_flag"] = "--clip-g";
  101. result.suggestedParams["t5xxl_flag"] = "--t5xxl";
  102. break;
  103. case ModelArchitecture::QWEN2VL:
  104. // Qwen2-VL requires vision and language model components
  105. result.suggestedParams["qwen2vl_required"] = "qwen2vl.safetensors";
  106. result.suggestedParams["qwen2vl_vision_required"] = "qwen2vl_vision.safetensors";
  107. result.suggestedParams["qwen2vl_flag"] = "--qwen2vl";
  108. result.suggestedParams["qwen2vl_vision_flag"] = "--qwen2vl-vision";
  109. break;
  110. default:
  111. break;
  112. }
  113. // Merge with general recommended parameters (width, height, steps, etc.)
  114. auto generalParams = getRecommendedParams(result.architecture);
  115. for (const auto& [key, value] : generalParams) {
  116. // Only add if not already set (preserve architecture-specific flags)
  117. if (result.suggestedParams.find(key) == result.suggestedParams.end()) {
  118. result.suggestedParams[key] = value;
  119. }
  120. }
  121. return result;
  122. }
  123. bool ModelDetector::parseSafetensorsHeader(
  124. const std::string& filePath,
  125. std::map<std::string, std::string>& metadata,
  126. std::map<std::string, std::vector<int64_t>>& tensorInfo) {
  127. std::ifstream file(filePath, std::ios::binary);
  128. if (!file.is_open()) {
  129. return false;
  130. }
  131. // Read header length (first 8 bytes, little-endian uint64)
  132. uint64_t headerLength = 0;
  133. file.read(reinterpret_cast<char*>(&headerLength), 8);
  134. if (file.gcount() != 8) {
  135. return false;
  136. }
  137. // Sanity check: header should be reasonable size (< 100MB)
  138. if (headerLength == 0 || headerLength > 100 * 1024 * 1024) {
  139. return false;
  140. }
  141. // Read header JSON
  142. std::vector<char> headerBuffer(headerLength);
  143. file.read(headerBuffer.data(), headerLength);
  144. if (file.gcount() != static_cast<std::streamsize>(headerLength)) {
  145. return false;
  146. }
  147. // Parse JSON
  148. try {
  149. nlohmann::json headerJson = nlohmann::json::parse(headerBuffer.begin(), headerBuffer.end());
  150. // Extract metadata if present
  151. if (headerJson.contains("__metadata__")) {
  152. auto metadataJson = headerJson["__metadata__"];
  153. for (auto it = metadataJson.begin(); it != metadataJson.end(); ++it) {
  154. metadata[it.key()] = it.value().get<std::string>();
  155. }
  156. }
  157. // Extract tensor information
  158. for (auto it = headerJson.begin(); it != headerJson.end(); ++it) {
  159. if (it.key() == "__metadata__")
  160. continue;
  161. if (it.value().contains("shape")) {
  162. std::vector<int64_t> shape;
  163. for (const auto& dim : it.value()["shape"]) {
  164. shape.push_back(dim.get<int64_t>());
  165. }
  166. tensorInfo[it.key()] = shape;
  167. }
  168. }
  169. return true;
  170. } catch (const std::exception& e) {
  171. return false;
  172. }
  173. }
  174. ModelArchitecture ModelDetector::analyzeArchitecture(
  175. const std::map<std::string, std::vector<int64_t>>& tensorInfo,
  176. const std::map<std::string, std::string>& metadata,
  177. const std::string& filename) {
  178. // Check metadata first for explicit architecture hints
  179. auto modelTypeIt = metadata.find("modelspec.architecture");
  180. if (modelTypeIt != metadata.end()) {
  181. const std::string& archName = modelTypeIt->second;
  182. if (archName.find("stable-diffusion-xl") != std::string::npos) {
  183. return ModelArchitecture::SDXL_BASE;
  184. } else if (archName.find("stable-diffusion-v2") != std::string::npos) {
  185. return ModelArchitecture::SD_2_1;
  186. } else if (archName.find("stable-diffusion-v1") != std::string::npos) {
  187. return ModelArchitecture::SD_1_5;
  188. }
  189. }
  190. // Check filename for special variants
  191. std::string lowerFilename = filename;
  192. std::transform(lowerFilename.begin(), lowerFilename.end(), lowerFilename.begin(), ::tolower);
  193. // Analyze tensor structure for architecture detection
  194. bool hasConditioner = false;
  195. bool hasTextEncoder2 = false;
  196. bool hasFluxStructure = false;
  197. bool hasSD3Structure = false;
  198. int maxUNetChannels = 0;
  199. int textEncoderOutputDim = 0;
  200. for (const auto& [name, shape] : tensorInfo) {
  201. // Check for SDXL-specific components
  202. if (name.find("conditioner") != std::string::npos) {
  203. hasConditioner = true;
  204. }
  205. if (name.find("text_encoder_2") != std::string::npos ||
  206. name.find("cond_stage_model.1") != std::string::npos) {
  207. hasTextEncoder2 = true;
  208. }
  209. // Check for Flux-specific patterns
  210. if (name.find("double_blocks") != std::string::npos ||
  211. name.find("single_blocks") != std::string::npos) {
  212. hasFluxStructure = true;
  213. }
  214. // Check for SD3-specific patterns
  215. if (name.find("joint_blocks") != std::string::npos) {
  216. hasSD3Structure = true;
  217. }
  218. // Analyze UNet structure
  219. if (name.find("model.diffusion_model") != std::string::npos ||
  220. name.find("unet") != std::string::npos) {
  221. if (shape.size() >= 2) {
  222. maxUNetChannels = std::max(maxUNetChannels, static_cast<int>(shape[0]));
  223. }
  224. }
  225. // Check text encoder dimensions
  226. if (name.find("cond_stage_model") != std::string::npos ||
  227. name.find("text_encoder") != std::string::npos) {
  228. if (name.find("proj") != std::string::npos && shape.size() >= 2) {
  229. textEncoderOutputDim = std::max(textEncoderOutputDim, static_cast<int>(shape[1]));
  230. }
  231. }
  232. }
  233. // Determine architecture based on analysis
  234. if (hasFluxStructure) {
  235. // Check for Chroma variant (unlocked Flux)
  236. if (lowerFilename.find("chroma") != std::string::npos) {
  237. return ModelArchitecture::FLUX_CHROMA;
  238. }
  239. // Check if it's Schnell or Dev based on step count hints
  240. auto stepsIt = metadata.find("diffusion_steps");
  241. if (stepsIt != metadata.end() && stepsIt->second.find("4") != std::string::npos) {
  242. return ModelArchitecture::FLUX_SCHNELL;
  243. }
  244. return ModelArchitecture::FLUX_DEV;
  245. }
  246. if (hasSD3Structure) {
  247. return ModelArchitecture::SD_3;
  248. }
  249. if (hasConditioner || hasTextEncoder2) {
  250. // SDXL architecture
  251. bool hasRefinerMarkers = false;
  252. for (const auto& [name, _] : tensorInfo) {
  253. if (name.find("refiner") != std::string::npos) {
  254. hasRefinerMarkers = true;
  255. break;
  256. }
  257. }
  258. return hasRefinerMarkers ? ModelArchitecture::SDXL_REFINER : ModelArchitecture::SDXL_BASE;
  259. }
  260. if (maxUNetChannels >= 2048) {
  261. return ModelArchitecture::SDXL_BASE;
  262. }
  263. // Distinguish between SD1.x and SD2.x by text encoder dimension
  264. if (textEncoderOutputDim >= 1024 || maxUNetChannels == 1280) {
  265. return ModelArchitecture::SD_2_1;
  266. }
  267. if (textEncoderOutputDim == 768 || maxUNetChannels <= 1280) {
  268. return ModelArchitecture::SD_1_5;
  269. }
  270. return ModelArchitecture::UNKNOWN;
  271. }
  272. std::string ModelDetector::getArchitectureName(ModelArchitecture arch) {
  273. switch (arch) {
  274. case ModelArchitecture::SD_1_5:
  275. return "Stable Diffusion 1.5";
  276. case ModelArchitecture::SD_2_1:
  277. return "Stable Diffusion 2.1";
  278. case ModelArchitecture::SDXL_BASE:
  279. return "Stable Diffusion XL Base";
  280. case ModelArchitecture::SDXL_REFINER:
  281. return "Stable Diffusion XL Refiner";
  282. case ModelArchitecture::FLUX_SCHNELL:
  283. return "Flux Schnell";
  284. case ModelArchitecture::FLUX_DEV:
  285. return "Flux Dev";
  286. case ModelArchitecture::FLUX_CHROMA:
  287. return "Flux Chroma (Unlocked)";
  288. case ModelArchitecture::SD_3:
  289. return "Stable Diffusion 3";
  290. case ModelArchitecture::QWEN2VL:
  291. return "Qwen2-VL";
  292. default:
  293. return "Unknown";
  294. }
  295. }
  296. std::map<std::string, std::string> ModelDetector::getRecommendedParams(ModelArchitecture arch) {
  297. std::map<std::string, std::string> params;
  298. switch (arch) {
  299. case ModelArchitecture::SD_1_5:
  300. params["width"] = "512";
  301. params["height"] = "512";
  302. params["cfg_scale"] = "7.5";
  303. params["steps"] = "20";
  304. params["sampler"] = "euler_a";
  305. break;
  306. case ModelArchitecture::SD_2_1:
  307. params["width"] = "768";
  308. params["height"] = "768";
  309. params["cfg_scale"] = "7.0";
  310. params["steps"] = "25";
  311. params["sampler"] = "euler_a";
  312. break;
  313. case ModelArchitecture::SDXL_BASE:
  314. case ModelArchitecture::SDXL_REFINER:
  315. params["width"] = "1024";
  316. params["height"] = "1024";
  317. params["cfg_scale"] = "7.0";
  318. params["steps"] = "30";
  319. params["sampler"] = "dpm++2m";
  320. break;
  321. case ModelArchitecture::FLUX_SCHNELL:
  322. params["width"] = "1024";
  323. params["height"] = "1024";
  324. params["cfg_scale"] = "1.0";
  325. params["steps"] = "4";
  326. params["sampler"] = "euler";
  327. break;
  328. case ModelArchitecture::FLUX_DEV:
  329. params["width"] = "1024";
  330. params["height"] = "1024";
  331. params["cfg_scale"] = "1.0";
  332. params["steps"] = "20";
  333. params["sampler"] = "euler";
  334. break;
  335. case ModelArchitecture::FLUX_CHROMA:
  336. params["width"] = "1024";
  337. params["height"] = "1024";
  338. params["cfg_scale"] = "1.0";
  339. params["steps"] = "20";
  340. params["sampler"] = "euler";
  341. break;
  342. case ModelArchitecture::SD_3:
  343. params["width"] = "1024";
  344. params["height"] = "1024";
  345. params["cfg_scale"] = "5.0";
  346. params["steps"] = "28";
  347. params["sampler"] = "dpm++2m";
  348. break;
  349. default:
  350. break;
  351. }
  352. return params;
  353. }
  354. bool ModelDetector::parseGGUFHeader(
  355. const std::string& filePath,
  356. std::map<std::string, std::string>& metadata,
  357. std::map<std::string, std::vector<int64_t>>& tensorInfo) {
  358. std::ifstream file(filePath, std::ios::binary);
  359. if (!file.is_open()) {
  360. return false;
  361. }
  362. // Read and verify magic number "GGUF"
  363. char magic[4];
  364. file.read(magic, 4);
  365. if (file.gcount() != 4 || std::memcmp(magic, "GGUF", 4) != 0) {
  366. return false;
  367. }
  368. // Read version (uint32)
  369. uint32_t version;
  370. file.read(reinterpret_cast<char*>(&version), 4);
  371. if (file.gcount() != 4) {
  372. return false;
  373. }
  374. // Read tensor count (uint64)
  375. uint64_t tensorCount;
  376. file.read(reinterpret_cast<char*>(&tensorCount), 8);
  377. if (file.gcount() != 8) {
  378. return false;
  379. }
  380. // Read metadata KV count (uint64)
  381. uint64_t metadataCount;
  382. file.read(reinterpret_cast<char*>(&metadataCount), 8);
  383. if (file.gcount() != 8) {
  384. return false;
  385. }
  386. // Helper function to read string
  387. auto readString = [&file]() -> std::string {
  388. uint64_t length;
  389. file.read(reinterpret_cast<char*>(&length), 8);
  390. if (file.gcount() != 8 || length == 0 || length > 10000) {
  391. return "";
  392. }
  393. std::vector<char> buffer(length);
  394. file.read(buffer.data(), length);
  395. if (file.gcount() != static_cast<std::streamsize>(length)) {
  396. return "";
  397. }
  398. return std::string(buffer.begin(), buffer.end());
  399. };
  400. // Read metadata key-value pairs
  401. for (uint64_t i = 0; i < metadataCount && file.good(); ++i) {
  402. std::string key = readString();
  403. if (key.empty())
  404. break;
  405. // Read value type (uint32)
  406. uint32_t valueType;
  407. file.read(reinterpret_cast<char*>(&valueType), 4);
  408. if (file.gcount() != 4)
  409. break;
  410. // Parse value based on type
  411. std::string value;
  412. switch (valueType) {
  413. case 8: // String
  414. value = readString();
  415. break;
  416. case 4: { // Uint32
  417. uint32_t val;
  418. file.read(reinterpret_cast<char*>(&val), 4);
  419. value = std::to_string(val);
  420. break;
  421. }
  422. case 5: { // Int32
  423. int32_t val;
  424. file.read(reinterpret_cast<char*>(&val), 4);
  425. value = std::to_string(val);
  426. break;
  427. }
  428. case 6: { // Float32
  429. float val;
  430. file.read(reinterpret_cast<char*>(&val), 4);
  431. value = std::to_string(val);
  432. break;
  433. }
  434. case 0: { // Uint8
  435. uint8_t val;
  436. file.read(reinterpret_cast<char*>(&val), 1);
  437. value = std::to_string(val);
  438. break;
  439. }
  440. case 1: { // Int8
  441. int8_t val;
  442. file.read(reinterpret_cast<char*>(&val), 1);
  443. value = std::to_string(val);
  444. break;
  445. }
  446. default:
  447. // Skip unknown types
  448. file.seekg(8, std::ios::cur);
  449. continue;
  450. }
  451. if (!value.empty()) {
  452. metadata[key] = value;
  453. }
  454. }
  455. // Read tensor information
  456. for (uint64_t i = 0; i < tensorCount && file.good(); ++i) {
  457. std::string tensorName = readString();
  458. if (tensorName.empty())
  459. break;
  460. // Read number of dimensions (uint32)
  461. uint32_t nDims;
  462. file.read(reinterpret_cast<char*>(&nDims), 4);
  463. if (file.gcount() != 4 || nDims > 10)
  464. break;
  465. // Read dimensions (uint64 array)
  466. std::vector<int64_t> shape(nDims);
  467. for (uint32_t d = 0; d < nDims; ++d) {
  468. uint64_t dim;
  469. file.read(reinterpret_cast<char*>(&dim), 8);
  470. if (file.gcount() != 8)
  471. break;
  472. shape[d] = static_cast<int64_t>(dim);
  473. }
  474. // Skip type (uint32) and offset (uint64)
  475. file.seekg(12, std::ios::cur);
  476. tensorInfo[tensorName] = shape;
  477. }
  478. return !tensorInfo.empty();
  479. }