stable_diffusion_wrapper.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. #ifndef STABLE_DIFFUSION_WRAPPER_H
  2. #define STABLE_DIFFUSION_WRAPPER_H
  3. #include <string>
  4. #include <vector>
  5. #include <memory>
  6. #include <mutex>
  7. #include <functional>
  8. #include <cstdint>
  9. // Include stable-diffusion.h to get the enum definitions
  10. extern "C" {
  11. #include "stable-diffusion.h"
  12. }
  13. /**
  14. * @brief Wrapper class for stable-diffusion.cpp functionality
  15. *
  16. * This class provides a C++ interface to the stable-diffusion.cpp library,
  17. * handling model loading, image generation, and resource management.
  18. */
  19. class StableDiffusionWrapper {
  20. public:
  21. /**
  22. * @brief Generation parameters structure
  23. */
  24. struct GenerationParams {
  25. // Basic parameters
  26. std::string prompt; ///< Text prompt for generation
  27. std::string negativePrompt; ///< Negative prompt (optional)
  28. // Image parameters
  29. int width; ///< Image width
  30. int height; ///< Image height
  31. int batchCount; ///< Number of images to generate
  32. // Sampling parameters
  33. int steps; ///< Number of diffusion steps
  34. float cfgScale; ///< CFG scale
  35. std::string samplingMethod; ///< Sampling method
  36. std::string scheduler; ///< Scheduler
  37. // Seed control
  38. int64_t seed; ///< Seed for generation
  39. // Model paths
  40. std::string modelPath; ///< Path to main model
  41. std::string clipLPath; ///< Path to CLIP-L model
  42. std::string clipGPath; ///< Path to CLIP-G model
  43. std::string vaePath; ///< Path to VAE model
  44. std::string taesdPath; ///< Path to TAESD model
  45. std::string controlNetPath; ///< Path to ControlNet model
  46. std::string embeddingDir; ///< Path to embeddings directory
  47. std::string loraModelDir; ///< Path to LoRA model directory
  48. // Advanced parameters
  49. int clipSkip; ///< CLIP skip layers
  50. float strength; ///< Strength for img2img
  51. float controlStrength; ///< ControlNet strength
  52. // Performance parameters
  53. int nThreads; ///< Number of threads (-1 for auto)
  54. bool offloadParamsToCpu; ///< Offload parameters to CPU
  55. bool clipOnCpu; ///< Keep CLIP on CPU
  56. bool vaeOnCpu; ///< Keep VAE on CPU
  57. bool diffusionFlashAttn; ///< Use flash attention
  58. bool diffusionConvDirect; ///< Use direct convolution
  59. bool vaeConvDirect; ///< Use direct VAE convolution
  60. // Model type
  61. std::string modelType; ///< Model type (f32, f16, q4_0, etc.)
  62. // Verbose output
  63. bool verbose = false; ///< Whether to print verbose model loading info
  64. // Constructor with default values
  65. GenerationParams()
  66. : width(512), height(512), batchCount(1), steps(20), cfgScale(7.5f),
  67. samplingMethod("euler"), scheduler("default"), seed(42),
  68. clipSkip(-1), strength(0.75f), controlStrength(0.9f),
  69. nThreads(-1), offloadParamsToCpu(false), clipOnCpu(false),
  70. vaeOnCpu(false), diffusionFlashAttn(false),
  71. diffusionConvDirect(false), vaeConvDirect(false),
  72. modelType("f16"), verbose(false) {}
  73. };
  74. /**
  75. * @brief Generated image structure
  76. */
  77. struct GeneratedImage {
  78. std::vector<uint8_t> data; ///< Image data (RGB)
  79. int width; ///< Image width
  80. int height; ///< Image height
  81. int channels; ///< Number of channels (usually 3)
  82. int64_t seed; ///< Seed used for generation
  83. uint64_t generationTime; ///< Time taken for generation in milliseconds
  84. };
  85. /**
  86. * @brief Progress callback function type
  87. *
  88. * @param step Current step
  89. * @param steps Total steps
  90. * @param time Time taken for current step
  91. * @param userData User data pointer
  92. */
  93. using ProgressCallback = std::function<void(int step, int steps, float time, void* userData)>;
  94. /**
  95. * @brief Construct a new Stable Diffusion Wrapper object
  96. */
  97. StableDiffusionWrapper();
  98. /**
  99. * @brief Destroy the Stable Diffusion Wrapper object
  100. */
  101. ~StableDiffusionWrapper();
  102. /**
  103. * @brief Load a stable-diffusion model
  104. *
  105. * @param modelPath Path to the model file
  106. * @param params Additional loading parameters
  107. * @return true if model was loaded successfully, false otherwise
  108. */
  109. bool loadModel(const std::string& modelPath, const GenerationParams& params = GenerationParams{});
  110. /**
  111. * @brief Unload the current model
  112. */
  113. void unloadModel();
  114. /**
  115. * @brief Check if a model is currently loaded
  116. *
  117. * @return true if a model is loaded, false otherwise
  118. */
  119. bool isModelLoaded() const;
  120. /**
  121. * @brief Generate an image from text prompt
  122. *
  123. * @param params Generation parameters
  124. * @param progressCallback Optional progress callback
  125. * @param userData User data for progress callback
  126. * @return std::vector<GeneratedImage> Generated images
  127. */
  128. std::vector<GeneratedImage> generateImage(
  129. const GenerationParams& params,
  130. ProgressCallback progressCallback = nullptr,
  131. void* userData = nullptr
  132. );
  133. /**
  134. * @brief Generate an image from text prompt with image input (img2img)
  135. *
  136. * @param params Generation parameters
  137. * @param inputData Input image data
  138. * @param inputWidth Input image width
  139. * @param inputHeight Input image height
  140. * @param progressCallback Optional progress callback
  141. * @param userData User data for progress callback
  142. * @return std::vector<GeneratedImage> Generated images
  143. */
  144. std::vector<GeneratedImage> generateImageImg2Img(
  145. const GenerationParams& params,
  146. const std::vector<uint8_t>& inputData,
  147. int inputWidth,
  148. int inputHeight,
  149. ProgressCallback progressCallback = nullptr,
  150. void* userData = nullptr
  151. );
  152. /**
  153. * @brief Generate an image with ControlNet
  154. *
  155. * @param params Generation parameters
  156. * @param controlData Control image data
  157. * @param controlWidth Control image width
  158. * @param controlHeight Control image height
  159. * @param progressCallback Optional progress callback
  160. * @param userData User data for progress callback
  161. * @return std::vector<GeneratedImage> Generated images
  162. */
  163. std::vector<GeneratedImage> generateImageControlNet(
  164. const GenerationParams& params,
  165. const std::vector<uint8_t>& controlData,
  166. int controlWidth,
  167. int controlHeight,
  168. ProgressCallback progressCallback = nullptr,
  169. void* userData = nullptr
  170. );
  171. /**
  172. * @brief Generate an image with inpainting
  173. *
  174. * @param params Generation parameters
  175. * @param inputData Input image data
  176. * @param inputWidth Input image width
  177. * @param inputHeight Input image height
  178. * @param maskData Mask image data (grayscale, where white=keep, black=inpaint)
  179. * @param maskWidth Mask image width
  180. * @param maskHeight Mask image height
  181. * @param progressCallback Optional progress callback
  182. * @param userData User data for progress callback
  183. * @return std::vector<GeneratedImage> Generated images
  184. */
  185. std::vector<GeneratedImage> generateImageInpainting(
  186. const GenerationParams& params,
  187. const std::vector<uint8_t>& inputData,
  188. int inputWidth,
  189. int inputHeight,
  190. const std::vector<uint8_t>& maskData,
  191. int maskWidth,
  192. int maskHeight,
  193. ProgressCallback progressCallback = nullptr,
  194. void* userData = nullptr
  195. );
  196. /**
  197. * @brief Upscale an image using ESRGAN
  198. *
  199. * @param esrganPath Path to ESRGAN model
  200. * @param inputData Input image data
  201. * @param inputWidth Input image width
  202. * @param inputHeight Input image height
  203. * @param inputChannels Number of channels in input image
  204. * @param upscaleFactor Upscale factor (usually 2 or 4)
  205. * @param nThreads Number of threads (-1 for auto)
  206. * @param offloadParamsToCpu Offload parameters to CPU
  207. * @param direct Use direct mode
  208. * @param progressCallback Optional progress callback
  209. * @param userData User data for progress callback
  210. * @return GeneratedImage Upscaled image
  211. */
  212. GeneratedImage upscaleImage(
  213. const std::string& esrganPath,
  214. const std::vector<uint8_t>& inputData,
  215. int inputWidth,
  216. int inputHeight,
  217. int inputChannels,
  218. uint32_t upscaleFactor,
  219. int nThreads = -1,
  220. bool offloadParamsToCpu = false,
  221. bool direct = false,
  222. ProgressCallback progressCallback = nullptr,
  223. void* userData = nullptr);
  224. // File-based overload for upscaler input
  225. GeneratedImage upscaleImage(
  226. const std::string& esrganPath,
  227. const std::string& inputImagePath,
  228. uint32_t upscaleFactor,
  229. int nThreads = -1,
  230. bool offloadParamsToCpu = false,
  231. bool direct = false,
  232. ProgressCallback progressCallback = nullptr,
  233. void* userData = nullptr);
  234. /**
  235. * @brief Get the last error message
  236. *
  237. * @return std::string Last error message
  238. */
  239. std::string getLastError() const;
  240. /**
  241. * @brief Convert sampling method string to enum
  242. *
  243. * @param method Sampling method string
  244. * @return sample_method_t Sampling method enum
  245. */
  246. static sample_method_t stringToSamplingMethod(const std::string& method);
  247. /**
  248. * @brief Convert scheduler string to enum
  249. *
  250. * @param scheduler Scheduler string
  251. * @return scheduler_t Scheduler enum
  252. */
  253. static scheduler_t stringToScheduler(const std::string& scheduler);
  254. /**
  255. * @brief Convert model type string to enum
  256. *
  257. * @param type Model type string
  258. * @return sd_type_t Model type enum
  259. */
  260. static sd_type_t stringToModelType(const std::string& type);
  261. private:
  262. class Impl;
  263. std::unique_ptr<Impl> pImpl; // Pimpl idiom
  264. // Thread safety
  265. mutable std::mutex wrapperMutex;
  266. };
  267. #endif // STABLE_DIFFUSION_WRAPPER_H