Get all crypto tests passing

This commit is contained in:
Ted John 2020-02-05 02:09:19 +00:00
parent e2a541bff4
commit 57a758b9c0
2 changed files with 233 additions and 157 deletions

View File

@ -1,28 +1,29 @@
#pragma region Copyright (c) 2018 OpenRCT2 Developers
#pragma region Copyright(c) 2018 OpenRCT2 Developers
/*****************************************************************************
* OpenRCT2, an open source clone of Roller Coaster Tycoon 2.
*
* OpenRCT2 is the work of many authors, a full list can be found in contributors.md
* For more information, visit https://github.com/OpenRCT2/OpenRCT2
*
* OpenRCT2 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* A full copy of the GNU General Public License can be found in licence.txt
*****************************************************************************/
* OpenRCT2, an open source clone of Roller Coaster Tycoon 2.
*
* OpenRCT2 is the work of many authors, a full list can be found in contributors.md
* For more information, visit https://github.com/OpenRCT2/OpenRCT2
*
* OpenRCT2 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* A full copy of the GNU General Public License can be found in licence.txt
*****************************************************************************/
#pragma endregion
#ifdef __USE_CNG__
#include "Crypt.h"
#include "../platform/Platform2.h"
#include "IStream.hpp"
#include <stdexcept>
#include <sstream>
#include <string>
#include <tuple>
# include "../platform/Platform2.h"
# include "Crypt.h"
# include "IStream.hpp"
# include <sstream>
# include <stdexcept>
# include <string>
# include <tuple>
// clang-format off
// CNG: Cryptography API: Next Generation (CNG)
@ -45,7 +46,7 @@ static void CngThrowOnBadStatus(const std::string_view& name, NTSTATUS status)
}
}
static void ThrowBadAllocOnNull(const void * ptr)
static void ThrowBadAllocOnNull(const void* ptr)
{
if (ptr == nullptr)
{
@ -53,18 +54,17 @@ static void ThrowBadAllocOnNull(const void * ptr)
}
}
template<typename TBase>
class CngHashAlgorithm final : public TBase
template<typename TBase> class CngHashAlgorithm final : public TBase
{
private:
const wchar_t * _algName;
const wchar_t* _algName;
BCRYPT_ALG_HANDLE _hAlg{};
BCRYPT_HASH_HANDLE _hHash{};
PBYTE _pbHashObject{};
bool _reusable{};
public:
CngHashAlgorithm(const wchar_t * algName)
CngHashAlgorithm(const wchar_t* algName)
{
// BCRYPT_HASH_REUSABLE_FLAG only available from Windows 8
_algName = algName;
@ -77,7 +77,7 @@ public:
Dispose();
}
TBase * Clear() override
TBase* Clear() override
{
if (_reusable)
{
@ -92,7 +92,7 @@ public:
return this;
}
TBase * Update(const void * data, size_t dataLen) override
TBase* Update(const void* data, size_t dataLen) override
{
auto status = BCryptHashData(_hHash, (PBYTE)data, (ULONG)dataLen, 0);
CngThrowOnBadStatus("BCryptHashData", status);
@ -144,16 +144,14 @@ class DerReader
private:
ivstream<uint8_t> _stream;
template<typename T>
T Read(std::istream& stream)
template<typename T> T Read(std::istream& stream)
{
T value;
stream.read((char*)&value, sizeof(T));
return value;
}
template<typename T>
std::vector<T> Read(std::istream& stream, size_t count)
template<typename T> std::vector<T> Read(std::istream& stream, size_t count)
{
std::vector<T> values(count);
stream.read((char*)values.data(), sizeof(T) * count);
@ -230,17 +228,17 @@ public:
auto len = ReadLength(_stream);
auto result = Read<uint8_t>(_stream, len);
// auto v = result[0];
// auto neg = (v > 127);
// auto pad = neg ? 255 : 0;
// for (size_t i = 0; i < result.size(); i++)
// {
// if (result[i] != pad)
// {
// result.erase(result.begin(), result.begin() + i);
// break;
// }
// }
auto v = result[0];
auto neg = (v > 127);
auto pad = neg ? 255 : 0;
for (size_t i = 0; i < result.size(); i++)
{
if (result[i] != pad)
{
result.erase(result.begin(), result.begin() + i);
break;
}
}
return result;
}
};
@ -251,36 +249,55 @@ private:
std::vector<uint8_t> _buffer;
public:
void WriteSequenceHeader()
void WriteSequenceHeader(size_t len)
{
_buffer.push_back(0x30);
_buffer.push_back(0x81);
_buffer.push_back(0x89);
WriteCompressedNumber(len);
}
void WriteInteger(const std::vector<uint8_t>& data)
{
_buffer.push_back(0x02);
if (data.size() < 128)
size_t dataLen = data.size();
if (dataLen > 0 && data[0] > 127)
{
_buffer.push_back((uint8_t)data.size());
// Prepend a zero to number so it isn't treated as negative
WriteCompressedNumber(dataLen + 1);
_buffer.push_back(0);
_buffer.insert(_buffer.end(), data.begin(), data.end());
}
else if (data.size() <= std::numeric_limits<uint8_t>().max())
else
{
WriteCompressedNumber(dataLen);
_buffer.insert(_buffer.end(), data.begin(), data.end());
}
}
void WriteCompressedNumber(uint64_t value)
{
if (value < 128)
{
_buffer.push_back((uint8_t)value);
}
else if (value <= std::numeric_limits<uint8_t>().max())
{
_buffer.push_back(0b10000001);
_buffer.push_back((uint8_t)data.size());
_buffer.push_back((uint8_t)value);
}
else if (data.size() <= std::numeric_limits<uint16_t>().max())
else if (value <= std::numeric_limits<uint16_t>().max())
{
_buffer.push_back(0b10000010);
_buffer.push_back((data.size() >> 8) & 0xFF);
_buffer.push_back(data.size() & 0xFF);
_buffer.push_back((value >> 8) & 0xFF);
_buffer.push_back(value & 0xFF);
}
_buffer.insert(_buffer.end(), data.begin(), data.end());
}
std::vector<uint8_t>&& Complete()
{
auto oldBuffer = std::move(_buffer);
WriteSequenceHeader(oldBuffer.size());
_buffer.insert(_buffer.end(), oldBuffer.begin(), oldBuffer.end());
return std::move(_buffer);
}
};
@ -298,14 +315,108 @@ private:
std::vector<uint8_t> Exponent2;
std::vector<uint8_t> Coefficient;
std::vector<uint8_t> PrivateExponent;
ULONG GetMagic() const
{
ULONG magic = BCRYPT_RSAPUBLIC_MAGIC;
if (!Prime1.empty() || !Prime2.empty())
magic = BCRYPT_RSAPRIVATE_MAGIC;
if (!Exponent1.empty())
magic = BCRYPT_RSAFULLPRIVATE_MAGIC;
return magic;
}
size_t GetTotalSize()
{
return Modulus.size() + Exponent.size() + Prime1.size() + Prime2.size() + Exponent1.size() + Exponent2.size()
+ Coefficient.size() + PrivateExponent.size();
}
static RsaKeyParams FromBlob(const std::vector<uint8_t>& blob)
{
RsaKeyParams result;
const auto& header = *((BCRYPT_RSAKEY_BLOB*)blob.data());
size_t offset = sizeof(BCRYPT_RSAKEY_BLOB);
result.Exponent = ReadBytes(blob, offset, header.cbPublicExp);
result.Modulus = ReadBytes(blob, offset, header.cbModulus);
if (header.Magic == BCRYPT_RSAPRIVATE_MAGIC || header.Magic == BCRYPT_RSAFULLPRIVATE_MAGIC)
{
result.Prime1 = ReadBytes(blob, offset, header.cbPrime1);
result.Prime2 = ReadBytes(blob, offset, header.cbPrime2);
}
if (header.Magic == BCRYPT_RSAFULLPRIVATE_MAGIC)
{
result.Exponent1 = ReadBytes(blob, offset, header.cbPrime1);
result.Exponent2 = ReadBytes(blob, offset, header.cbPrime2);
result.Coefficient = ReadBytes(blob, offset, header.cbPrime1);
result.PrivateExponent = ReadBytes(blob, offset, header.cbModulus);
}
return result;
}
std::vector<uint8_t> ToBlob() const
{
auto magic = GetMagic();
std::vector<uint8_t> blob(sizeof(BCRYPT_RSAKEY_BLOB));
auto& header = *((BCRYPT_RSAKEY_BLOB*)blob.data());
header.Magic = magic;
header.BitLength = (ULONG)(Modulus.size() * 8);
header.cbPublicExp = (ULONG)Exponent.size();
header.cbModulus = (ULONG)Modulus.size();
header.cbPrime1 = (ULONG)Prime1.size();
header.cbPrime2 = (ULONG)Prime2.size();
WriteBytes(blob, Exponent);
WriteBytes(blob, Modulus);
if (magic == BCRYPT_RSAPRIVATE_MAGIC || magic == BCRYPT_RSAFULLPRIVATE_MAGIC)
{
WriteBytes(blob, Prime1);
WriteBytes(blob, Prime2);
}
if (magic == BCRYPT_RSAFULLPRIVATE_MAGIC)
{
WriteBytes(blob, Exponent1);
WriteBytes(blob, Exponent2);
WriteBytes(blob, Coefficient);
WriteBytes(blob, PrivateExponent);
}
return blob;
}
private:
static std::vector<uint8_t> ReadBytes(const std::vector<uint8_t>& src, size_t& offset, size_t length)
{
std::vector<uint8_t> result;
result.insert(result.end(), src.begin() + offset, src.begin() + offset + length);
offset += length;
return result;
}
static void WriteBytes(std::vector<uint8_t>& dst, const std::vector<uint8_t>& src)
{
dst.insert(dst.end(), src.begin(), src.end());
}
};
public:
NCRYPT_KEY_HANDLE GetKeyHandle() const { return _hKey; }
BCRYPT_KEY_HANDLE GetKeyHandle() const
{
return _hKey;
}
CngRsaKey()
{
auto status = BCryptOpenAlgorithmProvider(&_hAlg, BCRYPT_RSA_ALGORITHM, NULL, 0);
CngThrowOnBadStatus("BCryptOpenAlgorithmProvider", status);
}
~CngRsaKey()
{
NCryptFreeObject(_hKey);
BCryptDestroyKey(_hKey);
BCryptCloseAlgorithmProvider(_hAlg, 0);
_hKey = {};
_hAlg = {};
}
void SetPrivate(const std::string_view& pem) override
@ -317,10 +428,13 @@ public:
derReader.ReadInteger();
params.Modulus = derReader.ReadInteger();
params.Exponent = derReader.ReadInteger();
derReader.ReadInteger();
params.PrivateExponent = derReader.ReadInteger();
params.Prime1 = derReader.ReadInteger();
params.Prime2 = derReader.ReadInteger();
_hKey = ImportKey(params);
params.Exponent1 = derReader.ReadInteger();
params.Exponent2 = derReader.ReadInteger();
params.Coefficient = derReader.ReadInteger();
ImportKey(params);
}
void SetPublic(const std::string_view& pem) override
@ -331,15 +445,14 @@ public:
derReader.ReadSequenceHeader();
params.Modulus = derReader.ReadInteger();
params.Exponent = derReader.ReadInteger();
_hKey = ImportKey(params);
ImportKey(params);
}
std::string GetPrivate() override
{
auto params = ExportKey(false);
auto params = ExportKey();
DerWriter derWriter;
derWriter.WriteSequenceHeader();
derWriter.WriteInteger({});
derWriter.WriteInteger({ 0 });
derWriter.WriteInteger(params.Modulus);
derWriter.WriteInteger(params.Exponent);
derWriter.WriteInteger(params.PrivateExponent);
@ -360,9 +473,8 @@ public:
std::string GetPublic() override
{
auto params = ExportKey(true);
auto params = ExportKey();
DerWriter derWriter;
derWriter.WriteSequenceHeader();
derWriter.WriteInteger(params.Modulus);
derWriter.WriteInteger(params.Exponent);
auto derBytes = derWriter.Complete();
@ -377,6 +489,20 @@ public:
void Generate() override
{
Reset();
try
{
auto status = BCryptGenerateKeyPair(_hAlg, &_hKey, 1024, 0);
CngThrowOnBadStatus("BCryptGenerateKeyPair", status);
status = BCryptFinalizeKeyPair(_hKey, 0);
CngThrowOnBadStatus("BCryptFinalizeKeyPair", status);
_keyBlobType = BCRYPT_RSAFULLPRIVATE_BLOB;
}
catch (const std::exception&)
{
Reset();
throw;
}
}
private:
@ -385,9 +511,38 @@ private:
static constexpr std::string_view SZ_PRIVATE_BEGIN_TOKEN = "-----BEGIN RSA PRIVATE KEY-----";
static constexpr std::string_view SZ_PRIVATE_END_TOKEN = "-----END RSA PRIVATE KEY-----";
NCRYPT_KEY_HANDLE _hKey{};
BCRYPT_KEY_HANDLE _hKey{};
BCRYPT_KEY_HANDLE _hAlg{};
LPCWSTR _keyBlobType{};
static std::vector<uint8_t> ReadPEM(const std::string_view& pem, const std::string_view& beginToken, const std::string_view& endToken)
void Reset()
{
BCryptDestroyKey(_hKey);
_hKey = {};
}
void ImportKey(const RsaKeyParams& params)
{
Reset();
auto blob = params.ToBlob();
_keyBlobType = params.GetMagic() == BCRYPT_RSAFULLPRIVATE_MAGIC ? BCRYPT_RSAFULLPRIVATE_BLOB : BCRYPT_RSAPUBLIC_BLOB;
auto status = BCryptImportKeyPair(_hAlg, NULL, _keyBlobType, &_hKey, blob.data(), (ULONG)blob.size(), 0);
CngThrowOnBadStatus("BCryptImportKeyPair", status);
}
RsaKeyParams ExportKey()
{
ULONG cbOutput{};
auto status = BCryptExportKey(_hKey, NULL, _keyBlobType, NULL, 0, &cbOutput, 0);
CngThrowOnBadStatus("BCryptExportKey", status);
std::vector<uint8_t> blob(cbOutput);
status = BCryptExportKey(_hKey, NULL, _keyBlobType, blob.data(), cbOutput, &cbOutput, 0);
CngThrowOnBadStatus("BCryptExportKey", status);
return RsaKeyParams::FromBlob(blob);
}
static std::vector<uint8_t> ReadPEM(
const std::string_view& pem, const std::string_view& beginToken, const std::string_view& endToken)
{
auto beginPos = pem.find(beginToken);
auto endPos = pem.find(endToken);
@ -455,92 +610,12 @@ private:
}
return result;
}
static NCRYPT_KEY_HANDLE ImportKey(const RsaKeyParams& params)
{
bool isPublic = params.Prime1.size() == 0;
auto blobType = isPublic ? BCRYPT_RSAPUBLIC_BLOB : BCRYPT_RSAPRIVATE_BLOB;
BCRYPT_RSAKEY_BLOB header{};
header.Magic = isPublic ? BCRYPT_RSAPUBLIC_MAGIC : BCRYPT_RSAPRIVATE_MAGIC;
header.BitLength = (ULONG)(params.Modulus.size() * 8);
header.cbPublicExp = (ULONG)params.Exponent.size();
header.cbModulus = (ULONG)params.Modulus.size();
header.cbPrime1 = (ULONG)params.Prime1.size();
header.cbPrime2 = (ULONG)params.Prime2.size();
std::vector<uint8_t> blob;
blob.insert(blob.end(), (uint8_t*)&header, (uint8_t*)(&header + 1));
blob.insert(blob.end(), params.Exponent.begin(), params.Exponent.end());
blob.insert(blob.end(), params.Modulus.begin(), params.Modulus.end());
blob.insert(blob.end(), params.Prime1.begin(), params.Prime1.end());
blob.insert(blob.end(), params.Prime2.begin(), params.Prime2.end());
NCRYPT_PROV_HANDLE hProv{};
NCRYPT_KEY_HANDLE hKey{};
auto status = NCryptOpenStorageProvider(&hProv, MS_KEY_STORAGE_PROVIDER, 0);
CngThrowOnBadStatus("NCryptOpenStorageProvider", status);
status = NCryptImportKey(hProv, NULL, blobType, NULL, &hKey, (PBYTE)blob.data(), (DWORD)blob.size(), 0);
NCryptFreeObject(hProv);
CngThrowOnBadStatus("NCryptImportKey", status);
return hKey;
}
RsaKeyParams ExportKey(bool onlyPublic)
{
auto blobType = onlyPublic ? BCRYPT_RSAPUBLIC_BLOB : BCRYPT_RSAPRIVATE_BLOB;
std::vector<uint8_t> output;
NCRYPT_PROV_HANDLE hProv{};
try
{
auto status = NCryptOpenStorageProvider(&hProv, MS_KEY_STORAGE_PROVIDER, 0);
CngThrowOnBadStatus("NCryptOpenStorageProvider", status);
DWORD cbOutput{};
status = NCryptExportKey(_hKey, NULL, blobType, NULL, NULL, 0, &cbOutput, 0);
CngThrowOnBadStatus("NCryptExportKey", status);
output = std::vector<uint8_t>(cbOutput);
status = NCryptExportKey(_hKey, NULL, blobType, NULL, output.data(), cbOutput, &cbOutput, 0);
CngThrowOnBadStatus("NCryptExportKey", status);
NCryptFreeObject(hProv);
}
catch (const std::exception&)
{
NCryptFreeObject(hProv);
throw;
}
size_t offset{};
RsaKeyParams params;
const auto& header = *((BCRYPT_RSAKEY_BLOB*)output.data());
ReadBytes(output, offset, sizeof(BCRYPT_RSAKEY_BLOB));
params.Exponent = ReadBytes(output, offset, header.cbPublicExp);
params.Modulus = ReadBytes(output, offset, header.cbModulus);
params.Prime1 = ReadBytes(output, offset, header.cbPrime1);
params.Prime2 = ReadBytes(output, offset, header.cbPrime2);
if (!onlyPublic)
{
params.Exponent1 = ReadBytes(output, offset, header.cbPrime1);
params.Exponent2 = ReadBytes(output, offset, header.cbPrime2);
params.Coefficient = ReadBytes(output, offset, header.cbPrime1);
params.PrivateExponent = ReadBytes(output, offset, header.cbModulus);
}
return params;
}
static std::vector<uint8_t> ReadBytes(std::vector<uint8_t>& src, size_t& offset, size_t length)
{
std::vector<uint8_t> result;
result.insert(result.end(), src.begin() + offset, src.begin() + offset + length);
offset += length;
return result;
}
};
class CngRsaAlgorithm final : public RsaAlgorithm
{
public:
std::vector<uint8_t> SignData(const RsaKey& key, const void * data, size_t dataLen) override
std::vector<uint8_t> SignData(const RsaKey& key, const void* data, size_t dataLen) override
{
auto hKey = static_cast<const CngRsaKey&>(key).GetKeyHandle();
auto [cbHash, pbHash] = HashData(data, dataLen);
@ -548,11 +623,12 @@ public:
try
{
BCRYPT_PKCS1_PADDING_INFO paddingInfo{ BCRYPT_SHA256_ALGORITHM };
auto status = NCryptSignHash(hKey, &paddingInfo, pbHash, cbHash, NULL, 0, &cbSignature, BCRYPT_PAD_PKCS1);
auto status = BCryptSignHash(hKey, &paddingInfo, pbHash, cbHash, NULL, 0, &cbSignature, BCRYPT_PAD_PKCS1);
CngThrowOnBadStatus("NCryptSignHash", status);
pbSignature = (PBYTE)HeapAlloc(GetProcessHeap(), 0, cbSignature);
ThrowBadAllocOnNull(pbSignature);
status = NCryptSignHash(hKey, &paddingInfo, pbHash, cbHash, pbSignature, cbSignature, &cbSignature, BCRYPT_PAD_PKCS1);
status = BCryptSignHash(
hKey, &paddingInfo, pbHash, cbHash, pbSignature, cbSignature, &cbSignature, BCRYPT_PAD_PKCS1);
CngThrowOnBadStatus("NCryptSignHash", status);
auto result = std::vector<uint8_t>(pbSignature, pbSignature + cbSignature);
@ -567,26 +643,26 @@ public:
}
}
bool VerifyData(const RsaKey& key, const void * data, size_t dataLen, const void * sig, size_t sigLen) override
bool VerifyData(const RsaKey& key, const void* data, size_t dataLen, const void* sig, size_t sigLen) override
{
auto hKey = static_cast<const CngRsaKey&>(key).GetKeyHandle();
auto [cbHash, pbHash] = HashData(data, dataLen);
auto [cbSignature, pbSignature] = ToHeap(sig, sigLen);
BCRYPT_PKCS1_PADDING_INFO paddingInfo { BCRYPT_SHA256_ALGORITHM };
auto status = NCryptVerifySignature(hKey, &paddingInfo, pbHash, cbHash, pbSignature, cbSignature, BCRYPT_PAD_PKCS1);
BCRYPT_PKCS1_PADDING_INFO paddingInfo{ BCRYPT_SHA256_ALGORITHM };
auto status = BCryptVerifySignature(hKey, &paddingInfo, pbHash, cbHash, pbSignature, cbSignature, BCRYPT_PAD_PKCS1);
HeapFree(GetProcessHeap(), 0, pbSignature);
return status == ERROR_SUCCESS;
}
private:
static std::tuple<DWORD, PBYTE> HashData(const void * data, size_t dataLen)
static std::tuple<DWORD, PBYTE> HashData(const void* data, size_t dataLen)
{
auto hash = Crypt::SHA256(data, dataLen);
return ToHeap(hash.data(), hash.size());
}
static std::tuple<DWORD, PBYTE> ToHeap(const void * data, size_t dataLen)
static std::tuple<DWORD, PBYTE> ToHeap(const void* data, size_t dataLen)
{
auto cbHash = (DWORD)dataLen;
auto pbHash = (PBYTE)HeapAlloc(GetProcessHeap(), 0, dataLen);
@ -617,6 +693,6 @@ namespace Crypt
{
return std::make_unique<CngRsaKey>();
}
}
} // namespace Crypt
#endif

View File

@ -1,5 +1,5 @@
-----BEGIN RSA PUBLIC KEY-----
MIGJAoGA5EY8kelJxFRiuNK7xjlIJBbVQ559FSBAsN6ZxgP8OswUFacthMxCbJZB
olK3nOVpxC5xt8NX613FtAE714TVL7DFFWdImlrjRH1h7thNGwZSHBHIHQzBEHGz
FcqSxY0OOyHdS6Opb5OrYeHPiiFGvcrl8SQxLwKDOzZOaI2CPXMCAwEAAQ==
MIGJAoGBAORGPJHpScRUYrjSu8Y5SCQW1UOefRUgQLDemcYD/DrMFBWnLYTMQmyW
QaJSt5zlacQucbfDV+tdxbQBO9eE1S+wxRVnSJpa40R9Ye7YTRsGUhwRyB0MwRBx
sxXKksWNDjsh3UujqW+Tq2Hhz4ohRr3K5fEkMS8Cgzs2TmiNgj1zAgMBAAE=
-----END RSA PUBLIC KEY-----