From 46ad8b9d68b7eff720a6f33e6029ddd36ab8edaa Mon Sep 17 00:00:00 2001 From: Shawn Hoffman Date: Wed, 27 Jul 2022 01:51:19 -0700 Subject: [PATCH] Implement hw accelerated AES --- Source/Core/Common/CMakeLists.txt | 6 + Source/Core/Common/Crypto/AES.cpp | 406 +++++++++++++++++++- Source/Core/Common/Crypto/AES.h | 38 +- Source/Core/Core/IOS/ES/TitleManagement.cpp | 6 +- Source/Core/Core/IOS/IOSC.cpp | 12 +- Source/Core/Core/IOS/WFS/WFSI.cpp | 9 +- Source/Core/Core/IOS/WFS/WFSI.h | 6 +- Source/Core/DiscIO/NANDImporter.cpp | 11 +- Source/Core/DiscIO/NANDImporter.h | 3 +- Source/Core/DiscIO/VolumeWad.cpp | 10 +- Source/Core/DiscIO/VolumeWii.cpp | 55 +-- Source/Core/DiscIO/VolumeWii.h | 12 +- Source/Core/DiscIO/WIABlob.cpp | 7 +- 13 files changed, 488 insertions(+), 93 deletions(-) diff --git a/Source/Core/Common/CMakeLists.txt b/Source/Core/Common/CMakeLists.txt index 92b47f0f66..f9e5109d17 100644 --- a/Source/Core/Common/CMakeLists.txt +++ b/Source/Core/Common/CMakeLists.txt @@ -136,6 +136,12 @@ add_library(common WorkQueueThread.h ) +if(NOT MSVC AND _M_ARM_64) + set_source_files_properties( + Crypto/AES.cpp + PROPERTIES COMPILE_FLAGS "-march=armv8-a+crypto") +endif() + target_link_libraries(common PUBLIC ${CMAKE_THREAD_LIBS_INIT} diff --git a/Source/Core/Common/Crypto/AES.cpp b/Source/Core/Common/Crypto/AES.cpp index 5412969c7a..7c67766ed4 100644 --- a/Source/Core/Common/Crypto/AES.cpp +++ b/Source/Core/Common/Crypto/AES.cpp @@ -1,35 +1,411 @@ // Copyright 2017 Dolphin Emulator Project // SPDX-License-Identifier: GPL-2.0-or-later -#include "Common/Crypto/AES.h" +#include +#include #include +#include "Common/Assert.h" +#include "Common/BitUtils.h" +#include "Common/CPUDetect.h" +#include "Common/Crypto/AES.h" + +#ifdef _MSC_VER +#include +#else +#if defined(_M_X86_64) +#include +#elif defined(_M_ARM_64) +#include +#include +#endif +#endif + +#ifdef _MSC_VER +#define ATTRIBUTE_TARGET(x) +#else +#define ATTRIBUTE_TARGET(x) [[gnu::target(x)]] +#endif + namespace Common::AES { -std::vector DecryptEncrypt(const u8* key, u8* iv, const u8* src, size_t size, Mode mode) +// For x64 and arm64, it's very unlikely a user's cpu does not support the accelerated version, +// fallback is just in case. +template +class ContextGeneric final : public Context { - mbedtls_aes_context aes_ctx; - std::vector buffer(size); +public: + ContextGeneric(const u8* key) + { + mbedtls_aes_init(&ctx); + if constexpr (AesMode == Mode::Encrypt) + ASSERT(!mbedtls_aes_setkey_enc(&ctx, key, 128)); + else + ASSERT(!mbedtls_aes_setkey_dec(&ctx, key, 128)); + } - if (mode == Mode::Encrypt) - mbedtls_aes_setkey_enc(&aes_ctx, key, 128); - else - mbedtls_aes_setkey_dec(&aes_ctx, key, 128); + virtual bool Crypt(const u8* iv, u8* iv_out, const u8* buf_in, u8* buf_out, + size_t len) const override + { + std::array iv_tmp{}; + if (iv) + std::memcpy(&iv_tmp[0], iv, BLOCK_SIZE); - mbedtls_aes_crypt_cbc(&aes_ctx, mode == Mode::Encrypt ? MBEDTLS_AES_ENCRYPT : MBEDTLS_AES_DECRYPT, - size, iv, src, buffer.data()); + constexpr int mode = (AesMode == Mode::Encrypt) ? MBEDTLS_AES_ENCRYPT : MBEDTLS_AES_DECRYPT; + if (mbedtls_aes_crypt_cbc(const_cast(&ctx), mode, len, &iv_tmp[0], buf_in, + buf_out)) + return false; - return buffer; + if (iv_out) + std::memcpy(iv_out, &iv_tmp[0], BLOCK_SIZE); + return true; + } + +private: + mbedtls_aes_context ctx{}; +}; + +#if defined(_M_X86_64) + +// Note that (for instructions with same data width) the actual instructions emitted vary depending +// on compiler and flags. The naming is somewhat confusing, because VAES cpuid flag was added after +// VAES(VEX.128): +// clang-format off +// instructions | cpuid flag | #define +// AES(128) | AES | - +// VAES(VEX.128) | AES & AVX | __AVX__ +// VAES(VEX.256) | VAES | - +// VAES(EVEX.128) | VAES & AVX512VL | __AVX512VL__ +// VAES(EVEX.256) | VAES & AVX512VL | __AVX512VL__ +// VAES(EVEX.512) | VAES & AVX512F | __AVX512F__ +// clang-format on +template +class ContextAESNI final : public Context +{ + static inline __m128i Aes128KeygenAssistFinish(__m128i key, __m128i kga) + { + __m128i tmp = _mm_shuffle_epi32(kga, _MM_SHUFFLE(3, 3, 3, 3)); + tmp = _mm_xor_si128(tmp, key); + + key = _mm_slli_si128(key, 4); + tmp = _mm_xor_si128(tmp, key); + key = _mm_slli_si128(key, 4); + tmp = _mm_xor_si128(tmp, key); + key = _mm_slli_si128(key, 4); + tmp = _mm_xor_si128(tmp, key); + return tmp; + } + + template + ATTRIBUTE_TARGET("aes") + inline constexpr void StoreRoundKey(__m128i rk) + { + if constexpr (AesMode == Mode::Encrypt) + round_keys[RoundIdx] = rk; + else + { + constexpr size_t idx = NUM_ROUND_KEYS - RoundIdx - 1; + if constexpr (idx == 0 || idx == NUM_ROUND_KEYS - 1) + round_keys[idx] = rk; + else + round_keys[idx] = _mm_aesimc_si128(rk); + } + } + + template + ATTRIBUTE_TARGET("aes") + inline constexpr __m128i Aes128Keygen(__m128i rk) + { + rk = Aes128KeygenAssistFinish(rk, _mm_aeskeygenassist_si128(rk, Rcon)); + StoreRoundKey(rk); + return rk; + } + +public: + ContextAESNI(const u8* key) + { + __m128i rk = _mm_loadu_si128((const __m128i*)key); + StoreRoundKey<0>(rk); + rk = Aes128Keygen<1, 0x01>(rk); + rk = Aes128Keygen<2, 0x02>(rk); + rk = Aes128Keygen<3, 0x04>(rk); + rk = Aes128Keygen<4, 0x08>(rk); + rk = Aes128Keygen<5, 0x10>(rk); + rk = Aes128Keygen<6, 0x20>(rk); + rk = Aes128Keygen<7, 0x40>(rk); + rk = Aes128Keygen<8, 0x80>(rk); + rk = Aes128Keygen<9, 0x1b>(rk); + Aes128Keygen<10, 0x36>(rk); + } + + ATTRIBUTE_TARGET("aes") + inline void CryptBlock(__m128i* iv, const u8* buf_in, u8* buf_out) const + { + __m128i block = _mm_loadu_si128((const __m128i*)buf_in); + + if constexpr (AesMode == Mode::Encrypt) + { + block = _mm_xor_si128(_mm_xor_si128(block, *iv), round_keys[0]); + + for (size_t i = 1; i < Nr; ++i) + block = _mm_aesenc_si128(block, round_keys[i]); + block = _mm_aesenclast_si128(block, round_keys[Nr]); + + *iv = block; + } + else + { + __m128i iv_next = block; + + block = _mm_xor_si128(block, round_keys[0]); + + for (size_t i = 1; i < Nr; ++i) + block = _mm_aesdec_si128(block, round_keys[i]); + block = _mm_aesdeclast_si128(block, round_keys[Nr]); + + block = _mm_xor_si128(block, *iv); + *iv = iv_next; + } + + _mm_storeu_si128((__m128i*)buf_out, block); + } + + // Takes advantage of instruction pipelining to parallelize. + template + ATTRIBUTE_TARGET("aes") + inline void DecryptPipelined(__m128i* iv, const u8* buf_in, u8* buf_out) const + { + constexpr size_t Depth = NumBlocks; + + __m128i block[Depth]; + for (size_t d = 0; d < Depth; d++) + block[d] = _mm_loadu_si128(&((const __m128i*)buf_in)[d]); + + __m128i iv_next[1 + Depth]; + iv_next[0] = *iv; + for (size_t d = 0; d < Depth; d++) + iv_next[1 + d] = block[d]; + + for (size_t d = 0; d < Depth; d++) + block[d] = _mm_xor_si128(block[d], round_keys[0]); + + // The main speedup is here + for (size_t i = 1; i < Nr; ++i) + for (size_t d = 0; d < Depth; d++) + block[d] = _mm_aesdec_si128(block[d], round_keys[i]); + for (size_t d = 0; d < Depth; d++) + block[d] = _mm_aesdeclast_si128(block[d], round_keys[Nr]); + + for (size_t d = 0; d < Depth; d++) + block[d] = _mm_xor_si128(block[d], iv_next[d]); + *iv = iv_next[1 + Depth - 1]; + + for (size_t d = 0; d < Depth; d++) + _mm_storeu_si128(&((__m128i*)buf_out)[d], block[d]); + } + + virtual bool Crypt(const u8* iv, u8* iv_out, const u8* buf_in, u8* buf_out, + size_t len) const override + { + if (len % BLOCK_SIZE) + return false; + + __m128i iv_block = iv ? _mm_loadu_si128((const __m128i*)iv) : _mm_setzero_si128(); + + if constexpr (AesMode == Mode::Decrypt) + { + // On amd zen2...(benchmark, not real-world): + // With AES(128) instructions, BLOCK_DEPTH results in following speedup vs. non-pipelined: 4: + // 18%, 8: 22%, 9: 26%, 10-15: 31%. 16: 8% (register exhaustion). With VAES(VEX.128), 10 gives + // 36% speedup vs. its corresponding baseline. VAES(VEX.128) is ~4% faster than AES(128). The + // result is similar on zen3. + // Zen3 in general is 20% faster than zen2 in aes, and VAES(VEX.256) is 35% faster than + // zen3/VAES(VEX.128). + // It seems like VAES(VEX.256) should be faster? + // TODO Choose value at runtime based on some criteria? + constexpr size_t BLOCK_DEPTH = 10; + constexpr size_t CHUNK_LEN = BLOCK_DEPTH * BLOCK_SIZE; + while (len >= CHUNK_LEN) + { + DecryptPipelined(&iv_block, buf_in, buf_out); + buf_in += CHUNK_LEN; + buf_out += CHUNK_LEN; + len -= CHUNK_LEN; + } + } + + len /= BLOCK_SIZE; + while (len--) + { + CryptBlock(&iv_block, buf_in, buf_out); + buf_in += BLOCK_SIZE; + buf_out += BLOCK_SIZE; + } + + if (iv_out) + _mm_storeu_si128((__m128i*)iv_out, iv_block); + + return true; + } + +private: + std::array<__m128i, NUM_ROUND_KEYS> round_keys; +}; + +#endif + +#if defined(_M_ARM_64) + +template +class ContextNeon final : public Context +{ +public: + template + inline constexpr void StoreRoundKey(const u32* rk) + { + const uint8x16_t rk_block = vld1q_u32(rk); + if constexpr (AesMode == Mode::Encrypt) + round_keys[RoundIdx] = rk_block; + else + { + constexpr size_t idx = NUM_ROUND_KEYS - RoundIdx - 1; + if constexpr (idx == 0 || idx == NUM_ROUND_KEYS - 1) + round_keys[idx] = rk_block; + else + round_keys[idx] = vaesimcq_u8(rk_block); + } + } + + ContextNeon(const u8* key) + { + constexpr u8 rcon[]{0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36}; + std::array rk{}; + + // This uses a nice trick I've seen in wolfssl (not sure original author), + // which uses vaeseq_u8 to assist keygen. + // vaeseq_u8: op1 = SubBytes(ShiftRows(AddRoundKey(op1, op2))) + // given RotWord == ShiftRows for row 1 (rol(x,8)) + // Probably not super fast (moves to/from vector regs constantly), but it is nice and simple. + + std::memcpy(&rk[0], key, KEY_SIZE); + StoreRoundKey<0>(&rk[0]); + for (size_t i = 0; i < rk.size() - Nk; i += Nk) + { + const uint8x16_t enc = vaeseq_u8(vreinterpretq_u8_u32(vmovq_n_u32(rk[i + 3])), vmovq_n_u8(0)); + const u32 temp = vgetq_lane_u32(vreinterpretq_u32_u8(enc), 0); + rk[i + 4] = rk[i + 0] ^ Common::RotateRight(temp, 8) ^ rcon[i / Nk]; + rk[i + 5] = rk[i + 4] ^ rk[i + 1]; + rk[i + 6] = rk[i + 5] ^ rk[i + 2]; + rk[i + 7] = rk[i + 6] ^ rk[i + 3]; + // clang-format off + // Not great + const size_t rki = 1 + i / Nk; + switch (rki) + { + case 1: StoreRoundKey< 1>(&rk[Nk * rki]); break; + case 2: StoreRoundKey< 2>(&rk[Nk * rki]); break; + case 3: StoreRoundKey< 3>(&rk[Nk * rki]); break; + case 4: StoreRoundKey< 4>(&rk[Nk * rki]); break; + case 5: StoreRoundKey< 5>(&rk[Nk * rki]); break; + case 6: StoreRoundKey< 6>(&rk[Nk * rki]); break; + case 7: StoreRoundKey< 7>(&rk[Nk * rki]); break; + case 8: StoreRoundKey< 8>(&rk[Nk * rki]); break; + case 9: StoreRoundKey< 9>(&rk[Nk * rki]); break; + case 10: StoreRoundKey<10>(&rk[Nk * rki]); break; + } + // clang-format on + } + } + + inline void CryptBlock(uint8x16_t* iv, const u8* buf_in, u8* buf_out) const + { + uint8x16_t block = vld1q_u8(buf_in); + + if constexpr (AesMode == Mode::Encrypt) + { + block = veorq_u8(block, *iv); + + for (size_t i = 0; i < Nr - 1; ++i) + block = vaesmcq_u8(vaeseq_u8(block, round_keys[i])); + block = vaeseq_u8(block, round_keys[Nr - 1]); + block = veorq_u8(block, round_keys[Nr]); + + *iv = block; + } + else + { + uint8x16_t iv_next = block; + + for (size_t i = 0; i < Nr - 1; ++i) + block = vaesimcq_u8(vaesdq_u8(block, round_keys[i])); + block = vaesdq_u8(block, round_keys[Nr - 1]); + block = veorq_u8(block, round_keys[Nr]); + + block = veorq_u8(block, *iv); + *iv = iv_next; + } + + vst1q_u8(buf_out, block); + } + + virtual bool Crypt(const u8* iv, u8* iv_out, const u8* buf_in, u8* buf_out, + size_t len) const override + { + if (len % BLOCK_SIZE) + return false; + + uint8x16_t iv_block = iv ? vld1q_u8(iv) : vmovq_n_u8(0); + + len /= BLOCK_SIZE; + while (len--) + { + CryptBlock(&iv_block, buf_in, buf_out); + buf_in += BLOCK_SIZE; + buf_out += BLOCK_SIZE; + } + + if (iv_out) + vst1q_u8(iv_out, iv_block); + + return true; + } + +private: + std::array round_keys; +}; + +#endif + +template +std::unique_ptr CreateContext(const u8* key) +{ + if (cpu_info.bAES) + { +#if defined(_M_X86_64) +#if defined(__AVX__) + // If compiler enables AVX, the intrinsics will generate VAES(VEX.128) instructions. + // In the future we may want to compile the code twice and explicitly override the compiler + // flags. There doesn't seem to be much performance difference between AES(128) and + // VAES(VEX.128) at the moment, though. + if (cpu_info.bAVX) +#endif + return std::make_unique>(key); +#elif defined(_M_ARM_64) + return std::make_unique>(key); +#endif + } + return std::make_unique>(key); } -std::vector Decrypt(const u8* key, u8* iv, const u8* src, size_t size) +std::unique_ptr CreateContextEncrypt(const u8* key) { - return DecryptEncrypt(key, iv, src, size, Mode::Decrypt); + return CreateContext(key); } -std::vector Encrypt(const u8* key, u8* iv, const u8* src, size_t size) +std::unique_ptr CreateContextDecrypt(const u8* key) { - return DecryptEncrypt(key, iv, src, size, Mode::Encrypt); + return CreateContext(key); } + } // namespace Common::AES diff --git a/Source/Core/Common/Crypto/AES.h b/Source/Core/Common/Crypto/AES.h index 71ea3601d7..338b71bc85 100644 --- a/Source/Core/Common/Crypto/AES.h +++ b/Source/Core/Common/Crypto/AES.h @@ -3,11 +3,12 @@ #pragma once -#include -#include +#include #include "Common/CommonTypes.h" +// Dolphin only uses/implements AES-128-CBC. + namespace Common::AES { enum class Mode @@ -15,9 +16,34 @@ enum class Mode Decrypt, Encrypt, }; -std::vector DecryptEncrypt(const u8* key, u8* iv, const u8* src, size_t size, Mode mode); -// Convenience functions -std::vector Decrypt(const u8* key, u8* iv, const u8* src, size_t size); -std::vector Encrypt(const u8* key, u8* iv, const u8* src, size_t size); +class Context +{ +protected: + static constexpr size_t Nk = 4; + static constexpr size_t Nb = 4; + static constexpr size_t Nr = 10; + static constexpr size_t WORD_SIZE = sizeof(u32); + static constexpr size_t NUM_ROUND_KEYS = Nr + 1; + +public: + static constexpr size_t KEY_SIZE = Nk * WORD_SIZE; + static constexpr size_t BLOCK_SIZE = Nb * WORD_SIZE; + + Context() = default; + virtual ~Context() = default; + virtual bool Crypt(const u8* iv, u8* iv_out, const u8* buf_in, u8* buf_out, size_t len) const = 0; + bool Crypt(const u8* iv, const u8* buf_in, u8* buf_out, size_t len) const + { + return Crypt(iv, nullptr, buf_in, buf_out, len); + } + bool CryptIvZero(const u8* buf_in, u8* buf_out, size_t len) const + { + return Crypt(nullptr, nullptr, buf_in, buf_out, len); + } +}; + +std::unique_ptr CreateContextEncrypt(const u8* key); +std::unique_ptr CreateContextDecrypt(const u8* key); + } // namespace Common::AES diff --git a/Source/Core/Core/IOS/ES/TitleManagement.cpp b/Source/Core/Core/IOS/ES/TitleManagement.cpp index 216e54855d..df4ff8fd05 100644 --- a/Source/Core/Core/IOS/ES/TitleManagement.cpp +++ b/Source/Core/Core/IOS/ES/TitleManagement.cpp @@ -760,11 +760,11 @@ ReturnCode ESDevice::ExportContentData(Context& context, u32 content_fd, u8* dat buffer.resize(Common::AlignUp(buffer.size(), 32)); std::vector output(buffer.size()); - const ReturnCode decrypt_ret = m_ios.GetIOSC().Encrypt( + const ReturnCode encrypt_ret = m_ios.GetIOSC().Encrypt( context.title_import_export.key_handle, context.title_import_export.content.iv.data(), buffer.data(), buffer.size(), output.data(), PID_ES); - if (decrypt_ret != IPC_SUCCESS) - return decrypt_ret; + if (encrypt_ret != IPC_SUCCESS) + return encrypt_ret; std::copy(output.cbegin(), output.cend(), data); return IPC_SUCCESS; diff --git a/Source/Core/Core/IOS/IOSC.cpp b/Source/Core/Core/IOS/IOSC.cpp index 8ec6af4269..0e659ebb29 100644 --- a/Source/Core/Core/IOS/IOSC.cpp +++ b/Source/Core/Core/IOS/IOSC.cpp @@ -271,10 +271,16 @@ ReturnCode IOSC::DecryptEncrypt(Common::AES::Mode mode, Handle key_handle, u8* i if (entry->data.size() != AES128_KEY_SIZE) return IOSC_FAIL_INTERNAL; - const std::vector data = - Common::AES::DecryptEncrypt(entry->data.data(), iv, input, size, mode); + auto key = entry->data.data(); + // TODO? store enc + dec ctxs in the KeyEntry so they only need to be created once. + // This doesn't seem like a hot path, though. + std::unique_ptr ctx; + if (mode == Common::AES::Mode::Encrypt) + ctx = Common::AES::CreateContextEncrypt(key); + else + ctx = Common::AES::CreateContextDecrypt(key); - std::memcpy(output, data.data(), data.size()); + ctx->Crypt(iv, iv, input, output, size); return IPC_SUCCESS; } diff --git a/Source/Core/Core/IOS/WFS/WFSI.cpp b/Source/Core/Core/IOS/WFS/WFSI.cpp index 9f9c3b291b..641d41e052 100644 --- a/Source/Core/Core/IOS/WFS/WFSI.cpp +++ b/Source/Core/Core/IOS/WFS/WFSI.cpp @@ -3,7 +3,6 @@ #include "Core/IOS/WFS/WFSI.h" -#include #include #include #include @@ -12,6 +11,7 @@ #include #include "Common/CommonTypes.h" +#include "Common/Crypto/AES.h" #include "Common/FileUtil.h" #include "Common/IOFile.h" #include "Common/Logging/Log.h" @@ -167,8 +167,7 @@ std::optional WFSIDevice::IOCtl(const IOCtlRequest& request) break; } - memcpy(m_aes_key, ticket.GetTitleKey(m_ios.GetIOSC()).data(), sizeof(m_aes_key)); - mbedtls_aes_setkey_dec(&m_aes_ctx, m_aes_key, 128); + m_aes_ctx = Common::AES::CreateContextDecrypt(ticket.GetTitleKey(m_ios.GetIOSC()).data()); SetImportTitleIdAndGroupId(m_tmd.GetTitleId(), m_tmd.GetGroupId()); @@ -224,8 +223,8 @@ std::optional WFSIDevice::IOCtl(const IOCtlRequest& request) input_size, input_ptr, content_id); std::vector decrypted(input_size); - mbedtls_aes_crypt_cbc(&m_aes_ctx, MBEDTLS_AES_DECRYPT, input_size, m_aes_iv, - Memory::GetPointer(input_ptr), decrypted.data()); + m_aes_ctx->Crypt(m_aes_iv, m_aes_iv, Memory::GetPointer(input_ptr), decrypted.data(), + input_size); m_arc_unpacker.AddBytes(decrypted); break; diff --git a/Source/Core/Core/IOS/WFS/WFSI.h b/Source/Core/Core/IOS/WFS/WFSI.h index c83709d386..d1c2fdd919 100644 --- a/Source/Core/Core/IOS/WFS/WFSI.h +++ b/Source/Core/Core/IOS/WFS/WFSI.h @@ -4,12 +4,12 @@ #pragma once #include +#include #include #include -#include - #include "Common/CommonTypes.h" +#include "Common/Crypto/AES.h" #include "Core/IOS/Device.h" #include "Core/IOS/ES/Formats.h" #include "Core/IOS/IOS.h" @@ -50,7 +50,7 @@ private: std::string m_device_name; - mbedtls_aes_context m_aes_ctx{}; + std::unique_ptr m_aes_ctx{}; u8 m_aes_key[0x10] = {}; u8 m_aes_iv[0x10] = {}; diff --git a/Source/Core/DiscIO/NANDImporter.cpp b/Source/Core/DiscIO/NANDImporter.cpp index 4520930ca3..6c17303467 100644 --- a/Source/Core/DiscIO/NANDImporter.cpp +++ b/Source/Core/DiscIO/NANDImporter.cpp @@ -171,14 +171,13 @@ std::vector NANDImporter::GetEntryData(const NANDFSTEntry& entry) std::vector data{}; data.reserve(remaining_bytes); + auto block = std::make_unique(NAND_FAT_BLOCK_SIZE); while (remaining_bytes > 0) { - std::array iv{}; - std::vector block = Common::AES::Decrypt( - m_aes_key.data(), iv.data(), &m_nand[NAND_FAT_BLOCK_SIZE * sub], NAND_FAT_BLOCK_SIZE); + m_aes_ctx->CryptIvZero(&m_nand[NAND_FAT_BLOCK_SIZE * sub], block.get(), NAND_FAT_BLOCK_SIZE); - size_t size = std::min(remaining_bytes, block.size()); - data.insert(data.end(), block.begin(), block.begin() + size); + size_t size = std::min(remaining_bytes, NAND_FAT_BLOCK_SIZE); + data.insert(data.end(), block.get(), block.get() + size); remaining_bytes -= size; sub = m_superblock->fat[sub]; @@ -260,7 +259,7 @@ void NANDImporter::ExportKeys() { constexpr size_t NAND_AES_KEY_OFFSET = 0x158; - std::copy_n(&m_nand_keys[NAND_AES_KEY_OFFSET], m_aes_key.size(), m_aes_key.begin()); + m_aes_ctx = Common::AES::CreateContextDecrypt(&m_nand_keys[NAND_AES_KEY_OFFSET]); const std::string file_path = m_nand_root + "/keys.bin"; File::IOFile file(file_path, "wb"); diff --git a/Source/Core/DiscIO/NANDImporter.h b/Source/Core/DiscIO/NANDImporter.h index 269a6c5483..227ce5126f 100644 --- a/Source/Core/DiscIO/NANDImporter.h +++ b/Source/Core/DiscIO/NANDImporter.h @@ -12,6 +12,7 @@ #include #include "Common/CommonTypes.h" +#include "Common/Crypto/AES.h" #include "Common/Swap.h" namespace DiscIO @@ -74,7 +75,7 @@ private: std::string m_nand_root; std::vector m_nand; std::vector m_nand_keys; - std::array m_aes_key; + std::unique_ptr m_aes_ctx; std::unique_ptr m_superblock; std::function m_update_callback; }; diff --git a/Source/Core/DiscIO/VolumeWad.cpp b/Source/Core/DiscIO/VolumeWad.cpp index b485899c6e..edfd4c76ee 100644 --- a/Source/Core/DiscIO/VolumeWad.cpp +++ b/Source/Core/DiscIO/VolumeWad.cpp @@ -13,11 +13,10 @@ #include #include -#include - #include "Common/Align.h" #include "Common/Assert.h" #include "Common/CommonTypes.h" +#include "Common/Crypto/AES.h" #include "Common/Crypto/SHA1.h" #include "Common/Logging/Log.h" #include "Common/MsgHandler.h" @@ -159,17 +158,14 @@ bool VolumeWAD::CheckContentIntegrity(const IOS::ES::Content& content, if (encrypted_data.size() != Common::AlignUp(content.size, 0x40)) return false; - mbedtls_aes_context context; - const std::array key = ticket.GetTitleKey(); - mbedtls_aes_setkey_dec(&context, key.data(), 128); + auto context = Common::AES::CreateContextDecrypt(ticket.GetTitleKey().data()); std::array iv{}; iv[0] = static_cast(content.index >> 8); iv[1] = static_cast(content.index & 0xFF); std::vector decrypted_data(encrypted_data.size()); - mbedtls_aes_crypt_cbc(&context, MBEDTLS_AES_DECRYPT, decrypted_data.size(), iv.data(), - encrypted_data.data(), decrypted_data.data()); + context->Crypt(iv.data(), encrypted_data.data(), decrypted_data.data(), decrypted_data.size()); return Common::SHA1::CalculateDigest(decrypted_data.data(), content.size) == content.sha1; } diff --git a/Source/Core/DiscIO/VolumeWii.cpp b/Source/Core/DiscIO/VolumeWii.cpp index ca421cd591..da7254e845 100644 --- a/Source/Core/DiscIO/VolumeWii.cpp +++ b/Source/Core/DiscIO/VolumeWii.cpp @@ -16,11 +16,10 @@ #include #include -#include - #include "Common/Align.h" #include "Common/Assert.h" #include "Common/CommonTypes.h" +#include "Common/Crypto/AES.h" #include "Common/Crypto/SHA1.h" #include "Common/Logging/Log.h" #include "Common/Swap.h" @@ -128,14 +127,11 @@ VolumeWii::VolumeWii(std::unique_ptr reader) return h3_table; }; - auto get_key = [this, partition]() -> std::unique_ptr { + auto get_key = [this, partition]() -> std::unique_ptr { const IOS::ES::TicketReader& ticket = *m_partitions[partition].ticket; if (!ticket.IsValid()) return nullptr; - const std::array key = ticket.GetTitleKey(); - std::unique_ptr aes_context = std::make_unique(); - mbedtls_aes_setkey_dec(aes_context.get(), key.data(), 128); - return aes_context; + return Common::AES::CreateContextDecrypt(ticket.GetTitleKey().data()); }; auto get_file_system = [this, partition]() -> std::unique_ptr { @@ -148,7 +144,7 @@ VolumeWii::VolumeWii(std::unique_ptr reader) }; m_partitions.emplace( - partition, PartitionDetails{Common::Lazy>(get_key), + partition, PartitionDetails{Common::Lazy>(get_key), Common::Lazy(get_ticket), Common::Lazy(get_tmd), Common::Lazy>(get_cert_chain), @@ -183,11 +179,11 @@ bool VolumeWii::Read(u64 offset, u64 length, u8* buffer, const Partition& partit buffer); } - mbedtls_aes_context* aes_context = partition_details.key->get(); + auto aes_context = partition_details.key->get(); if (!aes_context) return false; - std::vector read_buffer(BLOCK_TOTAL_SIZE); + auto read_buffer = std::make_unique(BLOCK_TOTAL_SIZE); while (length > 0) { // Calculate offsets @@ -198,11 +194,11 @@ bool VolumeWii::Read(u64 offset, u64 length, u8* buffer, const Partition& partit if (m_last_decrypted_block != block_offset_on_disc) { // Read the current block - if (!m_reader->Read(block_offset_on_disc, BLOCK_TOTAL_SIZE, read_buffer.data())) + if (!m_reader->Read(block_offset_on_disc, BLOCK_TOTAL_SIZE, read_buffer.get())) return false; // Decrypt the block's data - DecryptBlockData(read_buffer.data(), m_last_decrypted_block_data, aes_context); + DecryptBlockData(read_buffer.get(), m_last_decrypted_block_data, aes_context); m_last_decrypted_block = block_offset_on_disc; } @@ -421,19 +417,19 @@ bool VolumeWii::CheckBlockIntegrity(u64 block_index, const u8* encrypted_data, partition_details.h3_table->size()) return false; - mbedtls_aes_context* aes_context = partition_details.key->get(); + auto aes_context = partition_details.key->get(); if (!aes_context) return false; HashBlock hashes; DecryptBlockHashes(encrypted_data, &hashes, aes_context); - u8 cluster_data[BLOCK_DATA_SIZE]; - DecryptBlockData(encrypted_data, cluster_data, aes_context); + auto cluster_data = std::make_unique(BLOCK_DATA_SIZE); + DecryptBlockData(encrypted_data, cluster_data.get(), aes_context); for (u32 hash_index = 0; hash_index < 31; ++hash_index) { - if (Common::SHA1::CalculateDigest(cluster_data + hash_index * 0x400, 0x400) != + if (Common::SHA1::CalculateDigest(&cluster_data[hash_index * 0x400], 0x400) != hashes.h0[hash_index]) return false; } @@ -577,8 +573,7 @@ bool VolumeWii::EncryptGroup( std::vector> encryption_futures(threads); - mbedtls_aes_context aes_context; - mbedtls_aes_setkey_enc(&aes_context, key.data(), 128); + auto aes_context = Common::AES::CreateContextEncrypt(key.data()); for (size_t i = 0; i < threads; ++i) { @@ -589,13 +584,11 @@ bool VolumeWii::EncryptGroup( { u8* out_ptr = out->data() + j * BLOCK_TOTAL_SIZE; - u8 iv[16] = {}; - mbedtls_aes_crypt_cbc(&aes_context, MBEDTLS_AES_ENCRYPT, BLOCK_HEADER_SIZE, iv, - reinterpret_cast(&unencrypted_hashes[j]), out_ptr); + aes_context->CryptIvZero(reinterpret_cast(&unencrypted_hashes[j]), out_ptr, + BLOCK_HEADER_SIZE); - std::memcpy(iv, out_ptr + 0x3D0, sizeof(iv)); - mbedtls_aes_crypt_cbc(&aes_context, MBEDTLS_AES_ENCRYPT, BLOCK_DATA_SIZE, iv, - unencrypted_data[j].data(), out_ptr + BLOCK_HEADER_SIZE); + aes_context->Crypt(out_ptr + 0x3D0, unencrypted_data[j].data(), + out_ptr + BLOCK_HEADER_SIZE, BLOCK_DATA_SIZE); } }, i * BLOCKS_PER_GROUP / threads, (i + 1) * BLOCKS_PER_GROUP / threads); @@ -607,20 +600,14 @@ bool VolumeWii::EncryptGroup( return true; } -void VolumeWii::DecryptBlockHashes(const u8* in, HashBlock* out, mbedtls_aes_context* aes_context) +void VolumeWii::DecryptBlockHashes(const u8* in, HashBlock* out, Common::AES::Context* aes_context) { - std::array iv; - iv.fill(0); - mbedtls_aes_crypt_cbc(aes_context, MBEDTLS_AES_DECRYPT, sizeof(HashBlock), iv.data(), in, - reinterpret_cast(out)); + aes_context->CryptIvZero(in, reinterpret_cast(out), sizeof(HashBlock)); } -void VolumeWii::DecryptBlockData(const u8* in, u8* out, mbedtls_aes_context* aes_context) +void VolumeWii::DecryptBlockData(const u8* in, u8* out, Common::AES::Context* aes_context) { - std::array iv; - std::copy(&in[0x3d0], &in[0x3e0], iv.data()); - mbedtls_aes_crypt_cbc(aes_context, MBEDTLS_AES_DECRYPT, BLOCK_DATA_SIZE, iv.data(), - &in[BLOCK_HEADER_SIZE], out); + aes_context->Crypt(&in[0x3d0], &in[sizeof(HashBlock)], out, BLOCK_DATA_SIZE); } } // namespace DiscIO diff --git a/Source/Core/DiscIO/VolumeWii.h b/Source/Core/DiscIO/VolumeWii.h index 5e0a86a4b9..d4837a26b6 100644 --- a/Source/Core/DiscIO/VolumeWii.h +++ b/Source/Core/DiscIO/VolumeWii.h @@ -11,8 +11,6 @@ #include #include -#include - #include "Common/CommonTypes.h" #include "Common/Crypto/SHA1.h" #include "Common/Lazy.h" @@ -21,6 +19,8 @@ #include "DiscIO/Volume.h" #include "DiscIO/VolumeDisc.h" +#include "Common/Crypto/AES.h" + namespace DiscIO { class BlobReader; @@ -34,7 +34,7 @@ enum class Platform; class VolumeWii : public VolumeDisc { public: - static constexpr size_t AES_KEY_SIZE = 16; + static constexpr size_t AES_KEY_SIZE = Common::AES::Context::KEY_SIZE; static constexpr u32 BLOCKS_PER_GROUP = 0x40; @@ -106,8 +106,8 @@ public: const std::function& hash_exception_callback = {}); - static void DecryptBlockHashes(const u8* in, HashBlock* out, mbedtls_aes_context* aes_context); - static void DecryptBlockData(const u8* in, u8* out, mbedtls_aes_context* aes_context); + static void DecryptBlockHashes(const u8* in, HashBlock* out, Common::AES::Context* aes_context); + static void DecryptBlockData(const u8* in, u8* out, Common::AES::Context* aes_context); protected: u32 GetOffsetShift() const override { return 2; } @@ -115,7 +115,7 @@ protected: private: struct PartitionDetails { - Common::Lazy> key; + Common::Lazy> key; Common::Lazy ticket; Common::Lazy tmd; Common::Lazy> cert_chain; diff --git a/Source/Core/DiscIO/WIABlob.cpp b/Source/Core/DiscIO/WIABlob.cpp index 1ee9dd7308..35b6407d66 100644 --- a/Source/Core/DiscIO/WIABlob.cpp +++ b/Source/Core/DiscIO/WIABlob.cpp @@ -1318,8 +1318,7 @@ WIARVZFileReader::ProcessAndCompress(CompressThreadState* state, CompressPa { const PartitionEntry& partition_entry = partition_entries[parameters.data_entry->index]; - mbedtls_aes_context aes_context; - mbedtls_aes_setkey_dec(&aes_context, partition_entry.partition_key.data(), 128); + auto aes_context = Common::AES::CreateContextDecrypt(partition_entry.partition_key.data()); const u64 groups = Common::AlignUp(parameters.data.size(), VolumeWii::GROUP_TOTAL_SIZE) / VolumeWii::GROUP_TOTAL_SIZE; @@ -1388,7 +1387,7 @@ WIARVZFileReader::ProcessAndCompress(CompressThreadState* state, CompressPa { const u64 offset_of_block = offset_of_group + j * VolumeWii::BLOCK_TOTAL_SIZE; VolumeWii::DecryptBlockData(parameters.data.data() + offset_of_block, - state->decryption_buffer[j].data(), &aes_context); + state->decryption_buffer[j].data(), aes_context.get()); } else { @@ -1413,7 +1412,7 @@ WIARVZFileReader::ProcessAndCompress(CompressThreadState* state, CompressPa VolumeWii::HashBlock hashes; VolumeWii::DecryptBlockHashes(parameters.data.data() + offset_of_block, &hashes, - &aes_context); + aes_context.get()); const auto compare_hash = [&](size_t offset_in_block) { ASSERT(offset_in_block + Common::SHA1::DIGEST_LEN <= VolumeWii::BLOCK_HEADER_SIZE);