stable_diffusion_wrapper.cpp 50 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209
  1. #include "stable_diffusion_wrapper.h"
  2. #include <algorithm>
  3. #include <chrono>
  4. #include <cstring>
  5. #include <filesystem>
  6. #include <thread>
  7. #include "logger.h"
  8. #include "model_detector.h"
  9. extern "C" {
  10. #include "stable-diffusion.h"
  11. }
  12. class StableDiffusionWrapper::Impl {
  13. public:
  14. sd_ctx_t* sdContext = nullptr;
  15. std::string lastError;
  16. std::mutex contextMutex;
  17. bool verbose = false;
  18. std::string currentModelPath;
  19. StableDiffusionWrapper::GenerationParams currentModelParams;
  20. Impl() {
  21. // Initialize any required resources
  22. }
  23. ~Impl() {
  24. unloadModel();
  25. }
  26. bool loadModel(const std::string& modelPath, const StableDiffusionWrapper::GenerationParams& params) {
  27. std::lock_guard<std::mutex> lock(contextMutex);
  28. // Store verbose flag for use in other functions
  29. verbose = params.verbose;
  30. // Unload any existing model
  31. if (sdContext) {
  32. free_sd_ctx(sdContext);
  33. sdContext = nullptr;
  34. }
  35. // Initialize context parameters
  36. sd_ctx_params_t ctxParams;
  37. sd_ctx_params_init(&ctxParams);
  38. ctxParams.free_params_immediately = false; // avoid segfault when reusing
  39. // Get absolute path for logging
  40. std::filesystem::path absModelPath = std::filesystem::absolute(modelPath);
  41. if (params.verbose) {
  42. LOG_DEBUG("Loading model from absolute path: " + std::filesystem::absolute(modelPath).string());
  43. }
  44. // Create persistent string copies to fix lifetime issues
  45. // These strings will remain valid for the entire lifetime of the context
  46. std::string persistentModelPath = modelPath;
  47. std::string persistentClipLPath = params.clipLPath;
  48. std::string persistentClipGPath = params.clipGPath;
  49. std::string persistentVaePath = params.vaePath;
  50. std::string persistentTaesdPath = params.taesdPath;
  51. std::string persistentControlNetPath = params.controlNetPath;
  52. std::string persistentLoraModelDir = params.loraModelDir;
  53. std::string persistentEmbeddingDir = params.embeddingDir;
  54. // Use folder-based path selection with enhanced logic
  55. bool useDiffusionModelPath = false;
  56. std::string detectionSource = "folder";
  57. // Check if model is in diffusion_models directory by examining the path
  58. std::filesystem::path modelFilePath(modelPath);
  59. std::filesystem::path parentDir = modelFilePath.parent_path();
  60. std::string parentDirName = parentDir.filename().string();
  61. std::string modelFileName = modelFilePath.filename().string();
  62. // Convert to lowercase for comparison
  63. std::transform(parentDirName.begin(), parentDirName.end(), parentDirName.begin(), ::tolower);
  64. std::transform(modelFileName.begin(), modelFileName.end(), modelFileName.begin(), ::tolower);
  65. // Variables for fallback detection
  66. ModelDetectionResult detectionResult;
  67. bool detectionSuccessful = false;
  68. bool isQwenModel = false;
  69. // Check if this is a Qwen model based on filename
  70. if (modelFileName.find("qwen") != std::string::npos) {
  71. isQwenModel = true;
  72. if (params.verbose) {
  73. LOG_DEBUG("Detected Qwen model from filename: " + modelFileName);
  74. }
  75. }
  76. // Enhanced path selection logic
  77. if (parentDirName == "diffusion_models" || parentDirName == "diffusion") {
  78. useDiffusionModelPath = true;
  79. if (params.verbose) {
  80. LOG_DEBUG("Model is in " + parentDirName + " directory, using diffusion_model_path");
  81. }
  82. } else if (parentDirName == "checkpoints" || parentDirName == "stable-diffusion") {
  83. useDiffusionModelPath = false;
  84. if (params.verbose) {
  85. LOG_DEBUG("Model is in " + parentDirName + " directory, using model_path");
  86. }
  87. } else if (parentDirName == "sd_models" || parentDirName.empty()) {
  88. // Handle models in root /data/SD_MODELS/ directory
  89. if (isQwenModel) {
  90. // Qwen models should use diffusion_model_path regardless of directory
  91. useDiffusionModelPath = true;
  92. detectionSource = "qwen_root_detection";
  93. if (params.verbose) {
  94. LOG_DEBUG("Qwen model in root directory, preferring diffusion_model_path");
  95. }
  96. } else {
  97. // For non-Qwen models in root, try architecture detection
  98. if (params.verbose) {
  99. LOG_DEBUG("Model is in root directory '" + parentDirName + "', attempting architecture detection");
  100. }
  101. detectionSource = "architecture_fallback";
  102. try {
  103. detectionResult = ModelDetector::detectModel(modelPath);
  104. detectionSuccessful = true;
  105. if (params.verbose) {
  106. LOG_DEBUG("Architecture detection found: " + detectionResult.architectureName);
  107. }
  108. } catch (const std::exception& e) {
  109. LOG_ERROR("Warning: Architecture detection failed: " + std::string(e.what()) + ". Using default loading method.");
  110. detectionResult.architecture = ModelArchitecture::UNKNOWN;
  111. detectionResult.architectureName = "Unknown";
  112. }
  113. if (detectionSuccessful) {
  114. switch (detectionResult.architecture) {
  115. case ModelArchitecture::FLUX_SCHNELL:
  116. case ModelArchitecture::FLUX_DEV:
  117. case ModelArchitecture::FLUX_CHROMA:
  118. case ModelArchitecture::SD_3:
  119. case ModelArchitecture::QWEN2VL:
  120. // Modern architectures use diffusion_model_path
  121. useDiffusionModelPath = true;
  122. break;
  123. case ModelArchitecture::SD_1_5:
  124. case ModelArchitecture::SD_2_1:
  125. case ModelArchitecture::SDXL_BASE:
  126. case ModelArchitecture::SDXL_REFINER:
  127. // Traditional SD models use model_path
  128. useDiffusionModelPath = false;
  129. break;
  130. case ModelArchitecture::UNKNOWN:
  131. default:
  132. // Unknown architectures fall back to model_path for backward compatibility
  133. useDiffusionModelPath = false;
  134. if (params.verbose) {
  135. LOG_WARNING("Warning: Unknown model architecture detected, using default model_path for backward compatibility");
  136. }
  137. break;
  138. }
  139. } else {
  140. useDiffusionModelPath = false; // Default fallback
  141. detectionSource = "default_fallback";
  142. }
  143. }
  144. } else {
  145. // Unknown directory - try architecture detection
  146. if (params.verbose) {
  147. LOG_DEBUG("Model is in unknown directory '" + parentDirName + "', attempting architecture detection as fallback");
  148. }
  149. detectionSource = "architecture_fallback";
  150. try {
  151. detectionResult = ModelDetector::detectModel(modelPath);
  152. detectionSuccessful = true;
  153. if (params.verbose) {
  154. LOG_DEBUG("Fallback detection found architecture: " + detectionResult.architectureName);
  155. }
  156. } catch (const std::exception& e) {
  157. LOG_ERROR("Warning: Fallback model detection failed: " + std::string(e.what()) + ". Using default loading method.");
  158. detectionResult.architecture = ModelArchitecture::UNKNOWN;
  159. detectionResult.architectureName = "Unknown";
  160. }
  161. if (detectionSuccessful) {
  162. switch (detectionResult.architecture) {
  163. case ModelArchitecture::FLUX_SCHNELL:
  164. case ModelArchitecture::FLUX_DEV:
  165. case ModelArchitecture::FLUX_CHROMA:
  166. case ModelArchitecture::SD_3:
  167. case ModelArchitecture::QWEN2VL:
  168. // Modern architectures use diffusion_model_path
  169. useDiffusionModelPath = true;
  170. break;
  171. case ModelArchitecture::SD_1_5:
  172. case ModelArchitecture::SD_2_1:
  173. case ModelArchitecture::SDXL_BASE:
  174. case ModelArchitecture::SDXL_REFINER:
  175. // Traditional SD models use model_path
  176. useDiffusionModelPath = false;
  177. break;
  178. case ModelArchitecture::UNKNOWN:
  179. default:
  180. // Unknown architectures fall back to model_path for backward compatibility
  181. useDiffusionModelPath = false;
  182. if (params.verbose) {
  183. LOG_WARNING("Warning: Unknown model architecture detected, using default model_path for backward compatibility");
  184. }
  185. break;
  186. }
  187. } else {
  188. useDiffusionModelPath = false; // Default fallback
  189. detectionSource = "default_fallback";
  190. }
  191. }
  192. // Set the appropriate model path based on folder location or fallback detection
  193. if (useDiffusionModelPath) {
  194. ctxParams.diffusion_model_path = persistentModelPath.c_str();
  195. ctxParams.model_path = nullptr; // Clear the traditional path
  196. if (params.verbose) {
  197. LOG_DEBUG("Using diffusion_model_path (source: " + detectionSource + ")");
  198. }
  199. } else {
  200. ctxParams.model_path = persistentModelPath.c_str();
  201. ctxParams.diffusion_model_path = nullptr; // Clear the modern path
  202. if (params.verbose) {
  203. LOG_DEBUG("Using model_path (source: " + detectionSource + ")");
  204. }
  205. }
  206. // Set optional model paths using persistent strings to fix lifetime issues
  207. if (!persistentClipLPath.empty()) {
  208. ctxParams.clip_l_path = persistentClipLPath.c_str();
  209. if (params.verbose) {
  210. LOG_DEBUG("Using CLIP-L path: " + std::filesystem::absolute(persistentClipLPath).string());
  211. }
  212. }
  213. if (!persistentClipGPath.empty()) {
  214. ctxParams.clip_g_path = persistentClipGPath.c_str();
  215. if (params.verbose) {
  216. LOG_DEBUG("Using CLIP-G path: " + std::filesystem::absolute(persistentClipGPath).string());
  217. }
  218. }
  219. if (!persistentVaePath.empty()) {
  220. // Check if VAE file exists before setting it
  221. if (std::filesystem::exists(persistentVaePath)) {
  222. ctxParams.vae_path = persistentVaePath.c_str();
  223. if (params.verbose) {
  224. LOG_DEBUG("Using VAE path: " + std::filesystem::absolute(persistentVaePath).string());
  225. }
  226. } else {
  227. if (params.verbose) {
  228. LOG_DEBUG("VAE file not found: " + std::filesystem::absolute(persistentVaePath).string() + " - continuing without VAE");
  229. }
  230. ctxParams.vae_path = nullptr;
  231. }
  232. }
  233. if (!persistentTaesdPath.empty()) {
  234. ctxParams.taesd_path = persistentTaesdPath.c_str();
  235. if (params.verbose) {
  236. LOG_DEBUG("Using TAESD path: " + std::filesystem::absolute(persistentTaesdPath).string());
  237. }
  238. }
  239. if (!persistentControlNetPath.empty()) {
  240. ctxParams.control_net_path = persistentControlNetPath.c_str();
  241. if (params.verbose) {
  242. LOG_DEBUG("Using ControlNet path: " + std::filesystem::absolute(persistentControlNetPath).string());
  243. }
  244. }
  245. if (!persistentLoraModelDir.empty()) {
  246. ctxParams.lora_model_dir = persistentLoraModelDir.c_str();
  247. if (params.verbose) {
  248. LOG_DEBUG("Using LoRA model directory: " + std::filesystem::absolute(persistentLoraModelDir).string());
  249. }
  250. }
  251. if (!persistentEmbeddingDir.empty()) {
  252. ctxParams.embedding_dir = persistentEmbeddingDir.c_str();
  253. if (params.verbose) {
  254. LOG_DEBUG("Using embedding directory: " + std::filesystem::absolute(persistentEmbeddingDir).string());
  255. }
  256. }
  257. // Set performance parameters
  258. ctxParams.n_threads = params.nThreads;
  259. ctxParams.offload_params_to_cpu = params.offloadParamsToCpu;
  260. ctxParams.keep_clip_on_cpu = params.clipOnCpu;
  261. ctxParams.keep_vae_on_cpu = params.vaeOnCpu;
  262. ctxParams.diffusion_flash_attn = params.diffusionFlashAttn;
  263. ctxParams.diffusion_conv_direct = params.diffusionConvDirect;
  264. ctxParams.vae_conv_direct = params.vaeConvDirect;
  265. // Set model type
  266. ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType);
  267. // Create the stable-diffusion context
  268. if (params.verbose) {
  269. LOG_DEBUG("Attempting to create stable-diffusion context with selected parameters...");
  270. }
  271. sdContext = new_sd_ctx(&ctxParams);
  272. if (!sdContext) {
  273. lastError = "Failed to create stable-diffusion context";
  274. LOG_ERROR("Error: " + lastError + " with initial attempt");
  275. // If we used diffusion_model_path and it failed, try fallback to model_path
  276. if (useDiffusionModelPath) {
  277. if (params.verbose) {
  278. LOG_WARNING("Warning: Failed to load with diffusion_model_path. Attempting fallback to model_path...");
  279. }
  280. // Re-initialize context parameters
  281. sd_ctx_params_init(&ctxParams);
  282. // Set fallback model path using persistent string
  283. ctxParams.model_path = persistentModelPath.c_str();
  284. ctxParams.diffusion_model_path = nullptr;
  285. // Re-apply other parameters using persistent strings
  286. if (!persistentClipLPath.empty()) {
  287. ctxParams.clip_l_path = persistentClipLPath.c_str();
  288. }
  289. if (!persistentClipGPath.empty()) {
  290. ctxParams.clip_g_path = persistentClipGPath.c_str();
  291. }
  292. if (!persistentVaePath.empty()) {
  293. // Check if VAE file exists before setting it
  294. if (std::filesystem::exists(persistentVaePath)) {
  295. ctxParams.vae_path = persistentVaePath.c_str();
  296. } else {
  297. ctxParams.vae_path = nullptr;
  298. }
  299. }
  300. if (!persistentTaesdPath.empty()) {
  301. ctxParams.taesd_path = persistentTaesdPath.c_str();
  302. }
  303. if (!persistentControlNetPath.empty()) {
  304. ctxParams.control_net_path = persistentControlNetPath.c_str();
  305. }
  306. if (!persistentLoraModelDir.empty()) {
  307. ctxParams.lora_model_dir = persistentLoraModelDir.c_str();
  308. }
  309. if (!persistentEmbeddingDir.empty()) {
  310. ctxParams.embedding_dir = persistentEmbeddingDir.c_str();
  311. }
  312. // Re-apply performance parameters
  313. ctxParams.n_threads = params.nThreads;
  314. ctxParams.offload_params_to_cpu = params.offloadParamsToCpu;
  315. ctxParams.keep_clip_on_cpu = params.clipOnCpu;
  316. ctxParams.keep_vae_on_cpu = params.vaeOnCpu;
  317. ctxParams.diffusion_flash_attn = params.diffusionFlashAttn;
  318. ctxParams.diffusion_conv_direct = params.diffusionConvDirect;
  319. ctxParams.vae_conv_direct = params.vaeConvDirect;
  320. // Re-apply model type
  321. ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType);
  322. if (params.verbose) {
  323. LOG_DEBUG("Attempting to create context with fallback model_path...");
  324. }
  325. // Try creating context again with fallback
  326. sdContext = new_sd_ctx(&ctxParams);
  327. if (!sdContext) {
  328. lastError = "Failed to create stable-diffusion context with both diffusion_model_path and model_path fallback";
  329. LOG_ERROR("Error: " + lastError);
  330. // Additional fallback: try with minimal parameters for GGUF models
  331. if (modelFileName.find(".gguf") != std::string::npos || modelFileName.find(".ggml") != std::string::npos) {
  332. if (params.verbose) {
  333. LOG_DEBUG("Detected GGUF/GGML model, attempting minimal parameter fallback...");
  334. }
  335. // Re-initialize with minimal parameters
  336. sd_ctx_params_init(&ctxParams);
  337. ctxParams.model_path = persistentModelPath.c_str();
  338. ctxParams.diffusion_model_path = nullptr;
  339. // Set only essential parameters for GGUF
  340. ctxParams.n_threads = params.nThreads;
  341. ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType);
  342. if (params.verbose) {
  343. LOG_DEBUG("Attempting to create context with minimal GGUF parameters...");
  344. }
  345. sdContext = new_sd_ctx(&ctxParams);
  346. if (!sdContext) {
  347. lastError = "Failed to create stable-diffusion context even with minimal GGUF parameters";
  348. LOG_ERROR("Error: " + lastError);
  349. return false;
  350. }
  351. if (params.verbose) {
  352. LOG_DEBUG("Successfully loaded GGUF model with minimal parameters: " + absModelPath.string());
  353. }
  354. } else {
  355. return false;
  356. }
  357. } else {
  358. if (params.verbose) {
  359. LOG_DEBUG("Successfully loaded model with fallback to model_path: " + absModelPath.string());
  360. }
  361. }
  362. } else {
  363. // Try minimal fallback for non-diffusion_model_path failures
  364. if (modelFileName.find(".gguf") != std::string::npos || modelFileName.find(".ggml") != std::string::npos) {
  365. if (params.verbose) {
  366. LOG_DEBUG("Detected GGUF/GGML model, attempting minimal parameter fallback...");
  367. }
  368. // Re-initialize with minimal parameters
  369. sd_ctx_params_init(&ctxParams);
  370. ctxParams.model_path = persistentModelPath.c_str();
  371. // Set only essential parameters for GGUF
  372. ctxParams.n_threads = params.nThreads;
  373. ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType);
  374. if (params.verbose) {
  375. LOG_DEBUG("Attempting to create context with minimal GGUF parameters...");
  376. }
  377. sdContext = new_sd_ctx(&ctxParams);
  378. if (!sdContext) {
  379. lastError = "Failed to create stable-diffusion context even with minimal GGUF parameters";
  380. LOG_ERROR("Error: " + lastError);
  381. return false;
  382. }
  383. if (params.verbose) {
  384. LOG_DEBUG("Successfully loaded GGUF model with minimal parameters: " + absModelPath.string());
  385. }
  386. } else {
  387. LOG_ERROR("Error: " + lastError);
  388. return false;
  389. }
  390. }
  391. }
  392. // Log successful loading with detection information
  393. if (params.verbose) {
  394. LOG_DEBUG("Successfully loaded model: " + absModelPath.string());
  395. LOG_DEBUG(" Detection source: " + detectionSource);
  396. LOG_DEBUG(" Loading method: " + std::string(useDiffusionModelPath ? "diffusion_model_path" : "model_path"));
  397. LOG_DEBUG(" Parent directory: " + parentDirName);
  398. LOG_DEBUG(" Model filename: " + modelFileName);
  399. }
  400. // Log additional model properties if architecture detection was performed
  401. if (detectionSuccessful && params.verbose) {
  402. LOG_DEBUG(" Architecture: " + detectionResult.architectureName);
  403. if (detectionResult.textEncoderDim > 0) {
  404. LOG_DEBUG(" Text encoder dimension: " + std::to_string(detectionResult.textEncoderDim));
  405. }
  406. if (detectionResult.needsVAE) {
  407. LOG_DEBUG(" Requires VAE: " + (detectionResult.recommendedVAE.empty() ? std::string("Yes") : detectionResult.recommendedVAE));
  408. }
  409. }
  410. // Store current model info for potential reload after upscaling
  411. currentModelPath = modelPath;
  412. currentModelParams = params;
  413. return true;
  414. }
  415. void unloadModel() {
  416. std::lock_guard<std::mutex> lock(contextMutex);
  417. if (sdContext) {
  418. free_sd_ctx(sdContext);
  419. sdContext = nullptr;
  420. if (verbose) {
  421. LOG_DEBUG("Unloaded stable-diffusion model");
  422. }
  423. }
  424. // Clear stored model info
  425. currentModelPath.clear();
  426. currentModelParams = StableDiffusionWrapper::GenerationParams();
  427. }
  428. bool isModelLoaded() const {
  429. return sdContext != nullptr;
  430. }
  431. std::vector<StableDiffusionWrapper::GeneratedImage> generateImage(
  432. const StableDiffusionWrapper::GenerationParams& params,
  433. StableDiffusionWrapper::ProgressCallback progressCallback,
  434. void* userData) {
  435. std::vector<StableDiffusionWrapper::GeneratedImage> results;
  436. if (!sdContext) {
  437. lastError = "No model loaded";
  438. return results;
  439. }
  440. auto startTime = std::chrono::high_resolution_clock::now();
  441. // Initialize generation parameters
  442. sd_img_gen_params_t genParams;
  443. sd_img_gen_params_init(&genParams);
  444. // Set basic parameters
  445. genParams.prompt = params.prompt.c_str();
  446. genParams.negative_prompt = params.negativePrompt.c_str();
  447. genParams.width = params.width;
  448. genParams.height = params.height;
  449. genParams.sample_params.sample_steps = params.steps;
  450. genParams.seed = params.seed;
  451. genParams.batch_count = params.batchCount;
  452. // Set sampling parameters
  453. genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod);
  454. genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler);
  455. genParams.sample_params.guidance.txt_cfg = params.cfgScale;
  456. // Set advanced parameters
  457. genParams.clip_skip = params.clipSkip;
  458. genParams.strength = params.strength;
  459. // Set progress callback if provided
  460. // Track callback data to ensure proper cleanup
  461. std::pair<StableDiffusionWrapper::ProgressCallback, void*>* callbackData = nullptr;
  462. if (progressCallback) {
  463. callbackData = new std::pair<StableDiffusionWrapper::ProgressCallback, void*>(progressCallback, userData);
  464. sd_set_progress_callback([](int step, int steps, float time, void* data) {
  465. auto* callbackData = static_cast<std::pair<StableDiffusionWrapper::ProgressCallback, void*>*>(data);
  466. if (callbackData) {
  467. callbackData->first(step, steps, time, callbackData->second);
  468. }
  469. },
  470. callbackData);
  471. }
  472. // Generate the image
  473. LOG_DEBUG("[TIMING_ANALYSIS] Starting generate_image() call");
  474. auto generationCallStart = std::chrono::high_resolution_clock::now();
  475. sd_image_t* sdImages = generate_image(sdContext, &genParams);
  476. auto generationCallEnd = std::chrono::high_resolution_clock::now();
  477. auto generationCallTime = std::chrono::duration_cast<std::chrono::milliseconds>(generationCallEnd - generationCallStart).count();
  478. LOG_DEBUG("[TIMING_ANALYSIS] generate_image() call completed in " + std::to_string(generationCallTime) + "ms");
  479. // Clear and clean up progress callback - FIX: Wait for any pending callbacks
  480. sd_set_progress_callback(nullptr, nullptr);
  481. // Add a small delay to ensure any in-flight callbacks complete before cleanup
  482. std::this_thread::sleep_for(std::chrono::milliseconds(10));
  483. if (callbackData) {
  484. delete callbackData;
  485. callbackData = nullptr;
  486. }
  487. auto endTime = std::chrono::high_resolution_clock::now();
  488. auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
  489. if (!sdImages) {
  490. lastError = "Failed to generate image";
  491. return results;
  492. }
  493. // Convert stable-diffusion images to our format
  494. for (int i = 0; i < params.batchCount; i++) {
  495. StableDiffusionWrapper::GeneratedImage image;
  496. image.width = sdImages[i].width;
  497. image.height = sdImages[i].height;
  498. image.channels = sdImages[i].channel;
  499. image.seed = params.seed;
  500. image.generationTime = duration.count();
  501. // Copy image data
  502. if (sdImages[i].data && sdImages[i].width > 0 && sdImages[i].height > 0 && sdImages[i].channel > 0) {
  503. size_t dataSize = sdImages[i].width * sdImages[i].height * sdImages[i].channel;
  504. image.data.resize(dataSize);
  505. std::memcpy(image.data.data(), sdImages[i].data, dataSize);
  506. }
  507. results.push_back(image);
  508. }
  509. // Free the generated images
  510. // Clean up each image's data array
  511. for (int i = 0; i < params.batchCount; i++) {
  512. if (sdImages[i].data) {
  513. free(sdImages[i].data);
  514. sdImages[i].data = nullptr;
  515. }
  516. }
  517. // Free the image array itself
  518. free(sdImages);
  519. return results;
  520. }
  521. std::vector<StableDiffusionWrapper::GeneratedImage> generateImageImg2Img(
  522. const StableDiffusionWrapper::GenerationParams& params,
  523. const std::vector<uint8_t>& inputData,
  524. int inputWidth,
  525. int inputHeight,
  526. StableDiffusionWrapper::ProgressCallback progressCallback,
  527. void* userData) {
  528. std::vector<StableDiffusionWrapper::GeneratedImage> results;
  529. if (!sdContext) {
  530. lastError = "No model loaded";
  531. return results;
  532. }
  533. auto startTime = std::chrono::high_resolution_clock::now();
  534. // Initialize generation parameters
  535. sd_img_gen_params_t genParams;
  536. sd_img_gen_params_init(&genParams);
  537. // Set basic parameters
  538. genParams.prompt = params.prompt.c_str();
  539. genParams.negative_prompt = params.negativePrompt.c_str();
  540. genParams.width = params.width;
  541. genParams.height = params.height;
  542. genParams.sample_params.sample_steps = params.steps;
  543. genParams.seed = params.seed;
  544. genParams.batch_count = params.batchCount;
  545. genParams.strength = params.strength;
  546. // Set sampling parameters
  547. genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod);
  548. genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler);
  549. genParams.sample_params.guidance.txt_cfg = params.cfgScale;
  550. // Set advanced parameters
  551. genParams.clip_skip = params.clipSkip;
  552. // Set input image
  553. sd_image_t initImage;
  554. initImage.width = inputWidth;
  555. initImage.height = inputHeight;
  556. initImage.channel = 3; // RGB
  557. initImage.data = const_cast<uint8_t*>(inputData.data());
  558. genParams.init_image = initImage;
  559. // Set progress callback if provided
  560. // Track callback data to ensure proper cleanup
  561. std::pair<StableDiffusionWrapper::ProgressCallback, void*>* callbackData = nullptr;
  562. if (progressCallback) {
  563. callbackData = new std::pair<StableDiffusionWrapper::ProgressCallback, void*>(progressCallback, userData);
  564. sd_set_progress_callback([](int step, int steps, float time, void* data) {
  565. auto* callbackData = static_cast<std::pair<StableDiffusionWrapper::ProgressCallback, void*>*>(data);
  566. if (callbackData) {
  567. callbackData->first(step, steps, time, callbackData->second);
  568. }
  569. },
  570. callbackData);
  571. }
  572. // Generate the image
  573. sd_image_t* sdImages = generate_image(sdContext, &genParams);
  574. // Clear and clean up progress callback - FIX: Wait for any pending callbacks
  575. sd_set_progress_callback(nullptr, nullptr);
  576. // Add a small delay to ensure any in-flight callbacks complete before cleanup
  577. std::this_thread::sleep_for(std::chrono::milliseconds(10));
  578. if (callbackData) {
  579. delete callbackData;
  580. callbackData = nullptr;
  581. }
  582. auto endTime = std::chrono::high_resolution_clock::now();
  583. auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
  584. if (!sdImages) {
  585. lastError = "Failed to generate image";
  586. return results;
  587. }
  588. // Convert stable-diffusion images to our format
  589. for (int i = 0; i < params.batchCount; i++) {
  590. StableDiffusionWrapper::GeneratedImage image;
  591. image.width = sdImages[i].width;
  592. image.height = sdImages[i].height;
  593. image.channels = sdImages[i].channel;
  594. image.seed = params.seed;
  595. image.generationTime = duration.count();
  596. // Copy image data
  597. if (sdImages[i].data && sdImages[i].width > 0 && sdImages[i].height > 0 && sdImages[i].channel > 0) {
  598. size_t dataSize = sdImages[i].width * sdImages[i].height * sdImages[i].channel;
  599. image.data.resize(dataSize);
  600. std::memcpy(image.data.data(), sdImages[i].data, dataSize);
  601. }
  602. results.push_back(image);
  603. }
  604. // Free the generated images
  605. // Clean up each image's data array
  606. for (int i = 0; i < params.batchCount; i++) {
  607. if (sdImages[i].data) {
  608. free(sdImages[i].data);
  609. sdImages[i].data = nullptr;
  610. }
  611. }
  612. // Free the image array itself
  613. free(sdImages);
  614. return results;
  615. }
  616. std::vector<StableDiffusionWrapper::GeneratedImage> generateImageControlNet(
  617. const StableDiffusionWrapper::GenerationParams& params,
  618. const std::vector<uint8_t>& controlData,
  619. int controlWidth,
  620. int controlHeight,
  621. StableDiffusionWrapper::ProgressCallback progressCallback,
  622. void* userData) {
  623. std::vector<StableDiffusionWrapper::GeneratedImage> results;
  624. if (!sdContext) {
  625. lastError = "No model loaded";
  626. return results;
  627. }
  628. auto startTime = std::chrono::high_resolution_clock::now();
  629. // Initialize generation parameters
  630. sd_img_gen_params_t genParams;
  631. sd_img_gen_params_init(&genParams);
  632. // Set basic parameters
  633. genParams.prompt = params.prompt.c_str();
  634. genParams.negative_prompt = params.negativePrompt.c_str();
  635. genParams.width = params.width;
  636. genParams.height = params.height;
  637. genParams.sample_params.sample_steps = params.steps;
  638. genParams.seed = params.seed;
  639. genParams.batch_count = params.batchCount;
  640. genParams.control_strength = params.controlStrength;
  641. // Set sampling parameters
  642. genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod);
  643. genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler);
  644. genParams.sample_params.guidance.txt_cfg = params.cfgScale;
  645. // Set advanced parameters
  646. genParams.clip_skip = params.clipSkip;
  647. // Set control image
  648. sd_image_t controlImage;
  649. controlImage.width = controlWidth;
  650. controlImage.height = controlHeight;
  651. controlImage.channel = 3; // RGB
  652. controlImage.data = const_cast<uint8_t*>(controlData.data());
  653. genParams.control_image = controlImage;
  654. // Set progress callback if provided
  655. // Track callback data to ensure proper cleanup
  656. std::pair<StableDiffusionWrapper::ProgressCallback, void*>* callbackData = nullptr;
  657. if (progressCallback) {
  658. callbackData = new std::pair<StableDiffusionWrapper::ProgressCallback, void*>(progressCallback, userData);
  659. sd_set_progress_callback([](int step, int steps, float time, void* data) {
  660. auto* callbackData = static_cast<std::pair<StableDiffusionWrapper::ProgressCallback, void*>*>(data);
  661. if (callbackData) {
  662. callbackData->first(step, steps, time, callbackData->second);
  663. }
  664. },
  665. callbackData);
  666. }
  667. // Generate the image
  668. sd_image_t* sdImages = generate_image(sdContext, &genParams);
  669. // Clear and clean up progress callback - FIX: Wait for any pending callbacks
  670. sd_set_progress_callback(nullptr, nullptr);
  671. // Add a small delay to ensure any in-flight callbacks complete before cleanup
  672. std::this_thread::sleep_for(std::chrono::milliseconds(10));
  673. if (callbackData) {
  674. delete callbackData;
  675. callbackData = nullptr;
  676. }
  677. auto endTime = std::chrono::high_resolution_clock::now();
  678. auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
  679. if (!sdImages) {
  680. lastError = "Failed to generate image";
  681. return results;
  682. }
  683. // Convert stable-diffusion images to our format
  684. for (int i = 0; i < params.batchCount; i++) {
  685. StableDiffusionWrapper::GeneratedImage image;
  686. image.width = sdImages[i].width;
  687. image.height = sdImages[i].height;
  688. image.channels = sdImages[i].channel;
  689. image.seed = params.seed;
  690. image.generationTime = duration.count();
  691. // Copy image data
  692. if (sdImages[i].data && sdImages[i].width > 0 && sdImages[i].height > 0 && sdImages[i].channel > 0) {
  693. size_t dataSize = sdImages[i].width * sdImages[i].height * sdImages[i].channel;
  694. image.data.resize(dataSize);
  695. std::memcpy(image.data.data(), sdImages[i].data, dataSize);
  696. }
  697. results.push_back(image);
  698. }
  699. // Free the generated images
  700. // Clean up each image's data array
  701. for (int i = 0; i < params.batchCount; i++) {
  702. if (sdImages[i].data) {
  703. free(sdImages[i].data);
  704. sdImages[i].data = nullptr;
  705. }
  706. }
  707. // Free the image array itself
  708. free(sdImages);
  709. return results;
  710. }
  711. std::vector<StableDiffusionWrapper::GeneratedImage> generateImageInpainting(
  712. const StableDiffusionWrapper::GenerationParams& params,
  713. const std::vector<uint8_t>& inputData,
  714. int inputWidth,
  715. int inputHeight,
  716. const std::vector<uint8_t>& maskData,
  717. int maskWidth,
  718. int maskHeight,
  719. StableDiffusionWrapper::ProgressCallback progressCallback,
  720. void* userData) {
  721. std::vector<StableDiffusionWrapper::GeneratedImage> results;
  722. if (!sdContext) {
  723. lastError = "No model loaded";
  724. return results;
  725. }
  726. auto startTime = std::chrono::high_resolution_clock::now();
  727. // Initialize generation parameters
  728. sd_img_gen_params_t genParams;
  729. sd_img_gen_params_init(&genParams);
  730. // Set basic parameters
  731. genParams.prompt = params.prompt.c_str();
  732. genParams.negative_prompt = params.negativePrompt.c_str();
  733. genParams.width = params.width;
  734. genParams.height = params.height;
  735. genParams.sample_params.sample_steps = params.steps;
  736. genParams.seed = params.seed;
  737. genParams.batch_count = params.batchCount;
  738. genParams.strength = params.strength;
  739. // Set sampling parameters
  740. genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod);
  741. genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler);
  742. genParams.sample_params.guidance.txt_cfg = params.cfgScale;
  743. // Set advanced parameters
  744. genParams.clip_skip = params.clipSkip;
  745. // Set input image
  746. sd_image_t initImage;
  747. initImage.width = inputWidth;
  748. initImage.height = inputHeight;
  749. initImage.channel = 3; // RGB
  750. initImage.data = const_cast<uint8_t*>(inputData.data());
  751. genParams.init_image = initImage;
  752. // Set mask image
  753. sd_image_t maskImage;
  754. maskImage.width = maskWidth;
  755. maskImage.height = maskHeight;
  756. maskImage.channel = 1; // Grayscale mask
  757. maskImage.data = const_cast<uint8_t*>(maskData.data());
  758. genParams.mask_image = maskImage;
  759. // Set progress callback if provided
  760. // Track callback data to ensure proper cleanup
  761. std::pair<StableDiffusionWrapper::ProgressCallback, void*>* callbackData = nullptr;
  762. if (progressCallback) {
  763. callbackData = new std::pair<StableDiffusionWrapper::ProgressCallback, void*>(progressCallback, userData);
  764. sd_set_progress_callback([](int step, int steps, float time, void* data) {
  765. auto* callbackData = static_cast<std::pair<StableDiffusionWrapper::ProgressCallback, void*>*>(data);
  766. if (callbackData) {
  767. callbackData->first(step, steps, time, callbackData->second);
  768. }
  769. },
  770. callbackData);
  771. }
  772. // Generate the image
  773. sd_image_t* sdImages = generate_image(sdContext, &genParams);
  774. // Clear and clean up progress callback - FIX: Wait for any pending callbacks
  775. sd_set_progress_callback(nullptr, nullptr);
  776. // Add a small delay to ensure any in-flight callbacks complete before cleanup
  777. std::this_thread::sleep_for(std::chrono::milliseconds(10));
  778. if (callbackData) {
  779. delete callbackData;
  780. callbackData = nullptr;
  781. }
  782. auto endTime = std::chrono::high_resolution_clock::now();
  783. auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
  784. if (!sdImages) {
  785. lastError = "Failed to generate image";
  786. return results;
  787. }
  788. // Convert stable-diffusion images to our format
  789. for (int i = 0; i < params.batchCount; i++) {
  790. StableDiffusionWrapper::GeneratedImage image;
  791. image.width = sdImages[i].width;
  792. image.height = sdImages[i].height;
  793. image.channels = sdImages[i].channel;
  794. image.seed = params.seed;
  795. image.generationTime = duration.count();
  796. // Copy image data
  797. if (sdImages[i].data && sdImages[i].width > 0 && sdImages[i].height > 0 && sdImages[i].channel > 0) {
  798. size_t dataSize = sdImages[i].width * sdImages[i].height * sdImages[i].channel;
  799. image.data.resize(dataSize);
  800. std::memcpy(image.data.data(), sdImages[i].data, dataSize);
  801. }
  802. results.push_back(image);
  803. }
  804. // Free the generated images
  805. // Clean up each image's data array
  806. for (int i = 0; i < params.batchCount; i++) {
  807. if (sdImages[i].data) {
  808. free(sdImages[i].data);
  809. sdImages[i].data = nullptr;
  810. }
  811. }
  812. // Free the image array itself
  813. free(sdImages);
  814. return results;
  815. }
  816. StableDiffusionWrapper::GeneratedImage upscaleImage(
  817. const std::string& esrganPath,
  818. const std::vector<uint8_t>& inputData,
  819. int inputWidth,
  820. int inputHeight,
  821. int inputChannels,
  822. uint32_t upscaleFactor,
  823. int nThreads,
  824. bool offloadParamsToCpu,
  825. bool direct) {
  826. StableDiffusionWrapper::GeneratedImage result;
  827. auto startTime = std::chrono::high_resolution_clock::now();
  828. // Unload stable diffusion checkpoint before loading upscaler to prevent memory conflicts
  829. {
  830. std::lock_guard<std::mutex> lock(contextMutex);
  831. if (sdContext) {
  832. if (verbose) {
  833. LOG_DEBUG("Unloading stable diffusion checkpoint before loading upscaler model");
  834. }
  835. free_sd_ctx(sdContext);
  836. sdContext = nullptr;
  837. }
  838. }
  839. // Create upscaler context
  840. upscaler_ctx_t* upscalerCtx = new_upscaler_ctx(
  841. esrganPath.c_str(),
  842. offloadParamsToCpu,
  843. direct,
  844. nThreads);
  845. if (!upscalerCtx) {
  846. lastError = "Failed to create upscaler context";
  847. return result;
  848. }
  849. // Prepare input image
  850. sd_image_t inputImage;
  851. inputImage.width = inputWidth;
  852. inputImage.height = inputHeight;
  853. inputImage.channel = inputChannels;
  854. inputImage.data = const_cast<uint8_t*>(inputData.data());
  855. // Perform upscaling
  856. sd_image_t upscaled = upscale(upscalerCtx, inputImage, upscaleFactor);
  857. auto endTime = std::chrono::high_resolution_clock::now();
  858. auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
  859. if (!upscaled.data) {
  860. lastError = "Failed to upscale image";
  861. free_upscaler_ctx(upscalerCtx);
  862. return result;
  863. }
  864. // Convert to our format
  865. result.width = upscaled.width;
  866. result.height = upscaled.height;
  867. result.channels = upscaled.channel;
  868. result.seed = 0; // No seed for upscaling
  869. result.generationTime = duration.count();
  870. // Copy image data
  871. if (upscaled.data && upscaled.width > 0 && upscaled.height > 0 && upscaled.channel > 0) {
  872. size_t dataSize = upscaled.width * upscaled.height * upscaled.channel;
  873. result.data.resize(dataSize);
  874. std::memcpy(result.data.data(), upscaled.data, dataSize);
  875. }
  876. // Clean up
  877. free_upscaler_ctx(upscalerCtx);
  878. return result;
  879. }
  880. std::string getLastError() const {
  881. return lastError;
  882. }
  883. };
  884. // Static helper functions
  885. sample_method_t StableDiffusionWrapper::stringToSamplingMethod(const std::string& method) {
  886. std::string lowerMethod = method;
  887. std::transform(lowerMethod.begin(), lowerMethod.end(), lowerMethod.begin(), ::tolower);
  888. if (lowerMethod == "euler") {
  889. return EULER;
  890. } else if (lowerMethod == "euler_a") {
  891. return EULER_A;
  892. } else if (lowerMethod == "heun") {
  893. return HEUN;
  894. } else if (lowerMethod == "dpm2") {
  895. return DPM2;
  896. } else if (lowerMethod == "dpmpp2s_a") {
  897. return DPMPP2S_A;
  898. } else if (lowerMethod == "dpmpp2m") {
  899. return DPMPP2M;
  900. } else if (lowerMethod == "dpmpp2mv2") {
  901. return DPMPP2Mv2;
  902. } else if (lowerMethod == "ipndm") {
  903. return IPNDM;
  904. } else if (lowerMethod == "ipndm_v") {
  905. return IPNDM_V;
  906. } else if (lowerMethod == "lcm") {
  907. return LCM;
  908. } else if (lowerMethod == "ddim_trailing") {
  909. return DDIM_TRAILING;
  910. } else if (lowerMethod == "tcd") {
  911. return TCD;
  912. } else {
  913. return SAMPLE_METHOD_DEFAULT;
  914. }
  915. }
  916. scheduler_t StableDiffusionWrapper::stringToScheduler(const std::string& scheduler) {
  917. std::string lowerScheduler = scheduler;
  918. std::transform(lowerScheduler.begin(), lowerScheduler.end(), lowerScheduler.begin(), ::tolower);
  919. if (lowerScheduler == "discrete") {
  920. return DISCRETE;
  921. } else if (lowerScheduler == "karras") {
  922. return KARRAS;
  923. } else if (lowerScheduler == "exponential") {
  924. return EXPONENTIAL;
  925. } else if (lowerScheduler == "ays") {
  926. return AYS;
  927. } else if (lowerScheduler == "gits") {
  928. return GITS;
  929. } else if (lowerScheduler == "smoothstep") {
  930. return SMOOTHSTEP;
  931. } else if (lowerScheduler == "sgm_uniform") {
  932. return SGM_UNIFORM;
  933. } else if (lowerScheduler == "simple") {
  934. return SIMPLE;
  935. } else {
  936. return DEFAULT;
  937. }
  938. }
  939. sd_type_t StableDiffusionWrapper::stringToModelType(const std::string& type) {
  940. std::string lowerType = type;
  941. std::transform(lowerType.begin(), lowerType.end(), lowerType.begin(), ::tolower);
  942. if (lowerType == "f32") {
  943. return SD_TYPE_F32;
  944. } else if (lowerType == "f16") {
  945. return SD_TYPE_F16;
  946. } else if (lowerType == "q4_0") {
  947. return SD_TYPE_Q4_0;
  948. } else if (lowerType == "q4_1") {
  949. return SD_TYPE_Q4_1;
  950. } else if (lowerType == "q5_0") {
  951. return SD_TYPE_Q5_0;
  952. } else if (lowerType == "q5_1") {
  953. return SD_TYPE_Q5_1;
  954. } else if (lowerType == "q8_0") {
  955. return SD_TYPE_Q8_0;
  956. } else if (lowerType == "q8_1") {
  957. return SD_TYPE_Q8_1;
  958. } else if (lowerType == "q2_k") {
  959. return SD_TYPE_Q2_K;
  960. } else if (lowerType == "q3_k") {
  961. return SD_TYPE_Q3_K;
  962. } else if (lowerType == "q4_k") {
  963. return SD_TYPE_Q4_K;
  964. } else if (lowerType == "q5_k") {
  965. return SD_TYPE_Q5_K;
  966. } else if (lowerType == "q6_k") {
  967. return SD_TYPE_Q6_K;
  968. } else if (lowerType == "q8_k") {
  969. return SD_TYPE_Q8_K;
  970. } else {
  971. return SD_TYPE_F16; // Default to F16
  972. }
  973. }
  974. // Public interface implementation
  975. StableDiffusionWrapper::StableDiffusionWrapper() : pImpl(std::make_unique<Impl>()) {
  976. // wrapperMutex is automatically initialized by std::mutex default constructor
  977. }
  978. StableDiffusionWrapper::~StableDiffusionWrapper() = default;
  979. bool StableDiffusionWrapper::loadModel(const std::string& modelPath, const GenerationParams& params) {
  980. std::lock_guard<std::mutex> lock(wrapperMutex);
  981. return pImpl->loadModel(modelPath, params);
  982. }
  983. void StableDiffusionWrapper::unloadModel() {
  984. std::lock_guard<std::mutex> lock(wrapperMutex);
  985. pImpl->unloadModel();
  986. }
  987. bool StableDiffusionWrapper::isModelLoaded() const {
  988. std::lock_guard<std::mutex> lock(wrapperMutex);
  989. return pImpl->isModelLoaded();
  990. }
  991. std::vector<StableDiffusionWrapper::GeneratedImage> StableDiffusionWrapper::generateImage(
  992. const GenerationParams& params,
  993. ProgressCallback progressCallback,
  994. void* userData) {
  995. std::lock_guard<std::mutex> lock(wrapperMutex);
  996. return pImpl->generateImage(params, progressCallback, userData);
  997. }
  998. std::vector<StableDiffusionWrapper::GeneratedImage> StableDiffusionWrapper::generateImageImg2Img(
  999. const GenerationParams& params,
  1000. const std::vector<uint8_t>& inputData,
  1001. int inputWidth,
  1002. int inputHeight,
  1003. ProgressCallback progressCallback,
  1004. void* userData) {
  1005. std::lock_guard<std::mutex> lock(wrapperMutex);
  1006. return pImpl->generateImageImg2Img(params, inputData, inputWidth, inputHeight, progressCallback, userData);
  1007. }
  1008. std::vector<StableDiffusionWrapper::GeneratedImage> StableDiffusionWrapper::generateImageControlNet(
  1009. const GenerationParams& params,
  1010. const std::vector<uint8_t>& controlData,
  1011. int controlWidth,
  1012. int controlHeight,
  1013. ProgressCallback progressCallback,
  1014. void* userData) {
  1015. std::lock_guard<std::mutex> lock(wrapperMutex);
  1016. return pImpl->generateImageControlNet(params, controlData, controlWidth, controlHeight, progressCallback, userData);
  1017. }
  1018. std::vector<StableDiffusionWrapper::GeneratedImage> StableDiffusionWrapper::generateImageInpainting(
  1019. const GenerationParams& params,
  1020. const std::vector<uint8_t>& inputData,
  1021. int inputWidth,
  1022. int inputHeight,
  1023. const std::vector<uint8_t>& maskData,
  1024. int maskWidth,
  1025. int maskHeight,
  1026. ProgressCallback progressCallback,
  1027. void* userData) {
  1028. std::lock_guard<std::mutex> lock(wrapperMutex);
  1029. return pImpl->generateImageInpainting(params, inputData, inputWidth, inputHeight, maskData, maskWidth, maskHeight, progressCallback, userData);
  1030. }
  1031. StableDiffusionWrapper::GeneratedImage StableDiffusionWrapper::upscaleImage(
  1032. const std::string& esrganPath,
  1033. const std::vector<uint8_t>& inputData,
  1034. int inputWidth,
  1035. int inputHeight,
  1036. int inputChannels,
  1037. uint32_t upscaleFactor,
  1038. int nThreads,
  1039. bool offloadParamsToCpu,
  1040. bool direct) {
  1041. std::lock_guard<std::mutex> lock(wrapperMutex);
  1042. return pImpl->upscaleImage(esrganPath, inputData, inputWidth, inputHeight, inputChannels, upscaleFactor, nThreads, offloadParamsToCpu, direct);
  1043. }
  1044. std::string StableDiffusionWrapper::getLastError() const {
  1045. std::lock_guard<std::mutex> lock(wrapperMutex);
  1046. return pImpl->getLastError();
  1047. }