mirror of https://github.com/OpenRCT2/OpenRCT2.git
Start implementing RSA for OpenSSL
This commit is contained in:
parent
865bfb7b1b
commit
fe7e8a17de
|
@ -32,28 +32,31 @@
|
|||
#include <bcrypt.h>
|
||||
#define NT_SUCCESS(Status) (((NTSTATUS)(Status)) >= 0)
|
||||
|
||||
class CNGSha1Algorithm final : public Sha1Algorithm
|
||||
template<typename TBase>
|
||||
class CngHashAlgorithm final : public TBase
|
||||
{
|
||||
private:
|
||||
const char * _algName;
|
||||
BCRYPT_ALG_HANDLE _hAlg{};
|
||||
BCRYPT_HASH_HANDLE _hHash{};
|
||||
PBYTE _pbHashObject{};
|
||||
bool _reusable{};
|
||||
|
||||
public:
|
||||
CNGSha1Algorithm()
|
||||
CngHashAlgorithm(const char * algName)
|
||||
{
|
||||
// BCRYPT_HASH_REUSABLE_FLAG only available from Windows 8
|
||||
_algName = algName;
|
||||
_reusable = Platform::IsOSVersionAtLeast(6, 2, 0);
|
||||
Initialise();
|
||||
}
|
||||
|
||||
~CNGSha1Algorithm()
|
||||
~CngHashAlgorithm()
|
||||
{
|
||||
Dispose();
|
||||
}
|
||||
|
||||
void Clear() override
|
||||
HashAlgorithm * Clear() override
|
||||
{
|
||||
if (_reusable)
|
||||
{
|
||||
|
@ -65,15 +68,17 @@ public:
|
|||
Dispose();
|
||||
Initialise();
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
void Update(const void * data, size_t dataLen) override
|
||||
HashAlgorithm * Update(const void * data, size_t dataLen) override
|
||||
{
|
||||
auto status = BCryptHashData(_hHash, (PBYTE)data, (ULONG)dataLen, 0);
|
||||
if (!NT_SUCCESS(status))
|
||||
{
|
||||
throw std::runtime_error("BCryptHashData failed: " + std::to_string(status));
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
Result Finish() override
|
||||
|
@ -91,7 +96,7 @@ private:
|
|||
void Initialise()
|
||||
{
|
||||
auto flags = _reusable ? BCRYPT_HASH_REUSABLE_FLAG : 0;
|
||||
auto status = BCryptOpenAlgorithmProvider(&_hAlg, BCRYPT_SHA1_ALGORITHM, nullptr, flags);
|
||||
auto status = BCryptOpenAlgorithmProvider(&_hAlg, TAlg, nullptr, flags);
|
||||
if (!NT_SUCCESS(status))
|
||||
{
|
||||
throw std::runtime_error("BCryptOpenAlgorithmProvider failed: " + std::to_string(status));
|
||||
|
@ -135,14 +140,12 @@ namespace Hash
|
|||
{
|
||||
std::unique_ptr<Sha1Algorithm> CreateSHA1()
|
||||
{
|
||||
return std::make_unique<CNGSha1Algorithm>();
|
||||
return std::make_unique<CngHashAlgorithm<Sha1Algorithm>>(BCRYPT_SHA1_ALGORITHM);
|
||||
}
|
||||
|
||||
Sha1Algorithm::Result SHA1(const void * data, size_t dataLen)
|
||||
std::unique_ptr<Sha256Algorithm> CreateSHA256()
|
||||
{
|
||||
CNGSha1Algorithm sha1;
|
||||
sha1.Update(data, dataLen);
|
||||
return sha1.Finish();
|
||||
return std::make_unique<CngHashAlgorithm<Sha256Algorithm>>(BCRYPT_SHA256_ALGORITHM);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -23,17 +23,29 @@
|
|||
#include "Crypt.h"
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <openssl/evp.h>
|
||||
|
||||
class OpenSSLSha1Algorithm final : public Sha1Algorithm
|
||||
static void OpenSSLThrowOnBadStatus(const std::string_view& name, int status)
|
||||
{
|
||||
if (status != 1)
|
||||
{
|
||||
throw std::runtime_error(std::string(name) + " failed: " + std::to_string(status));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TBase>
|
||||
class OpenSSLHashAlgorithm final : public TBase
|
||||
{
|
||||
private:
|
||||
const EVP_MD * _type;
|
||||
EVP_MD_CTX * _ctx{};
|
||||
bool _initialised{};
|
||||
|
||||
public:
|
||||
OpenSSLSha1Algorithm()
|
||||
OpenSSLHashAlgorithm(const EVP_MD * type)
|
||||
{
|
||||
_type = type;
|
||||
_ctx = EVP_MD_CTX_create();
|
||||
if (_ctx == nullptr)
|
||||
{
|
||||
|
@ -41,21 +53,22 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
~OpenSSLSha1Algorithm()
|
||||
~OpenSSLHashAlgorithm()
|
||||
{
|
||||
EVP_MD_CTX_destroy(_ctx);
|
||||
}
|
||||
|
||||
void Clear() override
|
||||
TBase * Clear() override
|
||||
{
|
||||
if (EVP_DigestInit_ex(_ctx, EVP_sha1(), nullptr) <= 0)
|
||||
if (EVP_DigestInit_ex(_ctx, _type, nullptr) <= 0)
|
||||
{
|
||||
throw std::runtime_error("EVP_DigestInit_ex failed");
|
||||
}
|
||||
_initialised = true;
|
||||
return this;
|
||||
}
|
||||
|
||||
void Update(const void * data, size_t dataLen) override
|
||||
TBase * Update(const void * data, size_t dataLen) override
|
||||
{
|
||||
// Auto initialise
|
||||
if (!_initialised)
|
||||
|
@ -67,9 +80,10 @@ public:
|
|||
{
|
||||
throw std::runtime_error("EVP_DigestUpdate failed");
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
Result Finish() override
|
||||
typename TBase::Result Finish() override
|
||||
{
|
||||
if (!_initialised)
|
||||
{
|
||||
|
@ -77,7 +91,7 @@ public:
|
|||
}
|
||||
_initialised = false;
|
||||
|
||||
Result result;
|
||||
typename TBase::Result result;
|
||||
unsigned int digestSize{};
|
||||
if (EVP_DigestFinal(_ctx, result.data(), &digestSize) <= 0)
|
||||
{
|
||||
|
@ -92,18 +106,112 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class OpenSSLRsaKey final : public RsaKey
|
||||
{
|
||||
public:
|
||||
EVP_PKEY * const EvpKey{};
|
||||
|
||||
void SetPrivate(const std::string_view& pem) override { }
|
||||
void SetPublic(const std::string_view& pem) override { }
|
||||
std::string GetPrivate() override { return ""; }
|
||||
std::string GetPublic() override { return ""; }
|
||||
};
|
||||
|
||||
class OpenSSLRsaAlgorithm final : public RsaAlgorithm
|
||||
{
|
||||
public:
|
||||
std::vector<uint8_t> SignData(const RsaKey& key, const void * data, size_t dataLen) override
|
||||
{
|
||||
auto evpKey = static_cast<const OpenSSLRsaKey&>(key).EvpKey;
|
||||
EVP_MD_CTX * mdctx{};
|
||||
try
|
||||
{
|
||||
mdctx = EVP_MD_CTX_create();
|
||||
if (mdctx == nullptr)
|
||||
{
|
||||
throw std::runtime_error("EVP_MD_CTX_create failed");
|
||||
}
|
||||
|
||||
auto status = EVP_DigestSignInit(mdctx, nullptr, EVP_sha256(), nullptr, evpKey);
|
||||
OpenSSLThrowOnBadStatus("EVP_DigestSignInit failed", status);
|
||||
|
||||
status = EVP_DigestSignUpdate(mdctx, data, dataLen);
|
||||
OpenSSLThrowOnBadStatus("EVP_DigestSignUpdate failed", status);
|
||||
|
||||
// Get required length of signature
|
||||
size_t sigLen{};
|
||||
status = EVP_DigestSignFinal(mdctx, nullptr, &sigLen);
|
||||
OpenSSLThrowOnBadStatus("EVP_DigestSignFinal failed", status);
|
||||
|
||||
// Get signature
|
||||
std::vector<uint8_t> signature(sigLen);
|
||||
status = EVP_DigestSignFinal(mdctx, signature.data(), &sigLen);
|
||||
OpenSSLThrowOnBadStatus("EVP_DigestSignFinal failed", status);
|
||||
|
||||
EVP_MD_CTX_destroy(mdctx);
|
||||
return signature;
|
||||
}
|
||||
catch (const std::exception&)
|
||||
{
|
||||
EVP_MD_CTX_destroy(mdctx);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
bool VerifyData(const RsaKey& key, const void * data, size_t dataLen, const void * sig, size_t sigLen) override
|
||||
{
|
||||
auto evpKey = static_cast<const OpenSSLRsaKey&>(key).EvpKey;
|
||||
EVP_MD_CTX * mdctx{};
|
||||
try
|
||||
{
|
||||
mdctx = EVP_MD_CTX_create();
|
||||
if (mdctx == nullptr)
|
||||
{
|
||||
throw std::runtime_error("EVP_MD_CTX_create failed");
|
||||
}
|
||||
|
||||
auto status = EVP_DigestVerifyInit(mdctx, nullptr, EVP_sha256(), nullptr, evpKey);
|
||||
OpenSSLThrowOnBadStatus("EVP_DigestVerifyInit", status);
|
||||
|
||||
status = EVP_DigestVerifyUpdate(mdctx, data, dataLen);
|
||||
OpenSSLThrowOnBadStatus("EVP_DigestVerifyUpdate", status);
|
||||
|
||||
status = EVP_DigestVerifyFinal(mdctx, (uint8_t*)sig, sigLen);
|
||||
if (status != 0 && status != 1)
|
||||
{
|
||||
OpenSSLThrowOnBadStatus("EVP_DigestVerifyUpdate", status);
|
||||
}
|
||||
EVP_MD_CTX_destroy(mdctx);
|
||||
return status == 0;
|
||||
}
|
||||
catch (const std::exception&)
|
||||
{
|
||||
EVP_MD_CTX_destroy(mdctx);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
namespace Hash
|
||||
{
|
||||
std::unique_ptr<Sha1Algorithm> CreateSHA1()
|
||||
{
|
||||
return std::make_unique<OpenSSLSha1Algorithm>();
|
||||
return std::make_unique<OpenSSLHashAlgorithm<Sha1Algorithm>>(EVP_sha1());
|
||||
}
|
||||
|
||||
Sha1Algorithm::Result SHA1(const void * data, size_t dataLen)
|
||||
std::unique_ptr<Sha256Algorithm> CreateSHA256()
|
||||
{
|
||||
OpenSSLSha1Algorithm sha1;
|
||||
sha1.Update(data, dataLen);
|
||||
return sha1.Finish();
|
||||
return std::make_unique<OpenSSLHashAlgorithm<Sha256Algorithm>>(EVP_sha256());
|
||||
}
|
||||
|
||||
std::unique_ptr<RsaAlgorithm> CreateRSA()
|
||||
{
|
||||
return std::make_unique<OpenSSLRsaAlgorithm>();
|
||||
}
|
||||
|
||||
std::unique_ptr<RsaKey> CreateRSAKey()
|
||||
{
|
||||
return std::make_unique<OpenSSLRsaKey>();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <array>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
template<size_t TLength>
|
||||
class HashAlgorithm
|
||||
|
@ -26,15 +27,43 @@ public:
|
|||
typedef std::array<uint8_t, TLength> Result;
|
||||
|
||||
virtual ~HashAlgorithm() = default;
|
||||
virtual void Clear() = 0;
|
||||
virtual void Update(const void * data, size_t dataLen) = 0;
|
||||
virtual HashAlgorithm * Clear() = 0;
|
||||
virtual HashAlgorithm * Update(const void * data, size_t dataLen) = 0;
|
||||
virtual Result Finish() = 0;
|
||||
};
|
||||
|
||||
class RsaKey
|
||||
{
|
||||
public:
|
||||
virtual ~RsaKey() = default;
|
||||
virtual void SetPrivate(const std::string_view& pem) = 0;
|
||||
virtual void SetPublic(const std::string_view& pem) = 0;
|
||||
virtual std::string GetPrivate() = 0;
|
||||
virtual std::string GetPublic() = 0;
|
||||
};
|
||||
|
||||
class RsaAlgorithm
|
||||
{
|
||||
public:
|
||||
virtual ~RsaAlgorithm() = default;
|
||||
virtual std::vector<uint8_t> SignData(const RsaKey& key, const void * data, size_t dataLen) = 0;
|
||||
virtual bool VerifyData(const RsaKey& key, const void * data, size_t dataLen, const void * sig, size_t sigLen) = 0;
|
||||
};
|
||||
|
||||
using Sha1Algorithm = HashAlgorithm<20>;
|
||||
using Sha256Algorithm = HashAlgorithm<32>;
|
||||
|
||||
namespace Hash
|
||||
{
|
||||
std::unique_ptr<Sha1Algorithm> CreateSHA1();
|
||||
Sha1Algorithm::Result SHA1(const void * data, size_t dataLen);
|
||||
std::unique_ptr<Sha256Algorithm> CreateSHA256();
|
||||
std::unique_ptr<RsaAlgorithm> CreateRSA();
|
||||
std::unique_ptr<RsaKey> CreateRSAKey();
|
||||
|
||||
Sha1Algorithm::Result SHA1(const void * data, size_t dataLen)
|
||||
{
|
||||
return CreateSHA1()
|
||||
->Update(data, dataLen)
|
||||
->Finish();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue