Start implementing RSA for OpenSSL

This commit is contained in:
Ted John 2018-05-25 00:39:27 +01:00
parent 865bfb7b1b
commit fe7e8a17de
3 changed files with 167 additions and 27 deletions

View File

@ -32,28 +32,31 @@
#include <bcrypt.h>
#define NT_SUCCESS(Status) (((NTSTATUS)(Status)) >= 0)
class CNGSha1Algorithm final : public Sha1Algorithm
template<typename TBase>
class CngHashAlgorithm final : public TBase
{
private:
const char * _algName;
BCRYPT_ALG_HANDLE _hAlg{};
BCRYPT_HASH_HANDLE _hHash{};
PBYTE _pbHashObject{};
bool _reusable{};
public:
CNGSha1Algorithm()
CngHashAlgorithm(const char * algName)
{
// BCRYPT_HASH_REUSABLE_FLAG only available from Windows 8
_algName = algName;
_reusable = Platform::IsOSVersionAtLeast(6, 2, 0);
Initialise();
}
~CNGSha1Algorithm()
~CngHashAlgorithm()
{
Dispose();
}
void Clear() override
HashAlgorithm * Clear() override
{
if (_reusable)
{
@ -65,15 +68,17 @@ public:
Dispose();
Initialise();
}
return this;
}
void Update(const void * data, size_t dataLen) override
HashAlgorithm * Update(const void * data, size_t dataLen) override
{
auto status = BCryptHashData(_hHash, (PBYTE)data, (ULONG)dataLen, 0);
if (!NT_SUCCESS(status))
{
throw std::runtime_error("BCryptHashData failed: " + std::to_string(status));
}
return this;
}
Result Finish() override
@ -91,7 +96,7 @@ private:
void Initialise()
{
auto flags = _reusable ? BCRYPT_HASH_REUSABLE_FLAG : 0;
auto status = BCryptOpenAlgorithmProvider(&_hAlg, BCRYPT_SHA1_ALGORITHM, nullptr, flags);
auto status = BCryptOpenAlgorithmProvider(&_hAlg, TAlg, nullptr, flags);
if (!NT_SUCCESS(status))
{
throw std::runtime_error("BCryptOpenAlgorithmProvider failed: " + std::to_string(status));
@ -135,14 +140,12 @@ namespace Hash
{
std::unique_ptr<Sha1Algorithm> CreateSHA1()
{
return std::make_unique<CNGSha1Algorithm>();
return std::make_unique<CngHashAlgorithm<Sha1Algorithm>>(BCRYPT_SHA1_ALGORITHM);
}
Sha1Algorithm::Result SHA1(const void * data, size_t dataLen)
std::unique_ptr<Sha256Algorithm> CreateSHA256()
{
CNGSha1Algorithm sha1;
sha1.Update(data, dataLen);
return sha1.Finish();
return std::make_unique<CngHashAlgorithm<Sha256Algorithm>>(BCRYPT_SHA256_ALGORITHM);
}
}

View File

@ -23,17 +23,29 @@
#include "Crypt.h"
#include <stdexcept>
#include <string>
#include <vector>
#include <openssl/evp.h>
class OpenSSLSha1Algorithm final : public Sha1Algorithm
static void OpenSSLThrowOnBadStatus(const std::string_view& name, int status)
{
if (status != 1)
{
throw std::runtime_error(std::string(name) + " failed: " + std::to_string(status));
}
}
template<typename TBase>
class OpenSSLHashAlgorithm final : public TBase
{
private:
const EVP_MD * _type;
EVP_MD_CTX * _ctx{};
bool _initialised{};
public:
OpenSSLSha1Algorithm()
OpenSSLHashAlgorithm(const EVP_MD * type)
{
_type = type;
_ctx = EVP_MD_CTX_create();
if (_ctx == nullptr)
{
@ -41,21 +53,22 @@ public:
}
}
~OpenSSLSha1Algorithm()
~OpenSSLHashAlgorithm()
{
EVP_MD_CTX_destroy(_ctx);
}
void Clear() override
TBase * Clear() override
{
if (EVP_DigestInit_ex(_ctx, EVP_sha1(), nullptr) <= 0)
if (EVP_DigestInit_ex(_ctx, _type, nullptr) <= 0)
{
throw std::runtime_error("EVP_DigestInit_ex failed");
}
_initialised = true;
return this;
}
void Update(const void * data, size_t dataLen) override
TBase * Update(const void * data, size_t dataLen) override
{
// Auto initialise
if (!_initialised)
@ -67,9 +80,10 @@ public:
{
throw std::runtime_error("EVP_DigestUpdate failed");
}
return this;
}
Result Finish() override
typename TBase::Result Finish() override
{
if (!_initialised)
{
@ -77,7 +91,7 @@ public:
}
_initialised = false;
Result result;
typename TBase::Result result;
unsigned int digestSize{};
if (EVP_DigestFinal(_ctx, result.data(), &digestSize) <= 0)
{
@ -92,18 +106,112 @@ public:
}
};
class OpenSSLRsaKey final : public RsaKey
{
public:
EVP_PKEY * const EvpKey{};
void SetPrivate(const std::string_view& pem) override { }
void SetPublic(const std::string_view& pem) override { }
std::string GetPrivate() override { return ""; }
std::string GetPublic() override { return ""; }
};
class OpenSSLRsaAlgorithm final : public RsaAlgorithm
{
public:
std::vector<uint8_t> SignData(const RsaKey& key, const void * data, size_t dataLen) override
{
auto evpKey = static_cast<const OpenSSLRsaKey&>(key).EvpKey;
EVP_MD_CTX * mdctx{};
try
{
mdctx = EVP_MD_CTX_create();
if (mdctx == nullptr)
{
throw std::runtime_error("EVP_MD_CTX_create failed");
}
auto status = EVP_DigestSignInit(mdctx, nullptr, EVP_sha256(), nullptr, evpKey);
OpenSSLThrowOnBadStatus("EVP_DigestSignInit failed", status);
status = EVP_DigestSignUpdate(mdctx, data, dataLen);
OpenSSLThrowOnBadStatus("EVP_DigestSignUpdate failed", status);
// Get required length of signature
size_t sigLen{};
status = EVP_DigestSignFinal(mdctx, nullptr, &sigLen);
OpenSSLThrowOnBadStatus("EVP_DigestSignFinal failed", status);
// Get signature
std::vector<uint8_t> signature(sigLen);
status = EVP_DigestSignFinal(mdctx, signature.data(), &sigLen);
OpenSSLThrowOnBadStatus("EVP_DigestSignFinal failed", status);
EVP_MD_CTX_destroy(mdctx);
return signature;
}
catch (const std::exception&)
{
EVP_MD_CTX_destroy(mdctx);
throw;
}
}
bool VerifyData(const RsaKey& key, const void * data, size_t dataLen, const void * sig, size_t sigLen) override
{
auto evpKey = static_cast<const OpenSSLRsaKey&>(key).EvpKey;
EVP_MD_CTX * mdctx{};
try
{
mdctx = EVP_MD_CTX_create();
if (mdctx == nullptr)
{
throw std::runtime_error("EVP_MD_CTX_create failed");
}
auto status = EVP_DigestVerifyInit(mdctx, nullptr, EVP_sha256(), nullptr, evpKey);
OpenSSLThrowOnBadStatus("EVP_DigestVerifyInit", status);
status = EVP_DigestVerifyUpdate(mdctx, data, dataLen);
OpenSSLThrowOnBadStatus("EVP_DigestVerifyUpdate", status);
status = EVP_DigestVerifyFinal(mdctx, (uint8_t*)sig, sigLen);
if (status != 0 && status != 1)
{
OpenSSLThrowOnBadStatus("EVP_DigestVerifyUpdate", status);
}
EVP_MD_CTX_destroy(mdctx);
return status == 0;
}
catch (const std::exception&)
{
EVP_MD_CTX_destroy(mdctx);
throw;
}
}
};
namespace Hash
{
std::unique_ptr<Sha1Algorithm> CreateSHA1()
{
return std::make_unique<OpenSSLSha1Algorithm>();
return std::make_unique<OpenSSLHashAlgorithm<Sha1Algorithm>>(EVP_sha1());
}
Sha1Algorithm::Result SHA1(const void * data, size_t dataLen)
std::unique_ptr<Sha256Algorithm> CreateSHA256()
{
OpenSSLSha1Algorithm sha1;
sha1.Update(data, dataLen);
return sha1.Finish();
return std::make_unique<OpenSSLHashAlgorithm<Sha256Algorithm>>(EVP_sha256());
}
std::unique_ptr<RsaAlgorithm> CreateRSA()
{
return std::make_unique<OpenSSLRsaAlgorithm>();
}
std::unique_ptr<RsaKey> CreateRSAKey()
{
return std::make_unique<OpenSSLRsaKey>();
}
}

View File

@ -18,6 +18,7 @@
#include <array>
#include <memory>
#include <vector>
template<size_t TLength>
class HashAlgorithm
@ -26,15 +27,43 @@ public:
typedef std::array<uint8_t, TLength> Result;
virtual ~HashAlgorithm() = default;
virtual void Clear() = 0;
virtual void Update(const void * data, size_t dataLen) = 0;
virtual HashAlgorithm * Clear() = 0;
virtual HashAlgorithm * Update(const void * data, size_t dataLen) = 0;
virtual Result Finish() = 0;
};
class RsaKey
{
public:
virtual ~RsaKey() = default;
virtual void SetPrivate(const std::string_view& pem) = 0;
virtual void SetPublic(const std::string_view& pem) = 0;
virtual std::string GetPrivate() = 0;
virtual std::string GetPublic() = 0;
};
class RsaAlgorithm
{
public:
virtual ~RsaAlgorithm() = default;
virtual std::vector<uint8_t> SignData(const RsaKey& key, const void * data, size_t dataLen) = 0;
virtual bool VerifyData(const RsaKey& key, const void * data, size_t dataLen, const void * sig, size_t sigLen) = 0;
};
using Sha1Algorithm = HashAlgorithm<20>;
using Sha256Algorithm = HashAlgorithm<32>;
namespace Hash
{
std::unique_ptr<Sha1Algorithm> CreateSHA1();
Sha1Algorithm::Result SHA1(const void * data, size_t dataLen);
std::unique_ptr<Sha256Algorithm> CreateSHA256();
std::unique_ptr<RsaAlgorithm> CreateRSA();
std::unique_ptr<RsaKey> CreateRSAKey();
Sha1Algorithm::Result SHA1(const void * data, size_t dataLen)
{
return CreateSHA1()
->Update(data, dataLen)
->Finish();
}
}