test_all_architectures.cpp 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. #include "model_detector.h"
  2. #include <iostream>
  3. #include <vector>
  4. #include <filesystem>
  5. #include <map>
  6. namespace fs = std::filesystem;
  7. // Test results structure
  8. struct TestResult {
  9. std::string modelPath;
  10. std::string expectedArchitecture;
  11. ModelArchitecture detectedEnum;
  12. std::string detectedName;
  13. bool passed;
  14. std::string notes;
  15. };
  16. // Test function for each model
  17. void testModelArchitecture(const std::string& modelPath,
  18. ModelArchitecture expectedEnum,
  19. const std::string& expectedName,
  20. std::vector<TestResult>& results) {
  21. TestResult result;
  22. result.modelPath = modelPath;
  23. result.expectedArchitecture = expectedName;
  24. result.detectedEnum = ModelArchitecture::UNKNOWN;
  25. result.detectedName = "Unknown";
  26. result.passed = false;
  27. std::cout << "\n=== Testing: " << fs::path(modelPath).filename().string() << " ===" << std::endl;
  28. std::cout << "Expected: " << expectedName << std::endl;
  29. if (!fs::exists(modelPath)) {
  30. std::cout << "❌ Model file does not exist!" << std::endl;
  31. result.notes = "File not found";
  32. results.push_back(result);
  33. return;
  34. }
  35. try {
  36. ModelDetectionResult detectionResult = ModelDetector::detectModel(modelPath);
  37. result.detectedEnum = detectionResult.architecture;
  38. result.detectedName = detectionResult.architectureName;
  39. std::cout << "Detected: " << result.detectedName << std::endl;
  40. // Check if detection matches expected
  41. if (result.detectedEnum == expectedEnum) {
  42. std::cout << "✅ PASS: Correctly detected as " << expectedName << std::endl;
  43. result.passed = true;
  44. result.notes = "Correctly detected";
  45. } else {
  46. std::cout << "❌ FAIL: Expected " << expectedName << " but got " << result.detectedName << std::endl;
  47. result.notes = "Incorrect detection - expected " + expectedName + " but got " + result.detectedName;
  48. // Show some tensor names for debugging
  49. std::cout << " First 5 tensor names:" << std::endl;
  50. for (size_t i = 0; i < std::min(size_t(5), detectionResult.tensorNames.size()); ++i) {
  51. std::cout << " " << detectionResult.tensorNames[i] << std::endl;
  52. }
  53. }
  54. // Show key metadata if available
  55. if (!detectionResult.metadata.empty()) {
  56. std::cout << " Key metadata:" << std::endl;
  57. for (const auto& [key, value] : detectionResult.metadata) {
  58. if (key.find("architecture") != std::string::npos ||
  59. key.find("model") != std::string::npos ||
  60. key.find("_model_name") != std::string::npos) {
  61. std::cout << " " << key << ": " << value << std::endl;
  62. }
  63. }
  64. }
  65. } catch (const std::exception& e) {
  66. std::cout << "❌ ERROR: " << e.what() << std::endl;
  67. result.notes = "Exception: " + std::string(e.what());
  68. }
  69. results.push_back(result);
  70. }
  71. int main() {
  72. std::cout << "🧪 Comprehensive Model Architecture Detection Test" << std::endl;
  73. std::cout << "==================================================" << std::endl;
  74. std::cout << "Testing all model types to verify Qwen fix doesn't break other architectures" << std::endl;
  75. std::vector<TestResult> results;
  76. // Test Stable Diffusion 1.5 models
  77. std::cout << "\n📋 Testing Stable Diffusion 1.5 Models..." << std::endl;
  78. testModelArchitecture("/data/SD_MODELS/checkpoints/v1-5-pruned-emaonly.safetensors",
  79. ModelArchitecture::SD_1_5, "Stable Diffusion 1.5", results);
  80. testModelArchitecture("/data/SD_MODELS/checkpoints/sd_15_base.safetensors",
  81. ModelArchitecture::SD_1_5, "Stable Diffusion 1.5", results);
  82. // Test Stable Diffusion 2.1 models (if available)
  83. std::cout << "\n📋 Testing Stable Diffusion 2.1 Models..." << std::endl;
  84. // Note: No SD 2.1 models found in the current directory, but keeping the test structure
  85. // Test SDXL models
  86. std::cout << "\n📋 Testing SDXL Models..." << std::endl;
  87. testModelArchitecture("/data/SD_MODELS/checkpoints/sd_xl_base_1.0_0.9vae.safetensors",
  88. ModelArchitecture::SDXL_BASE, "Stable Diffusion XL Base", results);
  89. testModelArchitecture("/data/SD_MODELS/checkpoints/sd_xl_refiner_1.0_0.9vae.safetensors",
  90. ModelArchitecture::SDXL_REFINER, "Stable Diffusion XL Refiner", results);
  91. testModelArchitecture("/data/SD_MODELS/checkpoints/juggernautXL_v8Rundiffusion.safetensors",
  92. ModelArchitecture::SDXL_BASE, "Stable Diffusion XL Base", results);
  93. testModelArchitecture("/data/SD_MODELS/checkpoints/realDream_sdxl6.safetensors",
  94. ModelArchitecture::SDXL_BASE, "Stable Diffusion XL Base", results);
  95. // Test Flux models
  96. std::cout << "\n📋 Testing Flux Models..." << std::endl;
  97. testModelArchitecture("/data/SD_MODELS/checkpoints/chroma-unlocked-v40-detail-calibrated-Q4_0.gguf",
  98. ModelArchitecture::FLUX_CHROMA, "Flux Chroma (Unlocked)", results);
  99. testModelArchitecture("/data/SD_MODELS/checkpoints/flux1-kontext-dev-Q5_K_S.gguf",
  100. ModelArchitecture::FLUX_DEV, "Flux Dev", results);
  101. testModelArchitecture("/data/SD_MODELS/checkpoints/gonzalomoXLFluxPony_v20PonyDMD.safetensors",
  102. ModelArchitecture::FLUX_DEV, "Flux Dev", results);
  103. // Test SD3 models (if available)
  104. std::cout << "\n📋 Testing SD3 Models..." << std::endl;
  105. // Note: No SD3 models found in the current directory, but keeping the test structure
  106. // Test Qwen models (to verify the fix still works)
  107. std::cout << "\n📋 Testing Qwen Models (verifying fix)..." << std::endl;
  108. testModelArchitecture("/data/SD_MODELS/diffusion_models/Qwen-Image-Edit-2509-Q3_K_S.gguf",
  109. ModelArchitecture::QWEN2VL, "Qwen2-VL", results);
  110. testModelArchitecture("/data/SD_MODELS/diffusion_models/Qwen-Image-Pruning-13b-Q4_0.gguf",
  111. ModelArchitecture::QWEN2VL, "Qwen2-VL", results);
  112. testModelArchitecture("/data/SD_MODELS/diffusion_models/qwen-image-Q2_K.gguf",
  113. ModelArchitecture::QWEN2VL, "Qwen2-VL", results);
  114. // Summary by architecture
  115. std::map<ModelArchitecture, std::pair<int, int>> archStats; // {passed, total}
  116. std::cout << "\n📊 DETAILED RESULTS BY ARCHITECTURE:" << std::endl;
  117. std::cout << "======================================" << std::endl;
  118. for (const auto& result : results) {
  119. archStats[result.detectedEnum].first += result.passed ? 1 : 0;
  120. archStats[result.detectedEnum].second += 1;
  121. std::cout << "\n📁 Model: " << fs::path(result.modelPath).filename().string() << std::endl;
  122. std::cout << " Expected: " << result.expectedArchitecture << std::endl;
  123. std::cout << " Detected: " << result.detectedName << std::endl;
  124. std::cout << " Status: " << (result.passed ? "✅ PASS" : "❌ FAIL") << std::endl;
  125. if (!result.notes.empty()) {
  126. std::cout << " Notes: " << result.notes << std::endl;
  127. }
  128. }
  129. // Overall summary
  130. std::cout << "\n🎯 OVERALL TEST SUMMARY:" << std::endl;
  131. std::cout << "========================" << std::endl;
  132. int totalTests = 0;
  133. int totalPassed = 0;
  134. for (const auto& [arch, stats] : archStats) {
  135. std::string archName = ModelDetector::getArchitectureName(arch);
  136. int passed = stats.first;
  137. int total = stats.second;
  138. std::cout << archName << ": " << passed << "/" << total << " tests passed";
  139. if (passed == total && total > 0) {
  140. std::cout << " ✅";
  141. } else if (total > 0) {
  142. std::cout << " ❌";
  143. }
  144. std::cout << std::endl;
  145. totalTests += total;
  146. totalPassed += passed;
  147. }
  148. std::cout << "\n📈 FINAL RESULTS:" << std::endl;
  149. std::cout << " Total models tested: " << totalTests << std::endl;
  150. std::cout << " Successfully detected: " << totalPassed << std::endl;
  151. std::cout << " Success rate: " << (totalTests > 0 ? (totalPassed * 100 / totalTests) : 0) << "%" << std::endl;
  152. // Check for specific issues with the Qwen fix
  153. std::cout << "\n🔍 QWEN FIX IMPACT ANALYSIS:" << std::endl;
  154. std::cout << "=============================" << std::endl;
  155. bool qwenWorks = false;
  156. bool othersWork = true;
  157. for (const auto& result : results) {
  158. if (result.expectedArchitecture == "Qwen2-VL") {
  159. qwenWorks = result.passed;
  160. } else if (result.expectedArchitecture != "Unknown" && !result.passed) {
  161. othersWork = false;
  162. }
  163. }
  164. if (qwenWorks) {
  165. std::cout << "✅ Qwen detection fix is working correctly" << std::endl;
  166. } else {
  167. std::cout << "❌ Qwen detection fix is NOT working" << std::endl;
  168. }
  169. if (othersWork) {
  170. std::cout << "✅ Other model architectures are still working correctly" << std::endl;
  171. } else {
  172. std::cout << "❌ Some other model architectures are broken after the Qwen fix" << std::endl;
  173. }
  174. std::cout << "\n🏁 TEST COMPLETE!" << std::endl;
  175. return (totalPassed == totalTests && totalTests > 0) ? 0 : 1;
  176. }