Merge pull request #5601 from leoetlino/esformats-additions

ESFormats additions and fixes
This commit is contained in:
Leo Lam 2017-06-15 21:32:44 +02:00 committed by GitHub
commit 9bc1b652fe
10 changed files with 264 additions and 96 deletions

View File

@ -110,7 +110,7 @@ IPCCommandResult DI::IOCtlV(const IOCtlVRequest& request)
// Read TMD to the buffer
const IOS::ES::TMDReader tmd = DVDThread::GetTMD(partition);
const std::vector<u8> raw_tmd = tmd.GetRawTMD();
const std::vector<u8>& raw_tmd = tmd.GetBytes();
Memory::CopyToEmu(request.io_vectors[0].address, raw_tmd.data(), raw_tmd.size());
ES::DIVerify(tmd, DVDThread::GetTicket(partition));

View File

@ -610,7 +610,7 @@ s32 ES::DIVerify(const IOS::ES::TMDReader& tmd, const IOS::ES::TicketReader& tic
if (!File::Exists(tmd_path))
{
File::IOFile tmd_file(tmd_path, "wb");
const std::vector<u8>& tmd_bytes = tmd.GetRawTMD();
const std::vector<u8>& tmd_bytes = tmd.GetBytes();
if (!tmd_file.WriteBytes(tmd_bytes.data(), tmd_bytes.size()))
ERROR_LOG(IOS_ES, "DIVerify failed to write disc TMD to NAND.");
}

View File

@ -10,6 +10,7 @@
#include <cstddef>
#include <cstring>
#include <locale>
#include <map>
#include <optional>
#include <string>
#include <utility>
@ -59,31 +60,124 @@ bool Content::IsShared() const
return (type & 0x8000) != 0;
}
SignedBlobReader::SignedBlobReader(const std::vector<u8>& bytes) : m_bytes(bytes)
{
}
SignedBlobReader::SignedBlobReader(std::vector<u8>&& bytes) : m_bytes(std::move(bytes))
{
}
const std::vector<u8>& SignedBlobReader::GetBytes() const
{
return m_bytes;
}
void SignedBlobReader::SetBytes(const std::vector<u8>& bytes)
{
m_bytes = bytes;
}
void SignedBlobReader::SetBytes(std::vector<u8>&& bytes)
{
m_bytes = std::move(bytes);
}
bool SignedBlobReader::IsSignatureValid() const
{
// Too small for the certificate type.
if (m_bytes.size() < sizeof(Cert::type))
return false;
// Too small to contain the whole signature data.
const size_t signature_size = GetSignatureSize();
if (signature_size == 0 || m_bytes.size() < signature_size)
return false;
return true;
}
SignatureType SignedBlobReader::GetSignatureType() const
{
return static_cast<SignatureType>(Common::swap32(m_bytes.data()));
}
std::vector<u8> SignedBlobReader::GetSignatureData() const
{
switch (GetSignatureType())
{
case SignatureType::RSA4096:
{
const auto signature_begin = m_bytes.begin() + offsetof(SignatureRSA4096, sig);
return std::vector<u8>(signature_begin, signature_begin + sizeof(SignatureRSA4096::sig));
}
case SignatureType::RSA2048:
{
const auto signature_begin = m_bytes.begin() + offsetof(SignatureRSA2048, sig);
return std::vector<u8>(signature_begin, signature_begin + sizeof(SignatureRSA2048::sig));
}
default:
return {};
}
}
size_t SignedBlobReader::GetSignatureSize() const
{
switch (GetSignatureType())
{
case SignatureType::RSA4096:
return sizeof(SignatureRSA4096);
case SignatureType::RSA2048:
return sizeof(SignatureRSA2048);
default:
return 0;
}
}
std::string SignedBlobReader::GetIssuer() const
{
switch (GetSignatureType())
{
case SignatureType::RSA4096:
{
const char* issuer =
reinterpret_cast<const char*>(m_bytes.data() + offsetof(SignatureRSA4096, issuer));
return std::string(issuer, strnlen(issuer, sizeof(SignatureRSA4096::issuer)));
}
case SignatureType::RSA2048:
{
const char* issuer =
reinterpret_cast<const char*>(m_bytes.data() + offsetof(SignatureRSA2048, issuer));
return std::string(issuer, strnlen(issuer, sizeof(SignatureRSA2048::issuer)));
}
default:
return "";
}
}
void SignedBlobReader::DoState(PointerWrap& p)
{
p.Do(m_bytes);
}
bool IsValidTMDSize(size_t size)
{
return size <= 0x49e4;
}
TMDReader::TMDReader(const std::vector<u8>& bytes) : m_bytes(bytes)
TMDReader::TMDReader(const std::vector<u8>& bytes) : SignedBlobReader(bytes)
{
}
TMDReader::TMDReader(std::vector<u8>&& bytes) : m_bytes(std::move(bytes))
TMDReader::TMDReader(std::vector<u8>&& bytes) : SignedBlobReader(std::move(bytes))
{
}
void TMDReader::SetBytes(const std::vector<u8>& bytes)
{
m_bytes = bytes;
}
void TMDReader::SetBytes(std::vector<u8>&& bytes)
{
m_bytes = std::move(bytes);
}
bool TMDReader::IsValid() const
{
if (!IsSignatureValid())
return false;
if (m_bytes.size() < sizeof(TMDHeader))
{
// TMD is too small to contain its base fields.
@ -99,16 +193,6 @@ bool TMDReader::IsValid() const
return true;
}
const std::vector<u8>& TMDReader::GetRawTMD() const
{
return m_bytes;
}
std::vector<u8> TMDReader::GetRawHeader() const
{
return std::vector<u8>(m_bytes.begin(), m_bytes.begin() + sizeof(TMDHeader));
}
std::vector<u8> TMDReader::GetRawView() const
{
// Base fields
@ -231,37 +315,17 @@ bool TMDReader::FindContentById(u32 id, Content* content) const
return false;
}
void TMDReader::DoState(PointerWrap& p)
{
p.Do(m_bytes);
}
TicketReader::TicketReader(const std::vector<u8>& bytes) : m_bytes(bytes)
TicketReader::TicketReader(const std::vector<u8>& bytes) : SignedBlobReader(bytes)
{
}
TicketReader::TicketReader(std::vector<u8>&& bytes) : m_bytes(std::move(bytes))
TicketReader::TicketReader(std::vector<u8>&& bytes) : SignedBlobReader(std::move(bytes))
{
}
void TicketReader::SetBytes(const std::vector<u8>& bytes)
{
m_bytes = bytes;
}
void TicketReader::SetBytes(std::vector<u8>&& bytes)
{
m_bytes = std::move(bytes);
}
bool TicketReader::IsValid() const
{
return !m_bytes.empty() && m_bytes.size() % sizeof(Ticket) == 0;
}
void TicketReader::DoState(PointerWrap& p)
{
p.Do(m_bytes);
return IsSignatureValid() && !m_bytes.empty() && m_bytes.size() % sizeof(Ticket) == 0;
}
size_t TicketReader::GetNumberOfTickets() const
@ -269,11 +333,6 @@ size_t TicketReader::GetNumberOfTickets() const
return m_bytes.size() / sizeof(Ticket);
}
const std::vector<u8>& TicketReader::GetRawTicket() const
{
return m_bytes;
}
std::vector<u8> TicketReader::GetRawTicket(u64 ticket_id_to_find) const
{
for (size_t i = 0; i < GetNumberOfTickets(); ++i)
@ -304,13 +363,6 @@ std::vector<u8> TicketReader::GetRawTicketView(u32 ticket_num) const
return view;
}
std::string TicketReader::GetIssuer() const
{
const char* bytes =
reinterpret_cast<const char*>(m_bytes.data() + offsetof(Ticket, signature.issuer));
return std::string(bytes, strnlen(bytes, sizeof(Ticket::signature.issuer)));
}
u32 TicketReader::GetDeviceId() const
{
return Common::swap32(m_bytes.data() + offsetof(Ticket, device_id));
@ -481,10 +533,12 @@ bool SharedContentMap::WriteEntries() const
File::CreateFullPath(temp_path);
// Atomically write the new content map.
File::IOFile file(temp_path, "w+b");
if (!file.WriteArray(m_entries.data(), m_entries.size()))
return false;
File::CreateFullPath(m_file_path);
{
File::IOFile file(temp_path, "w+b");
if (!file.WriteArray(m_entries.data(), m_entries.size()))
return false;
File::CreateFullPath(m_file_path);
}
return File::RenameSync(temp_path, m_file_path);
}
@ -563,5 +617,89 @@ u32 UIDSys::GetOrInsertUIDForTitle(const u64 title_id)
return uid;
}
CertReader::CertReader(std::vector<u8>&& bytes) : SignedBlobReader(std::move(bytes))
{
if (!IsSignatureValid())
return;
switch (GetSignatureType())
{
case SignatureType::RSA4096:
if (m_bytes.size() < sizeof(CertRSA4096))
return;
m_bytes.resize(sizeof(CertRSA4096));
break;
case SignatureType::RSA2048:
if (m_bytes.size() < sizeof(CertRSA2048))
return;
m_bytes.resize(sizeof(CertRSA2048));
break;
default:
return;
}
m_is_valid = true;
}
bool CertReader::IsValid() const
{
return m_is_valid;
}
u32 CertReader::GetId() const
{
const size_t offset = GetSignatureSize() + offsetof(CertHeader, id);
return Common::swap32(m_bytes.data() + offset);
}
std::string CertReader::GetName() const
{
const char* name = reinterpret_cast<const char*>(m_bytes.data() + GetSignatureSize() +
offsetof(CertHeader, name));
return std::string(name, strnlen(name, sizeof(CertHeader::name)));
}
PublicKeyType CertReader::GetPublicKeyType() const
{
const size_t offset = GetSignatureSize() + offsetof(CertHeader, public_key_type);
return static_cast<PublicKeyType>(Common::swap32(m_bytes.data() + offset));
}
std::vector<u8> CertReader::GetPublicKey() const
{
static const std::map<SignatureType, std::pair<size_t, size_t>> type_to_key_info = {{
{SignatureType::RSA4096,
{offsetof(CertRSA4096, public_key),
sizeof(CertRSA4096::public_key) + sizeof(CertRSA4096::exponent)}},
{SignatureType::RSA2048,
{offsetof(CertRSA2048, public_key),
sizeof(CertRSA2048::public_key) + sizeof(CertRSA2048::exponent)}},
}};
const auto info = type_to_key_info.at(GetSignatureType());
const auto key_begin = m_bytes.begin() + info.first;
return std::vector<u8>(key_begin, key_begin + info.second);
}
std::map<std::string, CertReader> ParseCertChain(const std::vector<u8>& chain)
{
std::map<std::string, CertReader> certs;
size_t processed = 0;
while (processed != chain.size())
{
CertReader cert_reader{std::vector<u8>(chain.begin() + processed, chain.end())};
if (!cert_reader.IsValid())
return certs;
processed += cert_reader.GetBytes().size();
const std::string name = cert_reader.GetName();
certs.emplace(std::move(name), std::move(cert_reader));
}
return certs;
}
} // namespace ES
} // namespace IOS

