#include "jwt_auth.h" #include #include #include #include #include #include #include #include JWTAuth::JWTAuth(const std::string& secret, int tokenExpirationMinutes, const std::string& issuer, const std::string& audience) : m_secret(secret) , m_tokenExpirationMinutes(tokenExpirationMinutes) , m_issuer(issuer) , m_audience(audience) { // Generate random secret if not provided if (m_secret.empty()) { m_secret = generateRandomString(64); } } JWTAuth::~JWTAuth() = default; std::string JWTAuth::generateToken(const std::string& userId, const std::string& username, const std::string& role, const std::vector& permissions) { try { // Create claims Claims claims; claims.userId = userId; claims.username = username; claims.role = role; claims.permissions = permissions; claims.issuedAt = getCurrentTimestamp(); claims.expiresAt = claims.issuedAt + (m_tokenExpirationMinutes * 60); claims.issuer = m_issuer; claims.audience = m_audience; // Create header and payload std::string header = createHeader(); std::string payload = createPayload(claims); // Create signature std::string signature = createSignature(header, payload); // Combine parts return header + "." + payload + "." + signature; } catch (const std::exception& e) { return ""; } } JWTAuth::AuthResult JWTAuth::validateToken(const std::string& token) { AuthResult result; result.success = false; try { // Split token auto parts = splitToken(token); if (parts.size() != 3) { result.errorMessage = "Invalid token format"; result.errorCode = "INVALID_TOKEN_FORMAT"; return result; } const std::string& header = parts[0]; const std::string& payload = parts[1]; const std::string& signature = parts[2]; // Verify signature if (!verifySignature(header, payload, signature)) { result.errorMessage = "Invalid token signature"; result.errorCode = "INVALID_SIGNATURE"; return result; } // Parse payload Claims claims = parsePayload(token); if (claims.userId.empty()) { result.errorMessage = "Invalid token payload"; result.errorCode = "INVALID_PAYLOAD"; return result; } // Check expiration if (getCurrentTimestamp() >= claims.expiresAt) { result.errorMessage = "Token has expired"; result.errorCode = "TOKEN_EXPIRED"; return result; } // Check issuer if (!claims.issuer.empty() && claims.issuer != m_issuer) { result.errorMessage = "Invalid token issuer"; result.errorCode = "INVALID_ISSUER"; return result; } // Token is valid result.success = true; result.userId = claims.userId; result.username = claims.username; result.role = claims.role; result.permissions = claims.permissions; } catch (const std::exception& e) { result.errorMessage = "Token validation failed: " + std::string(e.what()); result.errorCode = "VALIDATION_ERROR"; } return result; } std::string JWTAuth::refreshToken(const std::string& token) { try { // Validate current token AuthResult result = validateToken(token); if (!result.success) { return ""; } // Generate new token with same claims return generateToken(result.userId, result.username, result.role, result.permissions); } catch (const std::exception& e) { return ""; } } std::string JWTAuth::extractTokenFromHeader(const std::string& authHeader) { if (authHeader.empty()) { return ""; } // Check for "Bearer " prefix const std::string bearerPrefix = "Bearer "; if (authHeader.length() > bearerPrefix.length() && authHeader.substr(0, bearerPrefix.length()) == bearerPrefix) { return authHeader.substr(bearerPrefix.length()); } return ""; } bool JWTAuth::hasPermission(const std::vector& permissions, const std::string& requiredPermission) { return std::find(permissions.begin(), permissions.end(), requiredPermission) != permissions.end(); } bool JWTAuth::hasAnyPermission(const std::vector& permissions, const std::vector& requiredPermissions) { for (const auto& permission : requiredPermissions) { if (hasPermission(permissions, permission)) { return true; } } return false; } int64_t JWTAuth::getTokenExpiration(const std::string& token) { try { Claims claims = parsePayload(token); return claims.expiresAt; } catch (const std::exception& e) { return 0; } } bool JWTAuth::isTokenExpired(const std::string& token) { int64_t expiration = getTokenExpiration(token); return expiration > 0 && getCurrentTimestamp() >= expiration; } void JWTAuth::setTokenExpiration(int minutes) { m_tokenExpirationMinutes = minutes; } int JWTAuth::getTokenExpiration() const { return m_tokenExpirationMinutes; } void JWTAuth::setIssuer(const std::string& issuer) { m_issuer = issuer; } std::string JWTAuth::getIssuer() const { return m_issuer; } std::string JWTAuth::generateApiKey(int length) { return generateRandomString(length); } bool JWTAuth::validateApiKeyFormat(const std::string& apiKey) { if (apiKey.length() < 16 || apiKey.length() > 128) { return false; } // Check for alphanumeric characters only for (char c : apiKey) { if (!std::isalnum(c)) { return false; } } return true; } std::string JWTAuth::base64UrlEncode(const std::string& input) { const std::string base64Chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; std::string result; int val = 0, valb = -6; for (unsigned char c : input) { val = (val << 8) + c; valb += 8; while (valb >= 0) { result.push_back(base64Chars[(val >> valb) & 0x3F]); valb -= 6; } } if (valb > -6) { result.push_back(base64Chars[((val << 8) >> (valb + 8)) & 0x3F]); } return result; } std::string JWTAuth::base64UrlDecode(const std::string& input) { const std::string base64Chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; std::string result; int val = 0, valb = -8; for (char c : input) { if (c == '=') continue; size_t pos = base64Chars.find(c); if (pos == std::string::npos) return ""; val = (val << 6) + pos; valb += 6; if (valb >= 0) { result.push_back(char((val >> valb) & 0xFF)); valb -= 8; } } return result; } std::string JWTAuth::createHeader() const { nlohmann::json header = { {"alg", "HS256"}, {"typ", "JWT"} }; return base64UrlEncode(header.dump()); } std::string JWTAuth::createPayload(const Claims& claims) const { nlohmann::json payload = { {"sub", claims.userId}, {"username", claims.username}, {"role", claims.role}, {"iat", claims.issuedAt}, {"exp", claims.expiresAt}, {"iss", claims.issuer}, {"aud", claims.audience} }; // Add permissions if not empty if (!claims.permissions.empty()) { payload["permissions"] = claims.permissions; } return base64UrlEncode(payload.dump()); } JWTAuth::Claims JWTAuth::parsePayload(const std::string& token) const { Claims claims; try { auto parts = splitToken(token); if (parts.size() != 3) { return claims; } std::string payloadStr = base64UrlDecode(parts[1]); nlohmann::json payload = nlohmann::json::parse(payloadStr); claims.userId = payload.value("sub", ""); claims.username = payload.value("username", ""); claims.role = payload.value("role", ""); claims.issuedAt = payload.value("iat", 0); claims.expiresAt = payload.value("exp", 0); claims.issuer = payload.value("iss", ""); claims.audience = payload.value("aud", ""); if (payload.contains("permissions") && payload["permissions"].is_array()) { for (const auto& perm : payload["permissions"]) { claims.permissions.push_back(perm.get()); } } } catch (const std::exception& e) { // Return empty claims on error } return claims; } std::string JWTAuth::createSignature(const std::string& header, const std::string& payload) const { std::string data = header + "." + payload; unsigned char* digest = HMAC(EVP_sha256(), m_secret.c_str(), m_secret.length(), (unsigned char*)data.c_str(), data.length(), nullptr, nullptr); return base64UrlEncode(std::string((char*)digest, SHA256_DIGEST_LENGTH)); } bool JWTAuth::verifySignature(const std::string& header, const std::string& payload, const std::string& signature) const { std::string expectedSignature = createSignature(header, payload); return expectedSignature == signature; } std::vector JWTAuth::splitToken(const std::string& token) { std::vector parts; std::stringstream ss(token); std::string part; while (std::getline(ss, part, '.')) { parts.push_back(part); } return parts; } int64_t JWTAuth::getCurrentTimestamp() { return std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()).count(); } std::string JWTAuth::generateRandomString(int length) { const std::string chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; std::string result; result.reserve(length); for (int i = 0; i < length; ++i) { unsigned char randomByte; if (RAND_bytes(&randomByte, 1) != 1) { return ""; } result += chars[randomByte % chars.length()]; } return result; }