model_detector.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. #include "model_detector.h"
  2. #include <nlohmann/json.hpp>
  3. #include <fstream>
  4. #include <algorithm>
  5. #include <cstring>
  6. using json = nlohmann::json;
  7. // Helper function for C++17 compatibility (ends_with is C++20)
  8. static bool endsWith(const std::string& str, const std::string& suffix) {
  9. if (suffix.size() > str.size()) return false;
  10. return str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0;
  11. }
  12. ModelDetectionResult ModelDetector::detectModel(const std::string& modelPath) {
  13. ModelDetectionResult result;
  14. std::map<std::string, std::vector<int64_t>> tensorInfo;
  15. // Determine file type and parse accordingly
  16. bool parsed = false;
  17. if (endsWith(modelPath, ".safetensors")) {
  18. parsed = parseSafetensorsHeader(modelPath, result.metadata, tensorInfo);
  19. } else if (endsWith(modelPath, ".gguf")) {
  20. parsed = parseGGUFHeader(modelPath, result.metadata, tensorInfo);
  21. } else if (endsWith(modelPath, ".ckpt") || endsWith(modelPath, ".pt")) {
  22. // PyTorch pickle files - these require the full PyTorch library to parse safely
  23. // For now, we cannot detect their architecture without loading the model
  24. // Return unknown architecture with a note in metadata
  25. result.metadata["format"] = "pytorch_pickle";
  26. result.metadata["note"] = "Architecture detection not supported for .ckpt/.pt files";
  27. return result;
  28. }
  29. if (!parsed) {
  30. return result; // Unknown if we can't parse
  31. }
  32. // Store tensor names for reference
  33. for (const auto& [name, _] : tensorInfo) {
  34. result.tensorNames.push_back(name);
  35. }
  36. // Analyze architecture (pass filename for special detection)
  37. std::string filename = modelPath.substr(modelPath.find_last_of("/\\") + 1);
  38. result.architecture = analyzeArchitecture(tensorInfo, result.metadata, filename);
  39. result.architectureName = getArchitectureName(result.architecture);
  40. // Set architecture-specific properties and required models
  41. switch (result.architecture) {
  42. case ModelArchitecture::SD_1_5:
  43. result.textEncoderDim = 768;
  44. result.unetChannels = 1280;
  45. result.needsVAE = true;
  46. result.recommendedVAE = "vae-ft-mse-840000-ema-pruned.safetensors";
  47. result.needsTAESD = true;
  48. result.suggestedParams["vae_flag"] = "--vae";
  49. break;
  50. case ModelArchitecture::SD_2_1:
  51. result.textEncoderDim = 1024;
  52. result.unetChannels = 1280;
  53. result.needsVAE = true;
  54. result.recommendedVAE = "vae-ft-ema-560000.safetensors";
  55. result.needsTAESD = true;
  56. result.suggestedParams["vae_flag"] = "--vae";
  57. break;
  58. case ModelArchitecture::SDXL_BASE:
  59. case ModelArchitecture::SDXL_REFINER:
  60. result.textEncoderDim = 1280;
  61. result.unetChannels = 2560;
  62. result.hasConditioner = true;
  63. result.needsVAE = true;
  64. result.recommendedVAE = "sdxl_vae.safetensors";
  65. result.needsTAESD = true;
  66. result.suggestedParams["vae_flag"] = "--vae";
  67. break;
  68. case ModelArchitecture::FLUX_SCHNELL:
  69. case ModelArchitecture::FLUX_DEV:
  70. result.textEncoderDim = 4096;
  71. result.needsVAE = true;
  72. result.recommendedVAE = "ae.safetensors";
  73. // Flux requires CLIP-L and T5XXL
  74. result.suggestedParams["vae_flag"] = "--vae";
  75. result.suggestedParams["clip_l_required"] = "clip_l.safetensors";
  76. result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors";
  77. result.suggestedParams["clip_l_flag"] = "--clip-l";
  78. result.suggestedParams["t5xxl_flag"] = "--t5xxl";
  79. break;
  80. case ModelArchitecture::FLUX_CHROMA:
  81. result.textEncoderDim = 4096;
  82. result.needsVAE = true;
  83. result.recommendedVAE = "ae.safetensors";
  84. // Chroma (Flux Unlocked) requires VAE and T5XXL
  85. result.suggestedParams["vae_flag"] = "--vae";
  86. result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors";
  87. result.suggestedParams["t5xxl_flag"] = "--t5xxl";
  88. break;
  89. case ModelArchitecture::SD_3:
  90. result.textEncoderDim = 4096;
  91. result.needsVAE = true;
  92. result.recommendedVAE = "sd3_vae.safetensors";
  93. // SD3 requires CLIP-L, CLIP-G, and T5XXL
  94. result.suggestedParams["vae_flag"] = "--vae";
  95. result.suggestedParams["clip_l_required"] = "clip_l.safetensors";
  96. result.suggestedParams["clip_g_required"] = "clip_g.safetensors";
  97. result.suggestedParams["t5xxl_required"] = "t5xxl_fp16.safetensors";
  98. result.suggestedParams["clip_l_flag"] = "--clip-l";
  99. result.suggestedParams["clip_g_flag"] = "--clip-g";
  100. result.suggestedParams["t5xxl_flag"] = "--t5xxl";
  101. break;
  102. case ModelArchitecture::QWEN2VL:
  103. // Qwen2-VL requires vision and language model components
  104. result.suggestedParams["qwen2vl_required"] = "qwen2vl.safetensors";
  105. result.suggestedParams["qwen2vl_vision_required"] = "qwen2vl_vision.safetensors";
  106. result.suggestedParams["qwen2vl_flag"] = "--qwen2vl";
  107. result.suggestedParams["qwen2vl_vision_flag"] = "--qwen2vl-vision";
  108. break;
  109. default:
  110. break;
  111. }
  112. // Merge with general recommended parameters (width, height, steps, etc.)
  113. auto generalParams = getRecommendedParams(result.architecture);
  114. for (const auto& [key, value] : generalParams) {
  115. // Only add if not already set (preserve architecture-specific flags)
  116. if (result.suggestedParams.find(key) == result.suggestedParams.end()) {
  117. result.suggestedParams[key] = value;
  118. }
  119. }
  120. return result;
  121. }
  122. bool ModelDetector::parseSafetensorsHeader(
  123. const std::string& filePath,
  124. std::map<std::string, std::string>& metadata,
  125. std::map<std::string, std::vector<int64_t>>& tensorInfo
  126. ) {
  127. std::ifstream file(filePath, std::ios::binary);
  128. if (!file.is_open()) {
  129. return false;
  130. }
  131. // Read header length (first 8 bytes, little-endian uint64)
  132. uint64_t headerLength = 0;
  133. file.read(reinterpret_cast<char*>(&headerLength), 8);
  134. if (file.gcount() != 8) {
  135. return false;
  136. }
  137. // Sanity check: header should be reasonable size (< 100MB)
  138. if (headerLength == 0 || headerLength > 100 * 1024 * 1024) {
  139. return false;
  140. }
  141. // Read header JSON
  142. std::vector<char> headerBuffer(headerLength);
  143. file.read(headerBuffer.data(), headerLength);
  144. if (file.gcount() != static_cast<std::streamsize>(headerLength)) {
  145. return false;
  146. }
  147. // Parse JSON
  148. try {
  149. json headerJson = json::parse(headerBuffer.begin(), headerBuffer.end());
  150. // Extract metadata if present
  151. if (headerJson.contains("__metadata__")) {
  152. auto metadataJson = headerJson["__metadata__"];
  153. for (auto it = metadataJson.begin(); it != metadataJson.end(); ++it) {
  154. metadata[it.key()] = it.value().get<std::string>();
  155. }
  156. }
  157. // Extract tensor information
  158. for (auto it = headerJson.begin(); it != headerJson.end(); ++it) {
  159. if (it.key() == "__metadata__") continue;
  160. if (it.value().contains("shape")) {
  161. std::vector<int64_t> shape;
  162. for (const auto& dim : it.value()["shape"]) {
  163. shape.push_back(dim.get<int64_t>());
  164. }
  165. tensorInfo[it.key()] = shape;
  166. }
  167. }
  168. return true;
  169. } catch (const std::exception& e) {
  170. return false;
  171. }
  172. }
  173. ModelArchitecture ModelDetector::analyzeArchitecture(
  174. const std::map<std::string, std::vector<int64_t>>& tensorInfo,
  175. const std::map<std::string, std::string>& metadata,
  176. const std::string& filename
  177. ) {
  178. // Check metadata first for explicit architecture hints
  179. auto modelTypeIt = metadata.find("modelspec.architecture");
  180. if (modelTypeIt != metadata.end()) {
  181. const std::string& archName = modelTypeIt->second;
  182. if (archName.find("stable-diffusion-xl") != std::string::npos) {
  183. return ModelArchitecture::SDXL_BASE;
  184. } else if (archName.find("stable-diffusion-v2") != std::string::npos) {
  185. return ModelArchitecture::SD_2_1;
  186. } else if (archName.find("stable-diffusion-v1") != std::string::npos) {
  187. return ModelArchitecture::SD_1_5;
  188. }
  189. }
  190. // Check filename for special variants
  191. std::string lowerFilename = filename;
  192. std::transform(lowerFilename.begin(), lowerFilename.end(), lowerFilename.begin(), ::tolower);
  193. // Analyze tensor structure for architecture detection
  194. bool hasConditioner = false;
  195. bool hasTextEncoder2 = false;
  196. bool hasFluxStructure = false;
  197. bool hasSD3Structure = false;
  198. int maxUNetChannels = 0;
  199. int textEncoderOutputDim = 0;
  200. for (const auto& [name, shape] : tensorInfo) {
  201. // Check for SDXL-specific components
  202. if (name.find("conditioner") != std::string::npos) {
  203. hasConditioner = true;
  204. }
  205. if (name.find("text_encoder_2") != std::string::npos ||
  206. name.find("cond_stage_model.1") != std::string::npos) {
  207. hasTextEncoder2 = true;
  208. }
  209. // Check for Flux-specific patterns
  210. if (name.find("double_blocks") != std::string::npos ||
  211. name.find("single_blocks") != std::string::npos) {
  212. hasFluxStructure = true;
  213. }
  214. // Check for SD3-specific patterns
  215. if (name.find("joint_blocks") != std::string::npos) {
  216. hasSD3Structure = true;
  217. }
  218. // Analyze UNet structure
  219. if (name.find("model.diffusion_model") != std::string::npos ||
  220. name.find("unet") != std::string::npos) {
  221. if (shape.size() >= 2) {
  222. maxUNetChannels = std::max(maxUNetChannels, static_cast<int>(shape[0]));
  223. }
  224. }
  225. // Check text encoder dimensions
  226. if (name.find("cond_stage_model") != std::string::npos ||
  227. name.find("text_encoder") != std::string::npos) {
  228. if (name.find("proj") != std::string::npos && shape.size() >= 2) {
  229. textEncoderOutputDim = std::max(textEncoderOutputDim, static_cast<int>(shape[1]));
  230. }
  231. }
  232. }
  233. // Determine architecture based on analysis
  234. if (hasFluxStructure) {
  235. // Check for Chroma variant (unlocked Flux)
  236. if (lowerFilename.find("chroma") != std::string::npos) {
  237. return ModelArchitecture::FLUX_CHROMA;
  238. }
  239. // Check if it's Schnell or Dev based on step count hints
  240. auto stepsIt = metadata.find("diffusion_steps");
  241. if (stepsIt != metadata.end() && stepsIt->second.find("4") != std::string::npos) {
  242. return ModelArchitecture::FLUX_SCHNELL;
  243. }
  244. return ModelArchitecture::FLUX_DEV;
  245. }
  246. if (hasSD3Structure) {
  247. return ModelArchitecture::SD_3;
  248. }
  249. if (hasConditioner || hasTextEncoder2) {
  250. // SDXL architecture
  251. bool hasRefinerMarkers = false;
  252. for (const auto& [name, _] : tensorInfo) {
  253. if (name.find("refiner") != std::string::npos) {
  254. hasRefinerMarkers = true;
  255. break;
  256. }
  257. }
  258. return hasRefinerMarkers ? ModelArchitecture::SDXL_REFINER : ModelArchitecture::SDXL_BASE;
  259. }
  260. if (maxUNetChannels >= 2048) {
  261. return ModelArchitecture::SDXL_BASE;
  262. }
  263. // Distinguish between SD1.x and SD2.x by text encoder dimension
  264. if (textEncoderOutputDim >= 1024 || maxUNetChannels == 1280) {
  265. return ModelArchitecture::SD_2_1;
  266. }
  267. if (textEncoderOutputDim == 768 || maxUNetChannels <= 1280) {
  268. return ModelArchitecture::SD_1_5;
  269. }
  270. return ModelArchitecture::UNKNOWN;
  271. }
  272. std::string ModelDetector::getArchitectureName(ModelArchitecture arch) {
  273. switch (arch) {
  274. case ModelArchitecture::SD_1_5: return "Stable Diffusion 1.5";
  275. case ModelArchitecture::SD_2_1: return "Stable Diffusion 2.1";
  276. case ModelArchitecture::SDXL_BASE: return "Stable Diffusion XL Base";
  277. case ModelArchitecture::SDXL_REFINER: return "Stable Diffusion XL Refiner";
  278. case ModelArchitecture::FLUX_SCHNELL: return "Flux Schnell";
  279. case ModelArchitecture::FLUX_DEV: return "Flux Dev";
  280. case ModelArchitecture::FLUX_CHROMA: return "Flux Chroma (Unlocked)";
  281. case ModelArchitecture::SD_3: return "Stable Diffusion 3";
  282. case ModelArchitecture::QWEN2VL: return "Qwen2-VL";
  283. default: return "Unknown";
  284. }
  285. }
  286. std::map<std::string, std::string> ModelDetector::getRecommendedParams(ModelArchitecture arch) {
  287. std::map<std::string, std::string> params;
  288. switch (arch) {
  289. case ModelArchitecture::SD_1_5:
  290. params["width"] = "512";
  291. params["height"] = "512";
  292. params["cfg_scale"] = "7.5";
  293. params["steps"] = "20";
  294. params["sampler"] = "euler_a";
  295. break;
  296. case ModelArchitecture::SD_2_1:
  297. params["width"] = "768";
  298. params["height"] = "768";
  299. params["cfg_scale"] = "7.0";
  300. params["steps"] = "25";
  301. params["sampler"] = "euler_a";
  302. break;
  303. case ModelArchitecture::SDXL_BASE:
  304. case ModelArchitecture::SDXL_REFINER:
  305. params["width"] = "1024";
  306. params["height"] = "1024";
  307. params["cfg_scale"] = "7.0";
  308. params["steps"] = "30";
  309. params["sampler"] = "dpm++2m";
  310. break;
  311. case ModelArchitecture::FLUX_SCHNELL:
  312. params["width"] = "1024";
  313. params["height"] = "1024";
  314. params["cfg_scale"] = "1.0";
  315. params["steps"] = "4";
  316. params["sampler"] = "euler";
  317. break;
  318. case ModelArchitecture::FLUX_DEV:
  319. params["width"] = "1024";
  320. params["height"] = "1024";
  321. params["cfg_scale"] = "1.0";
  322. params["steps"] = "20";
  323. params["sampler"] = "euler";
  324. break;
  325. case ModelArchitecture::FLUX_CHROMA:
  326. params["width"] = "1024";
  327. params["height"] = "1024";
  328. params["cfg_scale"] = "1.0";
  329. params["steps"] = "20";
  330. params["sampler"] = "euler";
  331. break;
  332. case ModelArchitecture::SD_3:
  333. params["width"] = "1024";
  334. params["height"] = "1024";
  335. params["cfg_scale"] = "5.0";
  336. params["steps"] = "28";
  337. params["sampler"] = "dpm++2m";
  338. break;
  339. default:
  340. break;
  341. }
  342. return params;
  343. }
  344. bool ModelDetector::parseGGUFHeader(
  345. const std::string& filePath,
  346. std::map<std::string, std::string>& metadata,
  347. std::map<std::string, std::vector<int64_t>>& tensorInfo
  348. ) {
  349. std::ifstream file(filePath, std::ios::binary);
  350. if (!file.is_open()) {
  351. return false;
  352. }
  353. // Read and verify magic number "GGUF"
  354. char magic[4];
  355. file.read(magic, 4);
  356. if (file.gcount() != 4 || std::memcmp(magic, "GGUF", 4) != 0) {
  357. return false;
  358. }
  359. // Read version (uint32)
  360. uint32_t version;
  361. file.read(reinterpret_cast<char*>(&version), 4);
  362. if (file.gcount() != 4) {
  363. return false;
  364. }
  365. // Read tensor count (uint64)
  366. uint64_t tensorCount;
  367. file.read(reinterpret_cast<char*>(&tensorCount), 8);
  368. if (file.gcount() != 8) {
  369. return false;
  370. }
  371. // Read metadata KV count (uint64)
  372. uint64_t metadataCount;
  373. file.read(reinterpret_cast<char*>(&metadataCount), 8);
  374. if (file.gcount() != 8) {
  375. return false;
  376. }
  377. // Helper function to read string
  378. auto readString = [&file]() -> std::string {
  379. uint64_t length;
  380. file.read(reinterpret_cast<char*>(&length), 8);
  381. if (file.gcount() != 8 || length == 0 || length > 10000) {
  382. return "";
  383. }
  384. std::vector<char> buffer(length);
  385. file.read(buffer.data(), length);
  386. if (file.gcount() != static_cast<std::streamsize>(length)) {
  387. return "";
  388. }
  389. return std::string(buffer.begin(), buffer.end());
  390. };
  391. // Read metadata key-value pairs
  392. for (uint64_t i = 0; i < metadataCount && file.good(); ++i) {
  393. std::string key = readString();
  394. if (key.empty()) break;
  395. // Read value type (uint32)
  396. uint32_t valueType;
  397. file.read(reinterpret_cast<char*>(&valueType), 4);
  398. if (file.gcount() != 4) break;
  399. // Parse value based on type
  400. std::string value;
  401. switch (valueType) {
  402. case 8: // String
  403. value = readString();
  404. break;
  405. case 4: { // Uint32
  406. uint32_t val;
  407. file.read(reinterpret_cast<char*>(&val), 4);
  408. value = std::to_string(val);
  409. break;
  410. }
  411. case 5: { // Int32
  412. int32_t val;
  413. file.read(reinterpret_cast<char*>(&val), 4);
  414. value = std::to_string(val);
  415. break;
  416. }
  417. case 6: { // Float32
  418. float val;
  419. file.read(reinterpret_cast<char*>(&val), 4);
  420. value = std::to_string(val);
  421. break;
  422. }
  423. case 0: { // Uint8
  424. uint8_t val;
  425. file.read(reinterpret_cast<char*>(&val), 1);
  426. value = std::to_string(val);
  427. break;
  428. }
  429. case 1: { // Int8
  430. int8_t val;
  431. file.read(reinterpret_cast<char*>(&val), 1);
  432. value = std::to_string(val);
  433. break;
  434. }
  435. default:
  436. // Skip unknown types
  437. file.seekg(8, std::ios::cur);
  438. continue;
  439. }
  440. if (!value.empty()) {
  441. metadata[key] = value;
  442. }
  443. }
  444. // Read tensor information
  445. for (uint64_t i = 0; i < tensorCount && file.good(); ++i) {
  446. std::string tensorName = readString();
  447. if (tensorName.empty()) break;
  448. // Read number of dimensions (uint32)
  449. uint32_t nDims;
  450. file.read(reinterpret_cast<char*>(&nDims), 4);
  451. if (file.gcount() != 4 || nDims > 10) break;
  452. // Read dimensions (uint64 array)
  453. std::vector<int64_t> shape(nDims);
  454. for (uint32_t d = 0; d < nDims; ++d) {
  455. uint64_t dim;
  456. file.read(reinterpret_cast<char*>(&dim), 8);
  457. if (file.gcount() != 8) break;
  458. shape[d] = static_cast<int64_t>(dim);
  459. }
  460. // Skip type (uint32) and offset (uint64)
  461. file.seekg(12, std::ios::cur);
  462. tensorInfo[tensorName] = shape;
  463. }
  464. return !tensorInfo.empty();
  465. }