| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299 |
- #ifndef STABLE_DIFFUSION_WRAPPER_H
- #define STABLE_DIFFUSION_WRAPPER_H
- #include <string>
- #include <vector>
- #include <memory>
- #include <mutex>
- #include <functional>
- #include <cstdint>
- // Include stable-diffusion.h to get the enum definitions
- extern "C" {
- #include "stable-diffusion.h"
- }
- /**
- * @brief Wrapper class for stable-diffusion.cpp functionality
- *
- * This class provides a C++ interface to the stable-diffusion.cpp library,
- * handling model loading, image generation, and resource management.
- */
- class StableDiffusionWrapper {
- public:
- /**
- * @brief Generation parameters structure
- */
- struct GenerationParams {
- // Basic parameters
- std::string prompt; ///< Text prompt for generation
- std::string negativePrompt; ///< Negative prompt (optional)
- // Image parameters
- int width; ///< Image width
- int height; ///< Image height
- int batchCount; ///< Number of images to generate
- // Sampling parameters
- int steps; ///< Number of diffusion steps
- float cfgScale; ///< CFG scale
- std::string samplingMethod; ///< Sampling method
- std::string scheduler; ///< Scheduler
- // Seed control
- int64_t seed; ///< Seed for generation
- // Model paths
- std::string modelPath; ///< Path to main model
- std::string clipLPath; ///< Path to CLIP-L model
- std::string clipGPath; ///< Path to CLIP-G model
- std::string vaePath; ///< Path to VAE model
- std::string taesdPath; ///< Path to TAESD model
- std::string controlNetPath; ///< Path to ControlNet model
- std::string embeddingDir; ///< Path to embeddings directory
- std::string loraModelDir; ///< Path to LoRA model directory
- // Advanced parameters
- int clipSkip; ///< CLIP skip layers
- float strength; ///< Strength for img2img
- float controlStrength; ///< ControlNet strength
- // Performance parameters
- int nThreads; ///< Number of threads (-1 for auto)
- bool offloadParamsToCpu; ///< Offload parameters to CPU
- bool clipOnCpu; ///< Keep CLIP on CPU
- bool vaeOnCpu; ///< Keep VAE on CPU
- bool diffusionFlashAttn; ///< Use flash attention
- bool diffusionConvDirect; ///< Use direct convolution
- bool vaeConvDirect; ///< Use direct VAE convolution
- // Model type
- std::string modelType; ///< Model type (f32, f16, q4_0, etc.)
- // Verbose output
- bool verbose = false; ///< Whether to print verbose model loading info
- // Constructor with default values
- GenerationParams()
- : width(512), height(512), batchCount(1), steps(20), cfgScale(7.5f),
- samplingMethod("euler"), scheduler("default"), seed(42),
- clipSkip(-1), strength(0.75f), controlStrength(0.9f),
- nThreads(-1), offloadParamsToCpu(false), clipOnCpu(false),
- vaeOnCpu(false), diffusionFlashAttn(false),
- diffusionConvDirect(false), vaeConvDirect(false),
- modelType("f16"), verbose(false) {}
- };
- /**
- * @brief Generated image structure
- */
- struct GeneratedImage {
- std::vector<uint8_t> data; ///< Image data (RGB)
- int width; ///< Image width
- int height; ///< Image height
- int channels; ///< Number of channels (usually 3)
- int64_t seed; ///< Seed used for generation
- uint64_t generationTime; ///< Time taken for generation in milliseconds
- };
- /**
- * @brief Progress callback function type
- *
- * @param step Current step
- * @param steps Total steps
- * @param time Time taken for current step
- * @param userData User data pointer
- */
- using ProgressCallback = std::function<void(int step, int steps, float time, void* userData)>;
- /**
- * @brief Construct a new Stable Diffusion Wrapper object
- */
- StableDiffusionWrapper();
- /**
- * @brief Destroy the Stable Diffusion Wrapper object
- */
- ~StableDiffusionWrapper();
- /**
- * @brief Load a stable-diffusion model
- *
- * @param modelPath Path to the model file
- * @param params Additional loading parameters
- * @return true if model was loaded successfully, false otherwise
- */
- bool loadModel(const std::string& modelPath, const GenerationParams& params = GenerationParams{});
- /**
- * @brief Unload the current model
- */
- void unloadModel();
- /**
- * @brief Check if a model is currently loaded
- *
- * @return true if a model is loaded, false otherwise
- */
- bool isModelLoaded() const;
- /**
- * @brief Generate an image from text prompt
- *
- * @param params Generation parameters
- * @param progressCallback Optional progress callback
- * @param userData User data for progress callback
- * @return std::vector<GeneratedImage> Generated images
- */
- std::vector<GeneratedImage> generateImage(
- const GenerationParams& params,
- ProgressCallback progressCallback = nullptr,
- void* userData = nullptr
- );
- /**
- * @brief Generate an image from text prompt with image input (img2img)
- *
- * @param params Generation parameters
- * @param inputData Input image data
- * @param inputWidth Input image width
- * @param inputHeight Input image height
- * @param progressCallback Optional progress callback
- * @param userData User data for progress callback
- * @return std::vector<GeneratedImage> Generated images
- */
- std::vector<GeneratedImage> generateImageImg2Img(
- const GenerationParams& params,
- const std::vector<uint8_t>& inputData,
- int inputWidth,
- int inputHeight,
- ProgressCallback progressCallback = nullptr,
- void* userData = nullptr
- );
- /**
- * @brief Generate an image with ControlNet
- *
- * @param params Generation parameters
- * @param controlData Control image data
- * @param controlWidth Control image width
- * @param controlHeight Control image height
- * @param progressCallback Optional progress callback
- * @param userData User data for progress callback
- * @return std::vector<GeneratedImage> Generated images
- */
- std::vector<GeneratedImage> generateImageControlNet(
- const GenerationParams& params,
- const std::vector<uint8_t>& controlData,
- int controlWidth,
- int controlHeight,
- ProgressCallback progressCallback = nullptr,
- void* userData = nullptr
- );
- /**
- * @brief Generate an image with inpainting
- *
- * @param params Generation parameters
- * @param inputData Input image data
- * @param inputWidth Input image width
- * @param inputHeight Input image height
- * @param maskData Mask image data (grayscale, where white=keep, black=inpaint)
- * @param maskWidth Mask image width
- * @param maskHeight Mask image height
- * @param progressCallback Optional progress callback
- * @param userData User data for progress callback
- * @return std::vector<GeneratedImage> Generated images
- */
- std::vector<GeneratedImage> generateImageInpainting(
- const GenerationParams& params,
- const std::vector<uint8_t>& inputData,
- int inputWidth,
- int inputHeight,
- const std::vector<uint8_t>& maskData,
- int maskWidth,
- int maskHeight,
- ProgressCallback progressCallback = nullptr,
- void* userData = nullptr
- );
- /**
- * @brief Upscale an image using ESRGAN
- *
- * @param esrganPath Path to ESRGAN model
- * @param inputData Input image data
- * @param inputWidth Input image width
- * @param inputHeight Input image height
- * @param inputChannels Number of channels in input image
- * @param upscaleFactor Upscale factor (usually 2 or 4)
- * @param nThreads Number of threads (-1 for auto)
- * @param offloadParamsToCpu Offload parameters to CPU
- * @param direct Use direct mode
- * @param progressCallback Optional progress callback
- * @param userData User data for progress callback
- * @return GeneratedImage Upscaled image
- */
- GeneratedImage upscaleImage(
- const std::string& esrganPath,
- const std::vector<uint8_t>& inputData,
- int inputWidth,
- int inputHeight,
- int inputChannels,
- uint32_t upscaleFactor,
- int nThreads = -1,
- bool offloadParamsToCpu = false,
- bool direct = false,
- ProgressCallback progressCallback = nullptr,
- void* userData = nullptr);
- // File-based overload for upscaler input
- GeneratedImage upscaleImage(
- const std::string& esrganPath,
- const std::string& inputImagePath,
- uint32_t upscaleFactor,
- int nThreads = -1,
- bool offloadParamsToCpu = false,
- bool direct = false,
- ProgressCallback progressCallback = nullptr,
- void* userData = nullptr);
- /**
- * @brief Get the last error message
- *
- * @return std::string Last error message
- */
- std::string getLastError() const;
- /**
- * @brief Convert sampling method string to enum
- *
- * @param method Sampling method string
- * @return sample_method_t Sampling method enum
- */
- static sample_method_t stringToSamplingMethod(const std::string& method);
- /**
- * @brief Convert scheduler string to enum
- *
- * @param scheduler Scheduler string
- * @return scheduler_t Scheduler enum
- */
- static scheduler_t stringToScheduler(const std::string& scheduler);
- /**
- * @brief Convert model type string to enum
- *
- * @param type Model type string
- * @return sd_type_t Model type enum
- */
- static sd_type_t stringToModelType(const std::string& type);
- private:
- class Impl;
- std::unique_ptr<Impl> pImpl; // Pimpl idiom
- // Thread safety
- mutable std::mutex wrapperMutex;
- };
- #endif // STABLE_DIFFUSION_WRAPPER_H
|