| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 |
- #include "model_detector.h"
- #include <iostream>
- #include <vector>
- #include <filesystem>
- #include <map>
- namespace fs = std::filesystem;
- // Test results structure
- struct TestResult {
- std::string modelPath;
- std::string expectedArchitecture;
- ModelArchitecture detectedEnum;
- std::string detectedName;
- bool passed;
- std::string notes;
- };
- // Test function for each model
- void testModelArchitecture(const std::string& modelPath,
- ModelArchitecture expectedEnum,
- const std::string& expectedName,
- std::vector<TestResult>& results) {
-
- TestResult result;
- result.modelPath = modelPath;
- result.expectedArchitecture = expectedName;
- result.detectedEnum = ModelArchitecture::UNKNOWN;
- result.detectedName = "Unknown";
- result.passed = false;
-
- std::cout << "\n=== Testing: " << fs::path(modelPath).filename().string() << " ===" << std::endl;
- std::cout << "Expected: " << expectedName << std::endl;
-
- if (!fs::exists(modelPath)) {
- std::cout << "❌ Model file does not exist!" << std::endl;
- result.notes = "File not found";
- results.push_back(result);
- return;
- }
-
- try {
- ModelDetectionResult detectionResult = ModelDetector::detectModel(modelPath);
- result.detectedEnum = detectionResult.architecture;
- result.detectedName = detectionResult.architectureName;
-
- std::cout << "Detected: " << result.detectedName << std::endl;
-
- // Check if detection matches expected
- if (result.detectedEnum == expectedEnum) {
- std::cout << "✅ PASS: Correctly detected as " << expectedName << std::endl;
- result.passed = true;
- result.notes = "Correctly detected";
- } else {
- std::cout << "❌ FAIL: Expected " << expectedName << " but got " << result.detectedName << std::endl;
- result.notes = "Incorrect detection - expected " + expectedName + " but got " + result.detectedName;
-
- // Show some tensor names for debugging
- std::cout << " First 5 tensor names:" << std::endl;
- for (size_t i = 0; i < std::min(size_t(5), detectionResult.tensorNames.size()); ++i) {
- std::cout << " " << detectionResult.tensorNames[i] << std::endl;
- }
- }
-
- // Show key metadata if available
- if (!detectionResult.metadata.empty()) {
- std::cout << " Key metadata:" << std::endl;
- for (const auto& [key, value] : detectionResult.metadata) {
- if (key.find("architecture") != std::string::npos ||
- key.find("model") != std::string::npos ||
- key.find("_model_name") != std::string::npos) {
- std::cout << " " << key << ": " << value << std::endl;
- }
- }
- }
-
- } catch (const std::exception& e) {
- std::cout << "❌ ERROR: " << e.what() << std::endl;
- result.notes = "Exception: " + std::string(e.what());
- }
-
- results.push_back(result);
- }
- int main() {
- std::cout << "🧪 Comprehensive Model Architecture Detection Test" << std::endl;
- std::cout << "==================================================" << std::endl;
- std::cout << "Testing all model types to verify Qwen fix doesn't break other architectures" << std::endl;
-
- std::vector<TestResult> results;
-
- // Test Stable Diffusion 1.5 models
- std::cout << "\n📋 Testing Stable Diffusion 1.5 Models..." << std::endl;
- testModelArchitecture("/data/SD_MODELS/checkpoints/v1-5-pruned-emaonly.safetensors",
- ModelArchitecture::SD_1_5, "Stable Diffusion 1.5", results);
- testModelArchitecture("/data/SD_MODELS/checkpoints/sd_15_base.safetensors",
- ModelArchitecture::SD_1_5, "Stable Diffusion 1.5", results);
-
- // Test Stable Diffusion 2.1 models (if available)
- std::cout << "\n📋 Testing Stable Diffusion 2.1 Models..." << std::endl;
- // Note: No SD 2.1 models found in the current directory, but keeping the test structure
-
- // Test SDXL models
- std::cout << "\n📋 Testing SDXL Models..." << std::endl;
- testModelArchitecture("/data/SD_MODELS/checkpoints/sd_xl_base_1.0_0.9vae.safetensors",
- ModelArchitecture::SDXL_BASE, "Stable Diffusion XL Base", results);
- testModelArchitecture("/data/SD_MODELS/checkpoints/sd_xl_refiner_1.0_0.9vae.safetensors",
- ModelArchitecture::SDXL_REFINER, "Stable Diffusion XL Refiner", results);
- testModelArchitecture("/data/SD_MODELS/checkpoints/juggernautXL_v8Rundiffusion.safetensors",
- ModelArchitecture::SDXL_BASE, "Stable Diffusion XL Base", results);
- testModelArchitecture("/data/SD_MODELS/checkpoints/realDream_sdxl6.safetensors",
- ModelArchitecture::SDXL_BASE, "Stable Diffusion XL Base", results);
-
- // Test Flux models
- std::cout << "\n📋 Testing Flux Models..." << std::endl;
- testModelArchitecture("/data/SD_MODELS/checkpoints/chroma-unlocked-v40-detail-calibrated-Q4_0.gguf",
- ModelArchitecture::FLUX_CHROMA, "Flux Chroma (Unlocked)", results);
- testModelArchitecture("/data/SD_MODELS/checkpoints/flux1-kontext-dev-Q5_K_S.gguf",
- ModelArchitecture::FLUX_DEV, "Flux Dev", results);
- testModelArchitecture("/data/SD_MODELS/checkpoints/gonzalomoXLFluxPony_v20PonyDMD.safetensors",
- ModelArchitecture::FLUX_DEV, "Flux Dev", results);
-
- // Test SD3 models (if available)
- std::cout << "\n📋 Testing SD3 Models..." << std::endl;
- // Note: No SD3 models found in the current directory, but keeping the test structure
-
- // Test Qwen models (to verify the fix still works)
- std::cout << "\n📋 Testing Qwen Models (verifying fix)..." << std::endl;
- testModelArchitecture("/data/SD_MODELS/diffusion_models/Qwen-Image-Edit-2509-Q3_K_S.gguf",
- ModelArchitecture::QWEN2VL, "Qwen2-VL", results);
- testModelArchitecture("/data/SD_MODELS/diffusion_models/Qwen-Image-Pruning-13b-Q4_0.gguf",
- ModelArchitecture::QWEN2VL, "Qwen2-VL", results);
- testModelArchitecture("/data/SD_MODELS/diffusion_models/qwen-image-Q2_K.gguf",
- ModelArchitecture::QWEN2VL, "Qwen2-VL", results);
-
- // Summary by architecture
- std::map<ModelArchitecture, std::pair<int, int>> archStats; // {passed, total}
-
- std::cout << "\n📊 DETAILED RESULTS BY ARCHITECTURE:" << std::endl;
- std::cout << "======================================" << std::endl;
-
- for (const auto& result : results) {
- archStats[result.detectedEnum].first += result.passed ? 1 : 0;
- archStats[result.detectedEnum].second += 1;
-
- std::cout << "\n📁 Model: " << fs::path(result.modelPath).filename().string() << std::endl;
- std::cout << " Expected: " << result.expectedArchitecture << std::endl;
- std::cout << " Detected: " << result.detectedName << std::endl;
- std::cout << " Status: " << (result.passed ? "✅ PASS" : "❌ FAIL") << std::endl;
- if (!result.notes.empty()) {
- std::cout << " Notes: " << result.notes << std::endl;
- }
- }
-
- // Overall summary
- std::cout << "\n🎯 OVERALL TEST SUMMARY:" << std::endl;
- std::cout << "========================" << std::endl;
-
- int totalTests = 0;
- int totalPassed = 0;
-
- for (const auto& [arch, stats] : archStats) {
- std::string archName = ModelDetector::getArchitectureName(arch);
- int passed = stats.first;
- int total = stats.second;
-
- std::cout << archName << ": " << passed << "/" << total << " tests passed";
- if (passed == total && total > 0) {
- std::cout << " ✅";
- } else if (total > 0) {
- std::cout << " ❌";
- }
- std::cout << std::endl;
-
- totalTests += total;
- totalPassed += passed;
- }
-
- std::cout << "\n📈 FINAL RESULTS:" << std::endl;
- std::cout << " Total models tested: " << totalTests << std::endl;
- std::cout << " Successfully detected: " << totalPassed << std::endl;
- std::cout << " Success rate: " << (totalTests > 0 ? (totalPassed * 100 / totalTests) : 0) << "%" << std::endl;
-
- // Check for specific issues with the Qwen fix
- std::cout << "\n🔍 QWEN FIX IMPACT ANALYSIS:" << std::endl;
- std::cout << "=============================" << std::endl;
-
- bool qwenWorks = false;
- bool othersWork = true;
-
- for (const auto& result : results) {
- if (result.expectedArchitecture == "Qwen2-VL") {
- qwenWorks = result.passed;
- } else if (result.expectedArchitecture != "Unknown" && !result.passed) {
- othersWork = false;
- }
- }
-
- if (qwenWorks) {
- std::cout << "✅ Qwen detection fix is working correctly" << std::endl;
- } else {
- std::cout << "❌ Qwen detection fix is NOT working" << std::endl;
- }
-
- if (othersWork) {
- std::cout << "✅ Other model architectures are still working correctly" << std::endl;
- } else {
- std::cout << "❌ Some other model architectures are broken after the Qwen fix" << std::endl;
- }
-
- std::cout << "\n🏁 TEST COMPLETE!" << std::endl;
-
- return (totalPassed == totalTests && totalTests > 0) ? 0 : 1;
- }
|