Implement hw accelerated AES

This commit is contained in:
Shawn Hoffman 2022-07-27 01:51:19 -07:00
parent fb45ed3981
commit 46ad8b9d68
13 changed files with 488 additions and 93 deletions

View File

@ -136,6 +136,12 @@ add_library(common
WorkQueueThread.h
)
if(NOT MSVC AND _M_ARM_64)
set_source_files_properties(
Crypto/AES.cpp
PROPERTIES COMPILE_FLAGS "-march=armv8-a+crypto")
endif()
target_link_libraries(common
PUBLIC
${CMAKE_THREAD_LIBS_INIT}

View File

@ -1,35 +1,411 @@
// Copyright 2017 Dolphin Emulator Project
// SPDX-License-Identifier: GPL-2.0-or-later
#include "Common/Crypto/AES.h"
#include <array>
#include <memory>
#include <mbedtls/aes.h>
#include "Common/Assert.h"
#include "Common/BitUtils.h"
#include "Common/CPUDetect.h"
#include "Common/Crypto/AES.h"
#ifdef _MSC_VER
#include <intrin.h>
#else
#if defined(_M_X86_64)
#include <x86intrin.h>
#elif defined(_M_ARM_64)
#include <arm_acle.h>
#include <arm_neon.h>
#endif
#endif
#ifdef _MSC_VER
#define ATTRIBUTE_TARGET(x)
#else
#define ATTRIBUTE_TARGET(x) [[gnu::target(x)]]
#endif
namespace Common::AES
{
std::vector<u8> DecryptEncrypt(const u8* key, u8* iv, const u8* src, size_t size, Mode mode)
// For x64 and arm64, it's very unlikely a user's cpu does not support the accelerated version,
// fallback is just in case.
template <Mode AesMode>
class ContextGeneric final : public Context
{
mbedtls_aes_context aes_ctx;
std::vector<u8> buffer(size);
public:
ContextGeneric(const u8* key)
{
mbedtls_aes_init(&ctx);
if constexpr (AesMode == Mode::Encrypt)
ASSERT(!mbedtls_aes_setkey_enc(&ctx, key, 128));
else
ASSERT(!mbedtls_aes_setkey_dec(&ctx, key, 128));
}
if (mode == Mode::Encrypt)
mbedtls_aes_setkey_enc(&aes_ctx, key, 128);
else
mbedtls_aes_setkey_dec(&aes_ctx, key, 128);
virtual bool Crypt(const u8* iv, u8* iv_out, const u8* buf_in, u8* buf_out,
size_t len) const override
{
std::array<u8, BLOCK_SIZE> iv_tmp{};
if (iv)
std::memcpy(&iv_tmp[0], iv, BLOCK_SIZE);
mbedtls_aes_crypt_cbc(&aes_ctx, mode == Mode::Encrypt ? MBEDTLS_AES_ENCRYPT : MBEDTLS_AES_DECRYPT,
size, iv, src, buffer.data());
constexpr int mode = (AesMode == Mode::Encrypt) ? MBEDTLS_AES_ENCRYPT : MBEDTLS_AES_DECRYPT;
if (mbedtls_aes_crypt_cbc(const_cast<mbedtls_aes_context*>(&ctx), mode, len, &iv_tmp[0], buf_in,
buf_out))
return false;
return buffer;
if (iv_out)
std::memcpy(iv_out, &iv_tmp[0], BLOCK_SIZE);
return true;
}
private:
mbedtls_aes_context ctx{};
};
#if defined(_M_X86_64)
// Note that (for instructions with same data width) the actual instructions emitted vary depending
// on compiler and flags. The naming is somewhat confusing, because VAES cpuid flag was added after
// VAES(VEX.128):
// clang-format off
// instructions | cpuid flag | #define
// AES(128) | AES | -
// VAES(VEX.128) | AES & AVX | __AVX__
// VAES(VEX.256) | VAES | -
// VAES(EVEX.128) | VAES & AVX512VL | __AVX512VL__
// VAES(EVEX.256) | VAES & AVX512VL | __AVX512VL__
// VAES(EVEX.512) | VAES & AVX512F | __AVX512F__
// clang-format on
template <Mode AesMode>
class ContextAESNI final : public Context
{
static inline __m128i Aes128KeygenAssistFinish(__m128i key, __m128i kga)
{
__m128i tmp = _mm_shuffle_epi32(kga, _MM_SHUFFLE(3, 3, 3, 3));
tmp = _mm_xor_si128(tmp, key);
key = _mm_slli_si128(key, 4);
tmp = _mm_xor_si128(tmp, key);
key = _mm_slli_si128(key, 4);
tmp = _mm_xor_si128(tmp, key);
key = _mm_slli_si128(key, 4);
tmp = _mm_xor_si128(tmp, key);
return tmp;
}
template <size_t RoundIdx>
ATTRIBUTE_TARGET("aes")
inline constexpr void StoreRoundKey(__m128i rk)
{
if constexpr (AesMode == Mode::Encrypt)
round_keys[RoundIdx] = rk;
else
{
constexpr size_t idx = NUM_ROUND_KEYS - RoundIdx - 1;
if constexpr (idx == 0 || idx == NUM_ROUND_KEYS - 1)
round_keys[idx] = rk;
else
round_keys[idx] = _mm_aesimc_si128(rk);
}
}
template <size_t RoundIdx, int Rcon>
ATTRIBUTE_TARGET("aes")
inline constexpr __m128i Aes128Keygen(__m128i rk)
{
rk = Aes128KeygenAssistFinish(rk, _mm_aeskeygenassist_si128(rk, Rcon));
StoreRoundKey<RoundIdx>(rk);
return rk;
}
public:
ContextAESNI(const u8* key)
{
__m128i rk = _mm_loadu_si128((const __m128i*)key);
StoreRoundKey<0>(rk);
rk = Aes128Keygen<1, 0x01>(rk);
rk = Aes128Keygen<2, 0x02>(rk);
rk = Aes128Keygen<3, 0x04>(rk);
rk = Aes128Keygen<4, 0x08>(rk);
rk = Aes128Keygen<5, 0x10>(rk);
rk = Aes128Keygen<6, 0x20>(rk);
rk = Aes128Keygen<7, 0x40>(rk);
rk = Aes128Keygen<8, 0x80>(rk);
rk = Aes128Keygen<9, 0x1b>(rk);
Aes128Keygen<10, 0x36>(rk);
}
ATTRIBUTE_TARGET("aes")
inline void CryptBlock(__m128i* iv, const u8* buf_in, u8* buf_out) const
{
__m128i block = _mm_loadu_si128((const __m128i*)buf_in);
if constexpr (AesMode == Mode::Encrypt)
{
block = _mm_xor_si128(_mm_xor_si128(block, *iv), round_keys[0]);
for (size_t i = 1; i < Nr; ++i)
block = _mm_aesenc_si128(block, round_keys[i]);
block = _mm_aesenclast_si128(block, round_keys[Nr]);
*iv = block;
}
else
{
__m128i iv_next = block;
block = _mm_xor_si128(block, round_keys[0]);
for (size_t i = 1; i < Nr; ++i)
block = _mm_aesdec_si128(block, round_keys[i]);
block = _mm_aesdeclast_si128(block, round_keys[Nr]);
block = _mm_xor_si128(block, *iv);
*iv = iv_next;
}
_mm_storeu_si128((__m128i*)buf_out, block);
}
// Takes advantage of instruction pipelining to parallelize.
template <size_t NumBlocks>
ATTRIBUTE_TARGET("aes")
inline void DecryptPipelined(__m128i* iv, const u8* buf_in, u8* buf_out) const
{
constexpr size_t Depth = NumBlocks;
__m128i block[Depth];
for (size_t d = 0; d < Depth; d++)
block[d] = _mm_loadu_si128(&((const __m128i*)buf_in)[d]);
__m128i iv_next[1 + Depth];
iv_next[0] = *iv;
for (size_t d = 0; d < Depth; d++)
iv_next[1 + d] = block[d];
for (size_t d = 0; d < Depth; d++)
block[d] = _mm_xor_si128(block[d], round_keys[0]);
// The main speedup is here
for (size_t i = 1; i < Nr; ++i)
for (size_t d = 0; d < Depth; d++)
block[d] = _mm_aesdec_si128(block[d], round_keys[i]);
for (size_t d = 0; d < Depth; d++)
block[d] = _mm_aesdeclast_si128(block[d], round_keys[Nr]);
for (size_t d = 0; d < Depth; d++)
block[d] = _mm_xor_si128(block[d], iv_next[d]);
*iv = iv_next[1 + Depth - 1];
for (size_t d = 0; d < Depth; d++)
_mm_storeu_si128(&((__m128i*)buf_out)[d], block[d]);
}
virtual bool Crypt(const u8* iv, u8* iv_out, const u8* buf_in, u8* buf_out,
size_t len) const override
{
if (len % BLOCK_SIZE)
return false;
__m128i iv_block = iv ? _mm_loadu_si128((const __m128i*)iv) : _mm_setzero_si128();
if constexpr (AesMode == Mode::Decrypt)
{
// On amd zen2...(benchmark, not real-world):
// With AES(128) instructions, BLOCK_DEPTH results in following speedup vs. non-pipelined: 4:
// 18%, 8: 22%, 9: 26%, 10-15: 31%. 16: 8% (register exhaustion). With VAES(VEX.128), 10 gives
// 36% speedup vs. its corresponding baseline. VAES(VEX.128) is ~4% faster than AES(128). The
// result is similar on zen3.
// Zen3 in general is 20% faster than zen2 in aes, and VAES(VEX.256) is 35% faster than
// zen3/VAES(VEX.128).
// It seems like VAES(VEX.256) should be faster?
// TODO Choose value at runtime based on some criteria?
constexpr size_t BLOCK_DEPTH = 10;
constexpr size_t CHUNK_LEN = BLOCK_DEPTH * BLOCK_SIZE;
while (len >= CHUNK_LEN)
{
DecryptPipelined<BLOCK_DEPTH>(&iv_block, buf_in, buf_out);
buf_in += CHUNK_LEN;
buf_out += CHUNK_LEN;
len -= CHUNK_LEN;
}
}
len /= BLOCK_SIZE;
while (len--)
{
CryptBlock(&iv_block, buf_in, buf_out);
buf_in += BLOCK_SIZE;
buf_out += BLOCK_SIZE;
}
if (iv_out)
_mm_storeu_si128((__m128i*)iv_out, iv_block);
return true;
}
private:
std::array<__m128i, NUM_ROUND_KEYS> round_keys;
};
#endif
#if defined(_M_ARM_64)
template <Mode AesMode>
class ContextNeon final : public Context
{
public:
template <size_t RoundIdx>
inline constexpr void StoreRoundKey(const u32* rk)
{
const uint8x16_t rk_block = vld1q_u32(rk);
if constexpr (AesMode == Mode::Encrypt)
round_keys[RoundIdx] = rk_block;
else
{
constexpr size_t idx = NUM_ROUND_KEYS - RoundIdx - 1;
if constexpr (idx == 0 || idx == NUM_ROUND_KEYS - 1)
round_keys[idx] = rk_block;
else
round_keys[idx] = vaesimcq_u8(rk_block);
}
}
ContextNeon(const u8* key)
{
constexpr u8 rcon[]{0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36};
std::array<u32, Nb * NUM_ROUND_KEYS> rk{};
// This uses a nice trick I've seen in wolfssl (not sure original author),
// which uses vaeseq_u8 to assist keygen.
// vaeseq_u8: op1 = SubBytes(ShiftRows(AddRoundKey(op1, op2)))
// given RotWord == ShiftRows for row 1 (rol(x,8))
// Probably not super fast (moves to/from vector regs constantly), but it is nice and simple.
std::memcpy(&rk[0], key, KEY_SIZE);
StoreRoundKey<0>(&rk[0]);
for (size_t i = 0; i < rk.size() - Nk; i += Nk)
{
const uint8x16_t enc = vaeseq_u8(vreinterpretq_u8_u32(vmovq_n_u32(rk[i + 3])), vmovq_n_u8(0));
const u32 temp = vgetq_lane_u32(vreinterpretq_u32_u8(enc), 0);
rk[i + 4] = rk[i + 0] ^ Common::RotateRight(temp, 8) ^ rcon[i / Nk];
rk[i + 5] = rk[i + 4] ^ rk[i + 1];
rk[i + 6] = rk[i + 5] ^ rk[i + 2];
rk[i + 7] = rk[i + 6] ^ rk[i + 3];
// clang-format off
// Not great
const size_t rki = 1 + i / Nk;
switch (rki)
{
case 1: StoreRoundKey< 1>(&rk[Nk * rki]); break;
case 2: StoreRoundKey< 2>(&rk[Nk * rki]); break;
case 3: StoreRoundKey< 3>(&rk[Nk * rki]); break;
case 4: StoreRoundKey< 4>(&rk[Nk * rki]); break;
case 5: StoreRoundKey< 5>(&rk[Nk * rki]); break;
case 6: StoreRoundKey< 6>(&rk[Nk * rki]); break;
case 7: StoreRoundKey< 7>(&rk[Nk * rki]); break;
case 8: StoreRoundKey< 8>(&rk[Nk * rki]); break;
case 9: StoreRoundKey< 9>(&rk[Nk * rki]); break;
case 10: StoreRoundKey<10>(&rk[Nk * rki]); break;
}
// clang-format on
}
}
inline void CryptBlock(uint8x16_t* iv, const u8* buf_in, u8* buf_out) const
{
uint8x16_t block = vld1q_u8(buf_in);
if constexpr (AesMode == Mode::Encrypt)
{
block = veorq_u8(block, *iv);
for (size_t i = 0; i < Nr - 1; ++i)
block = vaesmcq_u8(vaeseq_u8(block, round_keys[i]));
block = vaeseq_u8(block, round_keys[Nr - 1]);
block = veorq_u8(block, round_keys[Nr]);
*iv = block;
}
else
{
uint8x16_t iv_next = block;
for (size_t i = 0; i < Nr - 1; ++i)
block = vaesimcq_u8(vaesdq_u8(block, round_keys[i]));
block = vaesdq_u8(block, round_keys[Nr - 1]);
block = veorq_u8(block, round_keys[Nr]);
block = veorq_u8(block, *iv);
*iv = iv_next;
}
vst1q_u8(buf_out, block);
}
virtual bool Crypt(const u8* iv, u8* iv_out, const u8* buf_in, u8* buf_out,
size_t len) const override
{
if (len % BLOCK_SIZE)
return false;
uint8x16_t iv_block = iv ? vld1q_u8(iv) : vmovq_n_u8(0);
len /= BLOCK_SIZE;
while (len--)
{
CryptBlock(&iv_block, buf_in, buf_out);
buf_in += BLOCK_SIZE;
buf_out += BLOCK_SIZE;
}
if (iv_out)
vst1q_u8(iv_out, iv_block);
return true;
}
private:
std::array<uint8x16_t, NUM_ROUND_KEYS> round_keys;
};
#endif
template <Mode AesMode>
std::unique_ptr<Context> CreateContext(const u8* key)
{
if (cpu_info.bAES)
{
#if defined(_M_X86_64)
#if defined(__AVX__)
// If compiler enables AVX, the intrinsics will generate VAES(VEX.128) instructions.
// In the future we may want to compile the code twice and explicitly override the compiler
// flags. There doesn't seem to be much performance difference between AES(128) and
// VAES(VEX.128) at the moment, though.
if (cpu_info.bAVX)
#endif
return std::make_unique<ContextAESNI<AesMode>>(key);
#elif defined(_M_ARM_64)
return std::make_unique<ContextNeon<AesMode>>(key);
#endif
}
return std::make_unique<ContextGeneric<AesMode>>(key);
}
std::vector<u8> Decrypt(const u8* key, u8* iv, const u8* src, size_t size)
std::unique_ptr<Context> CreateContextEncrypt(const u8* key)
{
return DecryptEncrypt(key, iv, src, size, Mode::Decrypt);
return CreateContext<Mode::Encrypt>(key);
}
std::vector<u8> Encrypt(const u8* key, u8* iv, const u8* src, size_t size)
std::unique_ptr<Context> CreateContextDecrypt(const u8* key)
{
return DecryptEncrypt(key, iv, src, size, Mode::Encrypt);
return CreateContext<Mode::Decrypt>(key);
}
} // namespace Common::AES

View File

@ -3,11 +3,12 @@
#pragma once
#include <cstddef>
#include <vector>
#include <memory>
#include "Common/CommonTypes.h"
// Dolphin only uses/implements AES-128-CBC.
namespace Common::AES
{
enum class Mode
@ -15,9 +16,34 @@ enum class Mode
Decrypt,
Encrypt,
};
std::vector<u8> DecryptEncrypt(const u8* key, u8* iv, const u8* src, size_t size, Mode mode);
// Convenience functions
std::vector<u8> Decrypt(const u8* key, u8* iv, const u8* src, size_t size);
std::vector<u8> Encrypt(const u8* key, u8* iv, const u8* src, size_t size);
class Context
{
protected:
static constexpr size_t Nk = 4;
static constexpr size_t Nb = 4;
static constexpr size_t Nr = 10;
static constexpr size_t WORD_SIZE = sizeof(u32);
static constexpr size_t NUM_ROUND_KEYS = Nr + 1;
public:
static constexpr size_t KEY_SIZE = Nk * WORD_SIZE;
static constexpr size_t BLOCK_SIZE = Nb * WORD_SIZE;
Context() = default;
virtual ~Context() = default;
virtual bool Crypt(const u8* iv, u8* iv_out, const u8* buf_in, u8* buf_out, size_t len) const = 0;
bool Crypt(const u8* iv, const u8* buf_in, u8* buf_out, size_t len) const
{
return Crypt(iv, nullptr, buf_in, buf_out, len);
}
bool CryptIvZero(const u8* buf_in, u8* buf_out, size_t len) const
{
return Crypt(nullptr, nullptr, buf_in, buf_out, len);
}
};
std::unique_ptr<Context> CreateContextEncrypt(const u8* key);
std::unique_ptr<Context> CreateContextDecrypt(const u8* key);
} // namespace Common::AES

