stable_diffusion_wrapper.h 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. #ifndef STABLE_DIFFUSION_WRAPPER_H
  2. #define STABLE_DIFFUSION_WRAPPER_H
  3. #include <string>
  4. #include <vector>
  5. #include <memory>
  6. #include <mutex>
  7. #include <functional>
  8. #include <cstdint>
  9. // Include stable-diffusion.h to get the enum definitions
  10. extern "C" {
  11. #include "stable-diffusion.h"
  12. }
  13. /**
  14. * @brief Wrapper class for stable-diffusion.cpp functionality
  15. *
  16. * This class provides a C++ interface to the stable-diffusion.cpp library,
  17. * handling model loading, image generation, and resource management.
  18. */
  19. class StableDiffusionWrapper {
  20. public:
  21. /**
  22. * @brief Generation parameters structure
  23. */
  24. struct GenerationParams {
  25. // Basic parameters
  26. std::string prompt; ///< Text prompt for generation
  27. std::string negativePrompt; ///< Negative prompt (optional)
  28. // Image parameters
  29. int width; ///< Image width
  30. int height; ///< Image height
  31. int batchCount; ///< Number of images to generate
  32. // Sampling parameters
  33. int steps; ///< Number of diffusion steps
  34. float cfgScale; ///< CFG scale
  35. std::string samplingMethod; ///< Sampling method
  36. std::string scheduler; ///< Scheduler
  37. // Seed control
  38. int64_t seed; ///< Seed for generation
  39. // Model paths
  40. std::string modelPath; ///< Path to main model
  41. std::string clipLPath; ///< Path to CLIP-L model
  42. std::string clipGPath; ///< Path to CLIP-G model
  43. std::string vaePath; ///< Path to VAE model
  44. std::string taesdPath; ///< Path to TAESD model
  45. std::string controlNetPath; ///< Path to ControlNet model
  46. std::string embeddingDir; ///< Path to embeddings directory
  47. std::string loraModelDir; ///< Path to LoRA model directory
  48. // Advanced parameters
  49. int clipSkip; ///< CLIP skip layers
  50. float strength; ///< Strength for img2img
  51. float controlStrength; ///< ControlNet strength
  52. // Performance parameters
  53. int nThreads; ///< Number of threads (-1 for auto)
  54. bool offloadParamsToCpu; ///< Offload parameters to CPU
  55. bool clipOnCpu; ///< Keep CLIP on CPU
  56. bool vaeOnCpu; ///< Keep VAE on CPU
  57. bool diffusionFlashAttn; ///< Use flash attention
  58. bool diffusionConvDirect; ///< Use direct convolution
  59. bool vaeConvDirect; ///< Use direct VAE convolution
  60. // Model type
  61. std::string modelType; ///< Model type (f32, f16, q4_0, etc.)
  62. // Constructor with default values
  63. GenerationParams()
  64. : width(512), height(512), batchCount(1), steps(20), cfgScale(7.5f),
  65. samplingMethod("euler"), scheduler("default"), seed(42),
  66. clipSkip(-1), strength(0.75f), controlStrength(0.9f),
  67. nThreads(-1), offloadParamsToCpu(false), clipOnCpu(false),
  68. vaeOnCpu(false), diffusionFlashAttn(false),
  69. diffusionConvDirect(false), vaeConvDirect(false),
  70. modelType("f16") {}
  71. };
  72. /**
  73. * @brief Generated image structure
  74. */
  75. struct GeneratedImage {
  76. std::vector<uint8_t> data; ///< Image data (RGB)
  77. int width; ///< Image width
  78. int height; ///< Image height
  79. int channels; ///< Number of channels (usually 3)
  80. int64_t seed; ///< Seed used for generation
  81. uint64_t generationTime; ///< Time taken for generation in milliseconds
  82. };
  83. /**
  84. * @brief Progress callback function type
  85. *
  86. * @param step Current step
  87. * @param steps Total steps
  88. * @param time Time taken for current step
  89. * @param userData User data pointer
  90. */
  91. using ProgressCallback = std::function<void(int step, int steps, float time, void* userData)>;
  92. /**
  93. * @brief Construct a new Stable Diffusion Wrapper object
  94. */
  95. StableDiffusionWrapper();
  96. /**
  97. * @brief Destroy the Stable Diffusion Wrapper object
  98. */
  99. ~StableDiffusionWrapper();
  100. /**
  101. * @brief Load a stable-diffusion model
  102. *
  103. * @param modelPath Path to the model file
  104. * @param params Additional loading parameters
  105. * @return true if model was loaded successfully, false otherwise
  106. */
  107. bool loadModel(const std::string& modelPath, const GenerationParams& params = GenerationParams{});
  108. /**
  109. * @brief Unload the current model
  110. */
  111. void unloadModel();
  112. /**
  113. * @brief Check if a model is currently loaded
  114. *
  115. * @return true if a model is loaded, false otherwise
  116. */
  117. bool isModelLoaded() const;
  118. /**
  119. * @brief Generate an image from text prompt
  120. *
  121. * @param params Generation parameters
  122. * @param progressCallback Optional progress callback
  123. * @param userData User data for progress callback
  124. * @return std::vector<GeneratedImage> Generated images
  125. */
  126. std::vector<GeneratedImage> generateImage(
  127. const GenerationParams& params,
  128. ProgressCallback progressCallback = nullptr,
  129. void* userData = nullptr
  130. );
  131. /**
  132. * @brief Generate an image from text prompt with image input (img2img)
  133. *
  134. * @param params Generation parameters
  135. * @param inputData Input image data
  136. * @param inputWidth Input image width
  137. * @param inputHeight Input image height
  138. * @param progressCallback Optional progress callback
  139. * @param userData User data for progress callback
  140. * @return std::vector<GeneratedImage> Generated images
  141. */
  142. std::vector<GeneratedImage> generateImageImg2Img(
  143. const GenerationParams& params,
  144. const std::vector<uint8_t>& inputData,
  145. int inputWidth,
  146. int inputHeight,
  147. ProgressCallback progressCallback = nullptr,
  148. void* userData = nullptr
  149. );
  150. /**
  151. * @brief Generate an image with ControlNet
  152. *
  153. * @param params Generation parameters
  154. * @param controlData Control image data
  155. * @param controlWidth Control image width
  156. * @param controlHeight Control image height
  157. * @param progressCallback Optional progress callback
  158. * @param userData User data for progress callback
  159. * @return std::vector<GeneratedImage> Generated images
  160. */
  161. std::vector<GeneratedImage> generateImageControlNet(
  162. const GenerationParams& params,
  163. const std::vector<uint8_t>& controlData,
  164. int controlWidth,
  165. int controlHeight,
  166. ProgressCallback progressCallback = nullptr,
  167. void* userData = nullptr
  168. );
  169. /**
  170. * @brief Upscale an image using ESRGAN
  171. *
  172. * @param esrganPath Path to ESRGAN model
  173. * @param inputData Input image data
  174. * @param inputWidth Input image width
  175. * @param inputHeight Input image height
  176. * @param inputChannels Number of channels in input image
  177. * @param upscaleFactor Upscale factor (usually 2 or 4)
  178. * @param nThreads Number of threads (-1 for auto)
  179. * @param offloadParamsToCpu Offload parameters to CPU
  180. * @param direct Use direct mode
  181. * @return GeneratedImage Upscaled image
  182. */
  183. GeneratedImage upscaleImage(
  184. const std::string& esrganPath,
  185. const std::vector<uint8_t>& inputData,
  186. int inputWidth,
  187. int inputHeight,
  188. int inputChannels,
  189. uint32_t upscaleFactor,
  190. int nThreads = -1,
  191. bool offloadParamsToCpu = false,
  192. bool direct = false
  193. );
  194. /**
  195. * @brief Get the last error message
  196. *
  197. * @return std::string Last error message
  198. */
  199. std::string getLastError() const;
  200. /**
  201. * @brief Convert sampling method string to enum
  202. *
  203. * @param method Sampling method string
  204. * @return sample_method_t Sampling method enum
  205. */
  206. static sample_method_t stringToSamplingMethod(const std::string& method);
  207. /**
  208. * @brief Convert scheduler string to enum
  209. *
  210. * @param scheduler Scheduler string
  211. * @return scheduler_t Scheduler enum
  212. */
  213. static scheduler_t stringToScheduler(const std::string& scheduler);
  214. /**
  215. * @brief Convert model type string to enum
  216. *
  217. * @param type Model type string
  218. * @return sd_type_t Model type enum
  219. */
  220. static sd_type_t stringToModelType(const std::string& type);
  221. private:
  222. class Impl;
  223. std::unique_ptr<Impl> pImpl; // Pimpl idiom
  224. // Thread safety
  225. mutable std::mutex wrapperMutex;
  226. };
  227. #endif // STABLE_DIFFUSION_WRAPPER_H