#ifndef STABLE_DIFFUSION_WRAPPER_H #define STABLE_DIFFUSION_WRAPPER_H #include #include #include #include #include #include // Include stable-diffusion.h to get the enum definitions extern "C" { #include "stable-diffusion.h" } /** * @brief Wrapper class for stable-diffusion.cpp functionality * * This class provides a C++ interface to the stable-diffusion.cpp library, * handling model loading, image generation, and resource management. */ class StableDiffusionWrapper { public: /** * @brief Generation parameters structure */ struct GenerationParams { // Basic parameters std::string prompt; ///< Text prompt for generation std::string negativePrompt; ///< Negative prompt (optional) // Image parameters int width; ///< Image width int height; ///< Image height int batchCount; ///< Number of images to generate // Sampling parameters int steps; ///< Number of diffusion steps float cfgScale; ///< CFG scale std::string samplingMethod; ///< Sampling method std::string scheduler; ///< Scheduler // Seed control int64_t seed; ///< Seed for generation // Model paths std::string modelPath; ///< Path to main model std::string clipLPath; ///< Path to CLIP-L model std::string clipGPath; ///< Path to CLIP-G model std::string vaePath; ///< Path to VAE model std::string taesdPath; ///< Path to TAESD model std::string controlNetPath; ///< Path to ControlNet model std::string embeddingDir; ///< Path to embeddings directory std::string loraModelDir; ///< Path to LoRA model directory // Advanced parameters int clipSkip; ///< CLIP skip layers float strength; ///< Strength for img2img float controlStrength; ///< ControlNet strength // Performance parameters int nThreads; ///< Number of threads (-1 for auto) bool offloadParamsToCpu; ///< Offload parameters to CPU bool clipOnCpu; ///< Keep CLIP on CPU bool vaeOnCpu; ///< Keep VAE on CPU bool diffusionFlashAttn; ///< Use flash attention bool diffusionConvDirect; ///< Use direct convolution bool vaeConvDirect; ///< Use direct VAE convolution // Model type std::string modelType; ///< Model type (f32, f16, q4_0, etc.) // Verbose output bool verbose = false; ///< Whether to print verbose model loading info // Constructor with default values GenerationParams() : width(512), height(512), batchCount(1), steps(20), cfgScale(7.5f), samplingMethod("euler"), scheduler("default"), seed(42), clipSkip(-1), strength(0.75f), controlStrength(0.9f), nThreads(-1), offloadParamsToCpu(false), clipOnCpu(false), vaeOnCpu(false), diffusionFlashAttn(false), diffusionConvDirect(false), vaeConvDirect(false), modelType("f16"), verbose(false) {} }; /** * @brief Generated image structure */ struct GeneratedImage { std::vector data; ///< Image data (RGB) int width; ///< Image width int height; ///< Image height int channels; ///< Number of channels (usually 3) int64_t seed; ///< Seed used for generation uint64_t generationTime; ///< Time taken for generation in milliseconds }; /** * @brief Progress callback function type * * @param step Current step * @param steps Total steps * @param time Time taken for current step * @param userData User data pointer */ using ProgressCallback = std::function; /** * @brief Construct a new Stable Diffusion Wrapper object */ StableDiffusionWrapper(); /** * @brief Destroy the Stable Diffusion Wrapper object */ ~StableDiffusionWrapper(); /** * @brief Load a stable-diffusion model * * @param modelPath Path to the model file * @param params Additional loading parameters * @return true if model was loaded successfully, false otherwise */ bool loadModel(const std::string& modelPath, const GenerationParams& params = GenerationParams{}); /** * @brief Unload the current model */ void unloadModel(); /** * @brief Check if a model is currently loaded * * @return true if a model is loaded, false otherwise */ bool isModelLoaded() const; /** * @brief Generate an image from text prompt * * @param params Generation parameters * @param progressCallback Optional progress callback * @param userData User data for progress callback * @return std::vector Generated images */ std::vector generateImage( const GenerationParams& params, ProgressCallback progressCallback = nullptr, void* userData = nullptr ); /** * @brief Generate an image from text prompt with image input (img2img) * * @param params Generation parameters * @param inputData Input image data * @param inputWidth Input image width * @param inputHeight Input image height * @param progressCallback Optional progress callback * @param userData User data for progress callback * @return std::vector Generated images */ std::vector generateImageImg2Img( const GenerationParams& params, const std::vector& inputData, int inputWidth, int inputHeight, ProgressCallback progressCallback = nullptr, void* userData = nullptr ); /** * @brief Generate an image with ControlNet * * @param params Generation parameters * @param controlData Control image data * @param controlWidth Control image width * @param controlHeight Control image height * @param progressCallback Optional progress callback * @param userData User data for progress callback * @return std::vector Generated images */ std::vector generateImageControlNet( const GenerationParams& params, const std::vector& controlData, int controlWidth, int controlHeight, ProgressCallback progressCallback = nullptr, void* userData = nullptr ); /** * @brief Generate an image with inpainting * * @param params Generation parameters * @param inputData Input image data * @param inputWidth Input image width * @param inputHeight Input image height * @param maskData Mask image data (grayscale, where white=keep, black=inpaint) * @param maskWidth Mask image width * @param maskHeight Mask image height * @param progressCallback Optional progress callback * @param userData User data for progress callback * @return std::vector Generated images */ std::vector generateImageInpainting( const GenerationParams& params, const std::vector& inputData, int inputWidth, int inputHeight, const std::vector& maskData, int maskWidth, int maskHeight, ProgressCallback progressCallback = nullptr, void* userData = nullptr ); /** * @brief Upscale an image using ESRGAN * * @param esrganPath Path to ESRGAN model * @param inputData Input image data * @param inputWidth Input image width * @param inputHeight Input image height * @param inputChannels Number of channels in input image * @param upscaleFactor Upscale factor (usually 2 or 4) * @param nThreads Number of threads (-1 for auto) * @param offloadParamsToCpu Offload parameters to CPU * @param direct Use direct mode * @return GeneratedImage Upscaled image */ GeneratedImage upscaleImage( const std::string& esrganPath, const std::vector& inputData, int inputWidth, int inputHeight, int inputChannels, uint32_t upscaleFactor, int nThreads = -1, bool offloadParamsToCpu = false, bool direct = false ); /** * @brief Get the last error message * * @return std::string Last error message */ std::string getLastError() const; /** * @brief Convert sampling method string to enum * * @param method Sampling method string * @return sample_method_t Sampling method enum */ static sample_method_t stringToSamplingMethod(const std::string& method); /** * @brief Convert scheduler string to enum * * @param scheduler Scheduler string * @return scheduler_t Scheduler enum */ static scheduler_t stringToScheduler(const std::string& scheduler); /** * @brief Convert model type string to enum * * @param type Model type string * @return sd_type_t Model type enum */ static sd_type_t stringToModelType(const std::string& type); private: class Impl; std::unique_ptr pImpl; // Pimpl idiom // Thread safety mutable std::mutex wrapperMutex; }; #endif // STABLE_DIFFUSION_WRAPPER_H