View File

@ -760,11 +760,11 @@ ReturnCode ESDevice::ExportContentData(Context& context, u32 content_fd, u8* dat
buffer.resize(Common::AlignUp(buffer.size(), 32));
std::vector<u8> output(buffer.size());
const ReturnCode decrypt_ret = m_ios.GetIOSC().Encrypt(
const ReturnCode encrypt_ret = m_ios.GetIOSC().Encrypt(
context.title_import_export.key_handle, context.title_import_export.content.iv.data(),
buffer.data(), buffer.size(), output.data(), PID_ES);
if (decrypt_ret != IPC_SUCCESS)
return decrypt_ret;
if (encrypt_ret != IPC_SUCCESS)
return encrypt_ret;
std::copy(output.cbegin(), output.cend(), data);
return IPC_SUCCESS;

View File

@ -271,10 +271,16 @@ ReturnCode IOSC::DecryptEncrypt(Common::AES::Mode mode, Handle key_handle, u8* i
if (entry->data.size() != AES128_KEY_SIZE)
return IOSC_FAIL_INTERNAL;
const std::vector<u8> data =
Common::AES::DecryptEncrypt(entry->data.data(), iv, input, size, mode);
auto key = entry->data.data();
// TODO? store enc + dec ctxs in the KeyEntry so they only need to be created once.
// This doesn't seem like a hot path, though.
std::unique_ptr<Common::AES::Context> ctx;
if (mode == Common::AES::Mode::Encrypt)
ctx = Common::AES::CreateContextEncrypt(key);
else
ctx = Common::AES::CreateContextDecrypt(key);
std::memcpy(output, data.data(), data.size());
ctx->Crypt(iv, iv, input, output, size);
return IPC_SUCCESS;
}

View File

@ -3,7 +3,6 @@
#include "Core/IOS/WFS/WFSI.h"
#include <mbedtls/aes.h>
#include <stack>
#include <string>
#include <utility>
@ -12,6 +11,7 @@
#include <fmt/format.h>
#include "Common/CommonTypes.h"
#include "Common/Crypto/AES.h"
#include "Common/FileUtil.h"
#include "Common/IOFile.h"
#include "Common/Logging/Log.h"
@ -167,8 +167,7 @@ std::optional<IPCReply> WFSIDevice::IOCtl(const IOCtlRequest& request)
break;
}
memcpy(m_aes_key, ticket.GetTitleKey(m_ios.GetIOSC()).data(), sizeof(m_aes_key));
mbedtls_aes_setkey_dec(&m_aes_ctx, m_aes_key, 128);
m_aes_ctx = Common::AES::CreateContextDecrypt(ticket.GetTitleKey(m_ios.GetIOSC()).data());
SetImportTitleIdAndGroupId(m_tmd.GetTitleId(), m_tmd.GetGroupId());
@ -224,8 +223,8 @@ std::optional<IPCReply> WFSIDevice::IOCtl(const IOCtlRequest& request)
input_size, input_ptr, content_id);
std::vector<u8> decrypted(input_size);
mbedtls_aes_crypt_cbc(&m_aes_ctx, MBEDTLS_AES_DECRYPT, input_size, m_aes_iv,
Memory::GetPointer(input_ptr), decrypted.data());
m_aes_ctx->Crypt(m_aes_iv, m_aes_iv, Memory::GetPointer(input_ptr), decrypted.data(),
input_size);
m_arc_unpacker.AddBytes(decrypted);
break;

