stable_diffusion_wrapper.cpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660
  1. #include "stable_diffusion_wrapper.h"
  2. #include <iostream>
  3. #include <chrono>
  4. #include <cstring>
  5. #include <algorithm>
  6. #include <random>
  7. extern "C" {
  8. #include "stable-diffusion.h"
  9. }
  10. class StableDiffusionWrapper::Impl {
  11. public:
  12. sd_ctx_t* sdContext = nullptr;
  13. std::string lastError;
  14. std::mutex contextMutex;
  15. Impl() {
  16. // Initialize any required resources
  17. }
  18. ~Impl() {
  19. unloadModel();
  20. }
  21. bool loadModel(const std::string& modelPath, const StableDiffusionWrapper::GenerationParams& params) {
  22. std::lock_guard<std::mutex> lock(contextMutex);
  23. // Unload any existing model
  24. if (sdContext) {
  25. free_sd_ctx(sdContext);
  26. sdContext = nullptr;
  27. }
  28. // Initialize context parameters
  29. sd_ctx_params_t ctxParams;
  30. sd_ctx_params_init(&ctxParams);
  31. // Set model path
  32. ctxParams.model_path = modelPath.c_str();
  33. // Set optional model paths if provided
  34. if (!params.clipLPath.empty()) {
  35. ctxParams.clip_l_path = params.clipLPath.c_str();
  36. }
  37. if (!params.clipGPath.empty()) {
  38. ctxParams.clip_g_path = params.clipGPath.c_str();
  39. }
  40. if (!params.vaePath.empty()) {
  41. ctxParams.vae_path = params.vaePath.c_str();
  42. }
  43. if (!params.taesdPath.empty()) {
  44. ctxParams.taesd_path = params.taesdPath.c_str();
  45. }
  46. if (!params.controlNetPath.empty()) {
  47. ctxParams.control_net_path = params.controlNetPath.c_str();
  48. }
  49. if (!params.loraModelDir.empty()) {
  50. ctxParams.lora_model_dir = params.loraModelDir.c_str();
  51. }
  52. if (!params.embeddingDir.empty()) {
  53. ctxParams.embedding_dir = params.embeddingDir.c_str();
  54. }
  55. // Set performance parameters
  56. ctxParams.n_threads = params.nThreads;
  57. ctxParams.offload_params_to_cpu = params.offloadParamsToCpu;
  58. ctxParams.keep_clip_on_cpu = params.clipOnCpu;
  59. ctxParams.keep_vae_on_cpu = params.vaeOnCpu;
  60. ctxParams.diffusion_flash_attn = params.diffusionFlashAttn;
  61. ctxParams.diffusion_conv_direct = params.diffusionConvDirect;
  62. ctxParams.vae_conv_direct = params.vaeConvDirect;
  63. // Set model type
  64. ctxParams.wtype = StableDiffusionWrapper::stringToModelType(params.modelType);
  65. // Create the stable-diffusion context
  66. sdContext = new_sd_ctx(&ctxParams);
  67. if (!sdContext) {
  68. lastError = "Failed to create stable-diffusion context";
  69. return false;
  70. }
  71. std::cout << "Successfully loaded model: " << modelPath << std::endl;
  72. return true;
  73. }
  74. void unloadModel() {
  75. std::lock_guard<std::mutex> lock(contextMutex);
  76. if (sdContext) {
  77. free_sd_ctx(sdContext);
  78. sdContext = nullptr;
  79. std::cout << "Unloaded stable-diffusion model" << std::endl;
  80. }
  81. }
  82. bool isModelLoaded() const {
  83. return sdContext != nullptr;
  84. }
  85. std::vector<StableDiffusionWrapper::GeneratedImage> generateImage(
  86. const StableDiffusionWrapper::GenerationParams& params,
  87. StableDiffusionWrapper::ProgressCallback progressCallback,
  88. void* userData) {
  89. std::vector<StableDiffusionWrapper::GeneratedImage> results;
  90. if (!sdContext) {
  91. lastError = "No model loaded";
  92. return results;
  93. }
  94. auto startTime = std::chrono::high_resolution_clock::now();
  95. // Initialize generation parameters
  96. sd_img_gen_params_t genParams;
  97. sd_img_gen_params_init(&genParams);
  98. // Set basic parameters
  99. genParams.prompt = params.prompt.c_str();
  100. genParams.negative_prompt = params.negativePrompt.c_str();
  101. genParams.width = params.width;
  102. genParams.height = params.height;
  103. genParams.sample_params.sample_steps = params.steps;
  104. genParams.seed = params.seed;
  105. genParams.batch_count = params.batchCount;
  106. // Set sampling parameters
  107. genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod);
  108. genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler);
  109. genParams.sample_params.guidance.txt_cfg = params.cfgScale;
  110. // Set advanced parameters
  111. genParams.clip_skip = params.clipSkip;
  112. genParams.strength = params.strength;
  113. // Set progress callback if provided
  114. // Track callback data to ensure proper cleanup
  115. std::pair<StableDiffusionWrapper::ProgressCallback, void*>* callbackData = nullptr;
  116. if (progressCallback) {
  117. callbackData = new std::pair<StableDiffusionWrapper::ProgressCallback, void*>(progressCallback, userData);
  118. sd_set_progress_callback([](int step, int steps, float time, void* data) {
  119. auto* callbackData = static_cast<std::pair<StableDiffusionWrapper::ProgressCallback, void*>*>(data);
  120. if (callbackData) {
  121. callbackData->first(step, steps, time, callbackData->second);
  122. }
  123. }, callbackData);
  124. }
  125. // Generate the image
  126. sd_image_t* sdImages = generate_image(sdContext, &genParams);
  127. // Clear and clean up progress callback
  128. sd_set_progress_callback(nullptr, nullptr);
  129. if (callbackData) {
  130. delete callbackData;
  131. callbackData = nullptr;
  132. }
  133. auto endTime = std::chrono::high_resolution_clock::now();
  134. auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
  135. if (!sdImages) {
  136. lastError = "Failed to generate image";
  137. return results;
  138. }
  139. // Convert stable-diffusion images to our format
  140. for (int i = 0; i < params.batchCount; i++) {
  141. StableDiffusionWrapper::GeneratedImage image;
  142. image.width = sdImages[i].width;
  143. image.height = sdImages[i].height;
  144. image.channels = sdImages[i].channel;
  145. image.seed = params.seed;
  146. image.generationTime = duration.count();
  147. // Copy image data
  148. if (sdImages[i].data && sdImages[i].width > 0 && sdImages[i].height > 0 && sdImages[i].channel > 0) {
  149. size_t dataSize = sdImages[i].width * sdImages[i].height * sdImages[i].channel;
  150. image.data.resize(dataSize);
  151. std::memcpy(image.data.data(), sdImages[i].data, dataSize);
  152. }
  153. results.push_back(image);
  154. }
  155. // Free the generated images
  156. // Clean up each image's data array
  157. for (int i = 0; i < params.batchCount; i++) {
  158. if (sdImages[i].data) {
  159. free(sdImages[i].data);
  160. sdImages[i].data = nullptr;
  161. }
  162. }
  163. // Free the image array itself
  164. free(sdImages);
  165. return results;
  166. }
  167. std::vector<StableDiffusionWrapper::GeneratedImage> generateImageImg2Img(
  168. const StableDiffusionWrapper::GenerationParams& params,
  169. const std::vector<uint8_t>& inputData,
  170. int inputWidth,
  171. int inputHeight,
  172. StableDiffusionWrapper::ProgressCallback progressCallback,
  173. void* userData) {
  174. std::vector<StableDiffusionWrapper::GeneratedImage> results;
  175. if (!sdContext) {
  176. lastError = "No model loaded";
  177. return results;
  178. }
  179. auto startTime = std::chrono::high_resolution_clock::now();
  180. // Initialize generation parameters
  181. sd_img_gen_params_t genParams;
  182. sd_img_gen_params_init(&genParams);
  183. // Set basic parameters
  184. genParams.prompt = params.prompt.c_str();
  185. genParams.negative_prompt = params.negativePrompt.c_str();
  186. genParams.width = params.width;
  187. genParams.height = params.height;
  188. genParams.sample_params.sample_steps = params.steps;
  189. genParams.seed = params.seed;
  190. genParams.batch_count = params.batchCount;
  191. genParams.strength = params.strength;
  192. // Set sampling parameters
  193. genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod);
  194. genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler);
  195. genParams.sample_params.guidance.txt_cfg = params.cfgScale;
  196. // Set advanced parameters
  197. genParams.clip_skip = params.clipSkip;
  198. // Set input image
  199. sd_image_t initImage;
  200. initImage.width = inputWidth;
  201. initImage.height = inputHeight;
  202. initImage.channel = 3; // RGB
  203. initImage.data = const_cast<uint8_t*>(inputData.data());
  204. genParams.init_image = initImage;
  205. // Set progress callback if provided
  206. // Track callback data to ensure proper cleanup
  207. std::pair<StableDiffusionWrapper::ProgressCallback, void*>* callbackData = nullptr;
  208. if (progressCallback) {
  209. callbackData = new std::pair<StableDiffusionWrapper::ProgressCallback, void*>(progressCallback, userData);
  210. sd_set_progress_callback([](int step, int steps, float time, void* data) {
  211. auto* callbackData = static_cast<std::pair<StableDiffusionWrapper::ProgressCallback, void*>*>(data);
  212. if (callbackData) {
  213. callbackData->first(step, steps, time, callbackData->second);
  214. }
  215. }, callbackData);
  216. }
  217. // Generate the image
  218. sd_image_t* sdImages = generate_image(sdContext, &genParams);
  219. // Clear and clean up progress callback
  220. sd_set_progress_callback(nullptr, nullptr);
  221. if (callbackData) {
  222. delete callbackData;
  223. callbackData = nullptr;
  224. }
  225. auto endTime = std::chrono::high_resolution_clock::now();
  226. auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
  227. if (!sdImages) {
  228. lastError = "Failed to generate image";
  229. return results;
  230. }
  231. // Convert stable-diffusion images to our format
  232. for (int i = 0; i < params.batchCount; i++) {
  233. StableDiffusionWrapper::GeneratedImage image;
  234. image.width = sdImages[i].width;
  235. image.height = sdImages[i].height;
  236. image.channels = sdImages[i].channel;
  237. image.seed = params.seed;
  238. image.generationTime = duration.count();
  239. // Copy image data
  240. if (sdImages[i].data && sdImages[i].width > 0 && sdImages[i].height > 0 && sdImages[i].channel > 0) {
  241. size_t dataSize = sdImages[i].width * sdImages[i].height * sdImages[i].channel;
  242. image.data.resize(dataSize);
  243. std::memcpy(image.data.data(), sdImages[i].data, dataSize);
  244. }
  245. results.push_back(image);
  246. }
  247. // Free the generated images
  248. // Clean up each image's data array
  249. for (int i = 0; i < params.batchCount; i++) {
  250. if (sdImages[i].data) {
  251. free(sdImages[i].data);
  252. sdImages[i].data = nullptr;
  253. }
  254. }
  255. // Free the image array itself
  256. free(sdImages);
  257. return results;
  258. }
  259. std::vector<StableDiffusionWrapper::GeneratedImage> generateImageControlNet(
  260. const StableDiffusionWrapper::GenerationParams& params,
  261. const std::vector<uint8_t>& controlData,
  262. int controlWidth,
  263. int controlHeight,
  264. StableDiffusionWrapper::ProgressCallback progressCallback,
  265. void* userData) {
  266. std::vector<StableDiffusionWrapper::GeneratedImage> results;
  267. if (!sdContext) {
  268. lastError = "No model loaded";
  269. return results;
  270. }
  271. auto startTime = std::chrono::high_resolution_clock::now();
  272. // Initialize generation parameters
  273. sd_img_gen_params_t genParams;
  274. sd_img_gen_params_init(&genParams);
  275. // Set basic parameters
  276. genParams.prompt = params.prompt.c_str();
  277. genParams.negative_prompt = params.negativePrompt.c_str();
  278. genParams.width = params.width;
  279. genParams.height = params.height;
  280. genParams.sample_params.sample_steps = params.steps;
  281. genParams.seed = params.seed;
  282. genParams.batch_count = params.batchCount;
  283. genParams.control_strength = params.controlStrength;
  284. // Set sampling parameters
  285. genParams.sample_params.sample_method = StableDiffusionWrapper::stringToSamplingMethod(params.samplingMethod);
  286. genParams.sample_params.scheduler = StableDiffusionWrapper::stringToScheduler(params.scheduler);
  287. genParams.sample_params.guidance.txt_cfg = params.cfgScale;
  288. // Set advanced parameters
  289. genParams.clip_skip = params.clipSkip;
  290. // Set control image
  291. sd_image_t controlImage;
  292. controlImage.width = controlWidth;
  293. controlImage.height = controlHeight;
  294. controlImage.channel = 3; // RGB
  295. controlImage.data = const_cast<uint8_t*>(controlData.data());
  296. genParams.control_image = controlImage;
  297. // Set progress callback if provided
  298. // Track callback data to ensure proper cleanup
  299. std::pair<StableDiffusionWrapper::ProgressCallback, void*>* callbackData = nullptr;
  300. if (progressCallback) {
  301. callbackData = new std::pair<StableDiffusionWrapper::ProgressCallback, void*>(progressCallback, userData);
  302. sd_set_progress_callback([](int step, int steps, float time, void* data) {
  303. auto* callbackData = static_cast<std::pair<StableDiffusionWrapper::ProgressCallback, void*>*>(data);
  304. if (callbackData) {
  305. callbackData->first(step, steps, time, callbackData->second);
  306. }
  307. }, callbackData);
  308. }
  309. // Generate the image
  310. sd_image_t* sdImages = generate_image(sdContext, &genParams);
  311. // Clear and clean up progress callback
  312. sd_set_progress_callback(nullptr, nullptr);
  313. if (callbackData) {
  314. delete callbackData;
  315. callbackData = nullptr;
  316. }
  317. auto endTime = std::chrono::high_resolution_clock::now();
  318. auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
  319. if (!sdImages) {
  320. lastError = "Failed to generate image";
  321. return results;
  322. }
  323. // Convert stable-diffusion images to our format
  324. for (int i = 0; i < params.batchCount; i++) {
  325. StableDiffusionWrapper::GeneratedImage image;
  326. image.width = sdImages[i].width;
  327. image.height = sdImages[i].height;
  328. image.channels = sdImages[i].channel;
  329. image.seed = params.seed;
  330. image.generationTime = duration.count();
  331. // Copy image data
  332. if (sdImages[i].data && sdImages[i].width > 0 && sdImages[i].height > 0 && sdImages[i].channel > 0) {
  333. size_t dataSize = sdImages[i].width * sdImages[i].height * sdImages[i].channel;
  334. image.data.resize(dataSize);
  335. std::memcpy(image.data.data(), sdImages[i].data, dataSize);
  336. }
  337. results.push_back(image);
  338. }
  339. // Free the generated images
  340. // Clean up each image's data array
  341. for (int i = 0; i < params.batchCount; i++) {
  342. if (sdImages[i].data) {
  343. free(sdImages[i].data);
  344. sdImages[i].data = nullptr;
  345. }
  346. }
  347. // Free the image array itself
  348. free(sdImages);
  349. return results;
  350. }
  351. StableDiffusionWrapper::GeneratedImage upscaleImage(
  352. const std::string& esrganPath,
  353. const std::vector<uint8_t>& inputData,
  354. int inputWidth,
  355. int inputHeight,
  356. int inputChannels,
  357. uint32_t upscaleFactor,
  358. int nThreads,
  359. bool offloadParamsToCpu,
  360. bool direct) {
  361. StableDiffusionWrapper::GeneratedImage result;
  362. auto startTime = std::chrono::high_resolution_clock::now();
  363. // Create upscaler context
  364. upscaler_ctx_t* upscalerCtx = new_upscaler_ctx(
  365. esrganPath.c_str(),
  366. offloadParamsToCpu,
  367. direct,
  368. nThreads
  369. );
  370. if (!upscalerCtx) {
  371. lastError = "Failed to create upscaler context";
  372. return result;
  373. }
  374. // Prepare input image
  375. sd_image_t inputImage;
  376. inputImage.width = inputWidth;
  377. inputImage.height = inputHeight;
  378. inputImage.channel = inputChannels;
  379. inputImage.data = const_cast<uint8_t*>(inputData.data());
  380. // Perform upscaling
  381. sd_image_t upscaled = upscale(upscalerCtx, inputImage, upscaleFactor);
  382. auto endTime = std::chrono::high_resolution_clock::now();
  383. auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
  384. if (!upscaled.data) {
  385. lastError = "Failed to upscale image";
  386. free_upscaler_ctx(upscalerCtx);
  387. return result;
  388. }
  389. // Convert to our format
  390. result.width = upscaled.width;
  391. result.height = upscaled.height;
  392. result.channels = upscaled.channel;
  393. result.seed = 0; // No seed for upscaling
  394. result.generationTime = duration.count();
  395. // Copy image data
  396. if (upscaled.data && upscaled.width > 0 && upscaled.height > 0 && upscaled.channel > 0) {
  397. size_t dataSize = upscaled.width * upscaled.height * upscaled.channel;
  398. result.data.resize(dataSize);
  399. std::memcpy(result.data.data(), upscaled.data, dataSize);
  400. }
  401. // Clean up
  402. free_upscaler_ctx(upscalerCtx);
  403. return result;
  404. }
  405. std::string getLastError() const {
  406. return lastError;
  407. }
  408. };
  409. // Static helper functions
  410. sample_method_t StableDiffusionWrapper::stringToSamplingMethod(const std::string& method) {
  411. std::string lowerMethod = method;
  412. std::transform(lowerMethod.begin(), lowerMethod.end(), lowerMethod.begin(), ::tolower);
  413. if (lowerMethod == "euler") {
  414. return EULER;
  415. } else if (lowerMethod == "euler_a") {
  416. return EULER_A;
  417. } else if (lowerMethod == "heun") {
  418. return HEUN;
  419. } else if (lowerMethod == "dpm2") {
  420. return DPM2;
  421. } else if (lowerMethod == "dpmpp2s_a") {
  422. return DPMPP2S_A;
  423. } else if (lowerMethod == "dpmpp2m") {
  424. return DPMPP2M;
  425. } else if (lowerMethod == "dpmpp2mv2") {
  426. return DPMPP2Mv2;
  427. } else if (lowerMethod == "ipndm") {
  428. return IPNDM;
  429. } else if (lowerMethod == "ipndm_v") {
  430. return IPNDM_V;
  431. } else if (lowerMethod == "lcm") {
  432. return LCM;
  433. } else if (lowerMethod == "ddim_trailing") {
  434. return DDIM_TRAILING;
  435. } else if (lowerMethod == "tcd") {
  436. return TCD;
  437. } else {
  438. return SAMPLE_METHOD_DEFAULT;
  439. }
  440. }
  441. scheduler_t StableDiffusionWrapper::stringToScheduler(const std::string& scheduler) {
  442. std::string lowerScheduler = scheduler;
  443. std::transform(lowerScheduler.begin(), lowerScheduler.end(), lowerScheduler.begin(), ::tolower);
  444. if (lowerScheduler == "discrete") {
  445. return DISCRETE;
  446. } else if (lowerScheduler == "karras") {
  447. return KARRAS;
  448. } else if (lowerScheduler == "exponential") {
  449. return EXPONENTIAL;
  450. } else if (lowerScheduler == "ays") {
  451. return AYS;
  452. } else if (lowerScheduler == "gits") {
  453. return GITS;
  454. } else if (lowerScheduler == "smoothstep") {
  455. return SMOOTHSTEP;
  456. } else if (lowerScheduler == "sgm_uniform") {
  457. return SGM_UNIFORM;
  458. } else if (lowerScheduler == "simple") {
  459. return SIMPLE;
  460. } else {
  461. return DEFAULT;
  462. }
  463. }
  464. sd_type_t StableDiffusionWrapper::stringToModelType(const std::string& type) {
  465. std::string lowerType = type;
  466. std::transform(lowerType.begin(), lowerType.end(), lowerType.begin(), ::tolower);
  467. if (lowerType == "f32") {
  468. return SD_TYPE_F32;
  469. } else if (lowerType == "f16") {
  470. return SD_TYPE_F16;
  471. } else if (lowerType == "q4_0") {
  472. return SD_TYPE_Q4_0;
  473. } else if (lowerType == "q4_1") {
  474. return SD_TYPE_Q4_1;
  475. } else if (lowerType == "q5_0") {
  476. return SD_TYPE_Q5_0;
  477. } else if (lowerType == "q5_1") {
  478. return SD_TYPE_Q5_1;
  479. } else if (lowerType == "q8_0") {
  480. return SD_TYPE_Q8_0;
  481. } else if (lowerType == "q8_1") {
  482. return SD_TYPE_Q8_1;
  483. } else if (lowerType == "q2_k") {
  484. return SD_TYPE_Q2_K;
  485. } else if (lowerType == "q3_k") {
  486. return SD_TYPE_Q3_K;
  487. } else if (lowerType == "q4_k") {
  488. return SD_TYPE_Q4_K;
  489. } else if (lowerType == "q5_k") {
  490. return SD_TYPE_Q5_K;
  491. } else if (lowerType == "q6_k") {
  492. return SD_TYPE_Q6_K;
  493. } else if (lowerType == "q8_k") {
  494. return SD_TYPE_Q8_K;
  495. } else {
  496. return SD_TYPE_F16; // Default to F16
  497. }
  498. }
  499. // Public interface implementation
  500. StableDiffusionWrapper::StableDiffusionWrapper() : pImpl(std::make_unique<Impl>()) {
  501. // wrapperMutex is automatically initialized by std::mutex default constructor
  502. }
  503. StableDiffusionWrapper::~StableDiffusionWrapper() = default;
  504. bool StableDiffusionWrapper::loadModel(const std::string& modelPath, const GenerationParams& params) {
  505. std::lock_guard<std::mutex> lock(wrapperMutex);
  506. return pImpl->loadModel(modelPath, params);
  507. }
  508. void StableDiffusionWrapper::unloadModel() {
  509. std::lock_guard<std::mutex> lock(wrapperMutex);
  510. pImpl->unloadModel();
  511. }
  512. bool StableDiffusionWrapper::isModelLoaded() const {
  513. std::lock_guard<std::mutex> lock(wrapperMutex);
  514. return pImpl->isModelLoaded();
  515. }
  516. std::vector<StableDiffusionWrapper::GeneratedImage> StableDiffusionWrapper::generateImage(
  517. const GenerationParams& params,
  518. ProgressCallback progressCallback,
  519. void* userData) {
  520. std::lock_guard<std::mutex> lock(wrapperMutex);
  521. return pImpl->generateImage(params, progressCallback, userData);
  522. }
  523. std::vector<StableDiffusionWrapper::GeneratedImage> StableDiffusionWrapper::generateImageImg2Img(
  524. const GenerationParams& params,
  525. const std::vector<uint8_t>& inputData,
  526. int inputWidth,
  527. int inputHeight,
  528. ProgressCallback progressCallback,
  529. void* userData) {
  530. std::lock_guard<std::mutex> lock(wrapperMutex);
  531. return pImpl->generateImageImg2Img(params, inputData, inputWidth, inputHeight, progressCallback, userData);
  532. }
  533. std::vector<StableDiffusionWrapper::GeneratedImage> StableDiffusionWrapper::generateImageControlNet(
  534. const GenerationParams& params,
  535. const std::vector<uint8_t>& controlData,
  536. int controlWidth,
  537. int controlHeight,
  538. ProgressCallback progressCallback,
  539. void* userData) {
  540. std::lock_guard<std::mutex> lock(wrapperMutex);
  541. return pImpl->generateImageControlNet(params, controlData, controlWidth, controlHeight, progressCallback, userData);
  542. }
  543. StableDiffusionWrapper::GeneratedImage StableDiffusionWrapper::upscaleImage(
  544. const std::string& esrganPath,
  545. const std::vector<uint8_t>& inputData,
  546. int inputWidth,
  547. int inputHeight,
  548. int inputChannels,
  549. uint32_t upscaleFactor,
  550. int nThreads,
  551. bool offloadParamsToCpu,
  552. bool direct) {
  553. std::lock_guard<std::mutex> lock(wrapperMutex);
  554. return pImpl->upscaleImage(esrganPath, inputData, inputWidth, inputHeight, inputChannels, upscaleFactor, nThreads, offloadParamsToCpu, direct);
  555. }
  556. std::string StableDiffusionWrapper::getLastError() const {
  557. std::lock_guard<std::mutex> lock(wrapperMutex);
  558. return pImpl->getLastError();
  559. }