View File

@ -16,6 +16,7 @@
#include "Common/CommonTypes.h"
#include "Common/NandPaths.h"
#include "Core/IOS/Device.h"
#include "Core/IOS/IOSC.h"
#include "DiscIO/Enums.h"
@ -140,23 +141,45 @@ struct Ticket
static_assert(sizeof(Ticket) == 0x2A4, "Ticket has the wrong size");
#pragma pack(pop)
class SignedBlobReader
{
public:
SignedBlobReader() = default;
explicit SignedBlobReader(const std::vector<u8>& bytes);
explicit SignedBlobReader(std::vector<u8>&& bytes);
const std::vector<u8>& GetBytes() const;
void SetBytes(const std::vector<u8>& bytes);
void SetBytes(std::vector<u8>&& bytes);
// Only checks whether the signature data could be parsed. The signature is not verified.
bool IsSignatureValid() const;
SignatureType GetSignatureType() const;
std::vector<u8> GetSignatureData() const;
size_t GetSignatureSize() const;
// Returns the whole issuer chain.
// Example: Root-CA00000001 if the blob was signed by CA00000001, which is signed by the Root.
std::string GetIssuer() const;
void DoState(PointerWrap& p);
protected:
std::vector<u8> m_bytes;
};
bool IsValidTMDSize(size_t size);
class TMDReader final
class TMDReader final : public SignedBlobReader
{
public:
TMDReader() = default;
explicit TMDReader(const std::vector<u8>& bytes);
explicit TMDReader(std::vector<u8>&& bytes);
void SetBytes(const std::vector<u8>& bytes);
void SetBytes(std::vector<u8>&& bytes);
bool IsValid() const;
// Returns the TMD or parts of it without any kind of parsing. Intended for use by ES.
const std::vector<u8>& GetRawTMD() const;
std::vector<u8> GetRawHeader() const;
// Returns parts of the TMD without any kind of parsing. Intended for use by ES.
std::vector<u8> GetRawView() const;
u16 GetBootIndex() const;
@ -176,27 +199,17 @@ public:
bool GetContent(u16 index, Content* content) const;
std::vector<Content> GetContents() const;
bool FindContentById(u32 id, Content* content) const;
void DoState(PointerWrap& p);
private:
std::vector<u8> m_bytes;
};
class TicketReader final
class TicketReader final : public SignedBlobReader
{
public:
TicketReader() = default;
explicit TicketReader(const std::vector<u8>& bytes);
explicit TicketReader(std::vector<u8>&& bytes);
void SetBytes(const std::vector<u8>& bytes);
void SetBytes(std::vector<u8>&& bytes);
bool IsValid() const;
void DoState(PointerWrap& p);
const std::vector<u8>& GetRawTicket() const;
std::vector<u8> GetRawTicket(u64 ticket_id) const;
size_t GetNumberOfTickets() const;
@ -206,7 +219,6 @@ public:
// more than just one ticket and generate ticket views for them, so we implement it too.
std::vector<u8> GetRawTicketView(u32 ticket_num) const;
std::string GetIssuer() const;
u32 GetDeviceId() const;
u64 GetTitleId() const;
std::vector<u8> GetTitleKey() const;
@ -217,9 +229,6 @@ public:
// Decrypts the title key field for a "personalised" ticket -- one that is device-specific
// and has a title key that must be decrypted first.
s32 Unpersonalise();
private:
std::vector<u8> m_bytes;
};
class SharedContentMap final
@ -256,5 +265,26 @@ private:
std::string m_file_path;
std::map<u32, u64> m_entries;
};
class CertReader final : public SignedBlobReader
{
public:
explicit CertReader(std::vector<u8>&& bytes);
bool IsValid() const;
u32 GetId() const;
// Returns the certificate name. Examples: XS00000003, CA00000001
std::string GetName() const;
PublicKeyType GetPublicKeyType() const;
// Returns the public key bytes + any other data associated with it.
// For RSA public keys, this includes 4 bytes for the exponent at the end.
std::vector<u8> GetPublicKey() const;
private:
bool m_is_valid = false;
};
std::map<std::string, CertReader> ParseCertChain(const std::vector<u8>& chain);
} // namespace ES
} // namespace IOS