View File

@ -4,12 +4,12 @@
#pragma once
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include <mbedtls/aes.h>
#include "Common/CommonTypes.h"
#include "Common/Crypto/AES.h"
#include "Core/IOS/Device.h"
#include "Core/IOS/ES/Formats.h"
#include "Core/IOS/IOS.h"
@ -50,7 +50,7 @@ private:
std::string m_device_name;
mbedtls_aes_context m_aes_ctx{};
std::unique_ptr<Common::AES::Context> m_aes_ctx{};
u8 m_aes_key[0x10] = {};
u8 m_aes_iv[0x10] = {};

View File

@ -171,14 +171,13 @@ std::vector<u8> NANDImporter::GetEntryData(const NANDFSTEntry& entry)
std::vector<u8> data{};
data.reserve(remaining_bytes);
auto block = std::make_unique<u8[]>(NAND_FAT_BLOCK_SIZE);
while (remaining_bytes > 0)
{
std::array<u8, 16> iv{};
std::vector<u8> block = Common::AES::Decrypt(
m_aes_key.data(), iv.data(), &m_nand[NAND_FAT_BLOCK_SIZE * sub], NAND_FAT_BLOCK_SIZE);
m_aes_ctx->CryptIvZero(&m_nand[NAND_FAT_BLOCK_SIZE * sub], block.get(), NAND_FAT_BLOCK_SIZE);
size_t size = std::min(remaining_bytes, block.size());
data.insert(data.end(), block.begin(), block.begin() + size);
size_t size = std::min(remaining_bytes, NAND_FAT_BLOCK_SIZE);
data.insert(data.end(), block.get(), block.get() + size);
remaining_bytes -= size;
sub = m_superblock->fat[sub];
@ -260,7 +259,7 @@ void NANDImporter::ExportKeys()
{
constexpr size_t NAND_AES_KEY_OFFSET = 0x158;
std::copy_n(&m_nand_keys[NAND_AES_KEY_OFFSET], m_aes_key.size(), m_aes_key.begin());
m_aes_ctx = Common::AES::CreateContextDecrypt(&m_nand_keys[NAND_AES_KEY_OFFSET]);
const std::string file_path = m_nand_root + "/keys.bin";
File::IOFile file(file_path, "wb");

View File

@ -12,6 +12,7 @@
#include <fmt/format.h>
#include "Common/CommonTypes.h"
#include "Common/Crypto/AES.h"
#include "Common/Swap.h"
namespace DiscIO
@ -74,7 +75,7 @@ private:
std::string m_nand_root;
std::vector<u8> m_nand;
std::vector<u8> m_nand_keys;
std::array<u8, 16> m_aes_key;
std::unique_ptr<Common::AES::Context> m_aes_ctx;
std::unique_ptr<NANDSuperblock> m_superblock;
std::function<void()> m_update_callback;
};

View File

@ -13,11 +13,10 @@
#include <utility>
#include <vector>
#include <mbedtls/aes.h>
#include "Common/Align.h"
#include "Common/Assert.h"
#include "Common/CommonTypes.h"
#include "Common/Crypto/AES.h"
#include "Common/Crypto/SHA1.h"
#include "Common/Logging/Log.h"
#include "Common/MsgHandler.h"
@ -159,17 +158,14 @@ bool VolumeWAD::CheckContentIntegrity(const IOS::ES::Content& content,
if (encrypted_data.size() != Common::AlignUp(content.size, 0x40))
return false;
mbedtls_aes_context context;
const std::array<u8, 16> key = ticket.GetTitleKey();
mbedtls_aes_setkey_dec(&context, key.data(), 128);
auto context = Common::AES::CreateContextDecrypt(ticket.GetTitleKey().data());
std::array<u8, 16> iv{};
iv[0] = static_cast<u8>(content.index >> 8);
iv[1] = static_cast<u8>(content.index & 0xFF);
std::vector<u8> decrypted_data(encrypted_data.size());
mbedtls_aes_crypt_cbc(&context, MBEDTLS_AES_DECRYPT, decrypted_data.size(), iv.data(),
encrypted_data.data(), decrypted_data.data());
context->Crypt(iv.data(), encrypted_data.data(), decrypted_data.data(), decrypted_data.size());
return Common::SHA1::CalculateDigest(decrypted_data.data(), content.size) == content.sha1;
}

View File

@ -16,11 +16,10 @@
#include <utility>
#include <vector>
#include <mbedtls/aes.h>
#include "Common/Align.h"
#include "Common/Assert.h"
#include "Common/CommonTypes.h"
#include "Common/Crypto/AES.h"
#include "Common/Crypto/SHA1.h"
#include "Common/Logging/Log.h"
#include "Common/Swap.h"
@ -128,14 +127,11 @@ VolumeWii::VolumeWii(std::unique_ptr<BlobReader> reader)
return h3_table;
};
auto get_key = [this, partition]() -> std::unique_ptr<mbedtls_aes_context> {
auto get_key = [this, partition]() -> std::unique_ptr<Common::AES::Context> {
const IOS::ES::TicketReader& ticket = *m_partitions[partition].ticket;
if (!ticket.IsValid())
return nullptr;
const std::array<u8, AES_KEY_SIZE> key = ticket.GetTitleKey();
std::unique_ptr<mbedtls_aes_context> aes_context = std::make_unique<mbedtls_aes_context>();
mbedtls_aes_setkey_dec(aes_context.get(), key.data(), 128);
return aes_context;
return Common::AES::CreateContextDecrypt(ticket.GetTitleKey().data());
};
auto get_file_system = [this, partition]() -> std::unique_ptr<FileSystem> {
@ -148,7 +144,7 @@ VolumeWii::VolumeWii(std::unique_ptr<BlobReader> reader)
};
m_partitions.emplace(
partition, PartitionDetails{Common::Lazy<std::unique_ptr<mbedtls_aes_context>>(get_key),
partition, PartitionDetails{Common::Lazy<std::unique_ptr<Common::AES::Context>>(get_key),
Common::Lazy<IOS::ES::TicketReader>(get_ticket),
Common::Lazy<IOS::ES::TMDReader>(get_tmd),
Common::Lazy<std::vector<u8>>(get_cert_chain),
@ -183,11 +179,11 @@ bool VolumeWii::Read(u64 offset, u64 length, u8* buffer, const Partition& partit
buffer);
}
mbedtls_aes_context* aes_context = partition_details.key->get();
auto aes_context = partition_details.key->get();
if (!aes_context)
return false;
std::vector<u8> read_buffer(BLOCK_TOTAL_SIZE);
auto read_buffer = std::make_unique<u8[]>(BLOCK_TOTAL_SIZE);
while (length > 0)
{
// Calculate offsets
@ -198,11 +194,11 @@ bool VolumeWii::Read(u64 offset, u64 length, u8* buffer, const Partition& partit
if (m_last_decrypted_block != block_offset_on_disc)
{
// Read the current block
if (!m_reader->Read(block_offset_on_disc, BLOCK_TOTAL_SIZE, read_buffer.data()))
if (!m_reader->Read(block_offset_on_disc, BLOCK_TOTAL_SIZE, read_buffer.get()))
return false;
// Decrypt the block's data
DecryptBlockData(read_buffer.data(), m_last_decrypted_block_data, aes_context);
DecryptBlockData(read_buffer.get(), m_last_decrypted_block_data, aes_context);
m_last_decrypted_block = block_offset_on_disc;
}
@ -421,19 +417,19 @@ bool VolumeWii::CheckBlockIntegrity(u64 block_index, const u8* encrypted_data,
partition_details.h3_table->size())
return false;
mbedtls_aes_context* aes_context = partition_details.key->get();
auto aes_context = partition_details.key->get();
if (!aes_context)
return false;
HashBlock hashes;
DecryptBlockHashes(encrypted_data, &hashes, aes_context);
u8 cluster_data[BLOCK_DATA_SIZE];
DecryptBlockData(encrypted_data, cluster_data, aes_context);
auto cluster_data = std::make_unique<u8[]>(BLOCK_DATA_SIZE);
DecryptBlockData(encrypted_data, cluster_data.get(), aes_context);
for (u32 hash_index = 0; hash_index < 31; ++hash_index)
{
if (Common::SHA1::CalculateDigest(cluster_data + hash_index * 0x400, 0x400) !=
if (Common::SHA1::CalculateDigest(&cluster_data[hash_index * 0x400], 0x400) !=
hashes.h0[hash_index])
return false;
}
@ -577,8 +573,7 @@ bool VolumeWii::EncryptGroup(
std::vector<std::future<void>> encryption_futures(threads);
mbedtls_aes_context aes_context;
mbedtls_aes_setkey_enc(&aes_context, key.data(), 128);
auto aes_context = Common::AES::CreateContextEncrypt(key.data());
for (size_t i = 0; i < threads; ++i)
{
@ -589,13 +584,11 @@ bool VolumeWii::EncryptGroup(
{
u8* out_ptr = out->data() + j * BLOCK_TOTAL_SIZE;
u8 iv[16] = {};
mbedtls_aes_crypt_cbc(&aes_context, MBEDTLS_AES_ENCRYPT, BLOCK_HEADER_SIZE, iv,
reinterpret_cast<u8*>(&unencrypted_hashes[j]), out_ptr);
aes_context->CryptIvZero(reinterpret_cast<u8*>(&unencrypted_hashes[j]), out_ptr,
BLOCK_HEADER_SIZE);
std::memcpy(iv, out_ptr + 0x3D0, sizeof(iv));
mbedtls_aes_crypt_cbc(&aes_context, MBEDTLS_AES_ENCRYPT, BLOCK_DATA_SIZE, iv,
unencrypted_data[j].data(), out_ptr + BLOCK_HEADER_SIZE);
aes_context->Crypt(out_ptr + 0x3D0, unencrypted_data[j].data(),
out_ptr + BLOCK_HEADER_SIZE, BLOCK_DATA_SIZE);
}
},
i * BLOCKS_PER_GROUP / threads, (i + 1) * BLOCKS_PER_GROUP / threads);
@ -607,20 +600,14 @@ bool VolumeWii::EncryptGroup(
return true;
}
void VolumeWii::DecryptBlockHashes(const u8* in, HashBlock* out, mbedtls_aes_context* aes_context)
void VolumeWii::DecryptBlockHashes(const u8* in, HashBlock* out, Common::AES::Context* aes_context)
{
std::array<u8, 16> iv;
iv.fill(0);
mbedtls_aes_crypt_cbc(aes_context, MBEDTLS_AES_DECRYPT, sizeof(HashBlock), iv.data(), in,
reinterpret_cast<u8*>(out));
aes_context->CryptIvZero(in, reinterpret_cast<u8*>(out), sizeof(HashBlock));
}
void VolumeWii::DecryptBlockData(const u8* in, u8* out, mbedtls_aes_context* aes_context)
void VolumeWii::DecryptBlockData(const u8* in, u8* out, Common::AES::Context* aes_context)
{
std::array<u8, 16> iv;
std::copy(&in[0x3d0], &in[0x3e0], iv.data());
mbedtls_aes_crypt_cbc(aes_context, MBEDTLS_AES_DECRYPT, BLOCK_DATA_SIZE, iv.data(),
&in[BLOCK_HEADER_SIZE], out);
aes_context->Crypt(&in[0x3d0], &in[sizeof(HashBlock)], out, BLOCK_DATA_SIZE);
}
} // namespace DiscIO

View File

@ -11,8 +11,6 @@
#include <string>
#include <vector>
#include <mbedtls/aes.h>
#include "Common/CommonTypes.h"
#include "Common/Crypto/SHA1.h"
#include "Common/Lazy.h"
@ -21,6 +19,8 @@
#include "DiscIO/Volume.h"
#include "DiscIO/VolumeDisc.h"
#include "Common/Crypto/AES.h"
namespace DiscIO
{
class BlobReader;
@ -34,7 +34,7 @@ enum class Platform;
class VolumeWii : public VolumeDisc
{
public:
static constexpr size_t AES_KEY_SIZE = 16;
static constexpr size_t AES_KEY_SIZE = Common::AES::Context::KEY_SIZE;
static constexpr u32 BLOCKS_PER_GROUP = 0x40;
@ -106,8 +106,8 @@ public:
const std::function<void(HashBlock hash_blocks[BLOCKS_PER_GROUP])>&
hash_exception_callback = {});
static void DecryptBlockHashes(const u8* in, HashBlock* out, mbedtls_aes_context* aes_context);
static void DecryptBlockData(const u8* in, u8* out, mbedtls_aes_context* aes_context);
static void DecryptBlockHashes(const u8* in, HashBlock* out, Common::AES::Context* aes_context);
static void DecryptBlockData(const u8* in, u8* out, Common::AES::Context* aes_context);
protected:
u32 GetOffsetShift() const override { return 2; }
@ -115,7 +115,7 @@ protected:
private:
struct PartitionDetails
{
Common::Lazy<std::unique_ptr<mbedtls_aes_context>> key;
Common::Lazy<std::unique_ptr<Common::AES::Context>> key;
Common::Lazy<IOS::ES::TicketReader> ticket;
Common::Lazy<IOS::ES::TMDReader> tmd;
Common::Lazy<std::vector<u8>> cert_chain;

View File

@ -1318,8 +1318,7 @@ WIARVZFileReader<RVZ>::ProcessAndCompress(CompressThreadState* state, CompressPa
{
const PartitionEntry& partition_entry = partition_entries[parameters.data_entry->index];
mbedtls_aes_context aes_context;
mbedtls_aes_setkey_dec(&aes_context, partition_entry.partition_key.data(), 128);
auto aes_context = Common::AES::CreateContextDecrypt(partition_entry.partition_key.data());
const u64 groups = Common::AlignUp(parameters.data.size(), VolumeWii::GROUP_TOTAL_SIZE) /
VolumeWii::GROUP_TOTAL_SIZE;
@ -1388,7 +1387,7 @@ WIARVZFileReader<RVZ>::ProcessAndCompress(CompressThreadState* state, CompressPa
{
const u64 offset_of_block = offset_of_group + j * VolumeWii::BLOCK_TOTAL_SIZE;
VolumeWii::DecryptBlockData(parameters.data.data() + offset_of_block,
state->decryption_buffer[j].data(), &aes_context);
state->decryption_buffer[j].data(), aes_context.get());
}
else
{
@ -1413,7 +1412,7 @@ WIARVZFileReader<RVZ>::ProcessAndCompress(CompressThreadState* state, CompressPa
VolumeWii::HashBlock hashes;
VolumeWii::DecryptBlockHashes(parameters.data.data() + offset_of_block, &hashes,
&aes_context);
aes_context.get());
const auto compare_hash = [&](size_t offset_in_block) {
ASSERT(offset_in_block + Common::SHA1::DIGEST_LEN <= VolumeWii::BLOCK_HEADER_SIZE);