View File

@ -254,7 +254,7 @@ bool ES::WriteImportTMD(const IOS::ES::TMDReader& tmd)
{
File::IOFile file(tmd_path, "wb");
if (!file.WriteBytes(tmd.GetRawTMD().data(), tmd.GetRawTMD().size()))
if (!file.WriteBytes(tmd.GetBytes().data(), tmd.GetBytes().size()))
return false;
}

View File

@ -150,7 +150,7 @@ IPCCommandResult ES::GetStoredTMDSize(const IOCtlVRequest& request)
if (!tmd.IsValid())
return GetDefaultReply(FS_ENOENT);
const u32 tmd_size = static_cast<u32>(tmd.GetRawTMD().size());
const u32 tmd_size = static_cast<u32>(tmd.GetBytes().size());
Memory::Write_U32(tmd_size, request.io_vectors[0].address);
INFO_LOG(IOS_ES, "GetStoredTMDSize: %u bytes for %016" PRIx64, tmd_size, title_id);
@ -171,7 +171,7 @@ IPCCommandResult ES::GetStoredTMD(const IOCtlVRequest& request)
// TODO: actually use this param in when writing to the outbuffer :/
const u32 MaxCount = Memory::Read_U32(request.in_vectors[1].address);
const std::vector<u8> raw_tmd = tmd.GetRawTMD();
const std::vector<u8>& raw_tmd = tmd.GetBytes();
if (raw_tmd.size() != request.io_vectors[0].size)
return GetDefaultReply(ES_EINVAL);

View File

@ -40,7 +40,7 @@ static ReturnCode WriteTicket(const IOS::ES::TicketReader& ticket)
if (!ticket_file)
return ES_EIO;
const std::vector<u8>& raw_ticket = ticket.GetRawTicket();
const std::vector<u8>& raw_ticket = ticket.GetBytes();
return ticket_file.WriteBytes(raw_ticket.data(), raw_ticket.size()) ? IPC_SUCCESS : ES_EIO;
}
@ -394,7 +394,7 @@ ReturnCode ES::DeleteTicket(const u8* ticket_view)
const u64 ticket_id = Common::swap64(ticket_view + offsetof(IOS::ES::TicketView, ticket_id));
ticket.DeleteTicket(ticket_id);
const std::vector<u8>& new_ticket = ticket.GetRawTicket();
const std::vector<u8>& new_ticket = ticket.GetBytes();
const std::string ticket_path = Common::GetTicketFileName(title_id, Common::FROM_SESSION_ROOT);
{
File::IOFile ticket_file(ticket_path, "wb");
@ -505,7 +505,7 @@ ReturnCode ES::ExportTitleInit(Context& context, u64 title_id, u8* tmd_bytes, u3
context.title_export.title_key = ticket.GetTitleKey();
const auto& raw_tmd = context.title_export.tmd.GetRawTMD();
const std::vector<u8>& raw_tmd = context.title_export.tmd.GetBytes();
if (tmd_size != raw_tmd.size())
return ES_EINVAL;

View File

@ -376,7 +376,7 @@ IPCCommandResult ES::DIGetTMDSize(const IOCtlVRequest& request)
if (!GetTitleContext().active)
return GetDefaultReply(ES_EINVAL);
Memory::Write_U32(static_cast<u32>(GetTitleContext().tmd.GetRawTMD().size()),
Memory::Write_U32(static_cast<u32>(GetTitleContext().tmd.GetBytes().size()),
request.io_vectors[0].address);
return GetDefaultReply(IPC_SUCCESS);
}
@ -393,7 +393,7 @@ IPCCommandResult ES::DIGetTMD(const IOCtlVRequest& request)
if (!GetTitleContext().active)
return GetDefaultReply(ES_EINVAL);
const std::vector<u8>& tmd_bytes = GetTitleContext().tmd.GetRawTMD();
const std::vector<u8>& tmd_bytes = GetTitleContext().tmd.GetBytes();
if (static_cast<u32>(tmd_bytes.size()) > tmd_size)
return GetDefaultReply(ES_EINVAL);

View File

@ -27,8 +27,8 @@ bool InstallWAD(const std::string& wad_path)
const auto es = ios.GetES();
IOS::HLE::Device::ES::Context context;
if (es->ImportTicket(wad.GetTicket().GetRawTicket()) < 0 ||
es->ImportTitleInit(context, tmd.GetRawTMD()) < 0)
if (es->ImportTicket(wad.GetTicket().GetBytes()) < 0 ||
es->ImportTitleInit(context, tmd.GetBytes()) < 0)
{
PanicAlertT("WAD installation failed: Could not initialise title import.");
return false;

View File

@ -76,7 +76,7 @@ void TMDReaderTest::TestGeneralInfo()
void TMDReaderTest::TestRawTMDAndView()
{
const std::vector<u8>& dolphin_tmd_bytes = m_tmd.GetRawTMD();
const std::vector<u8>& dolphin_tmd_bytes = m_tmd.GetBytes();
// Separate check because gtest prints neither the size nor the full buffer.
EXPECT_EQ(m_raw_tmd.size(), dolphin_tmd_bytes.size());
EXPECT_EQ(m_raw_tmd, dolphin_tmd_bytes);