WIA: Reuse groups when writing

This is useful for the way Dolphin scrubs Wii discs.
The encrypted data is what gets zeroed out, but this
zeroed out data then gets decrypted before being stored,
and the resulting data does not compress well.
However, each block of decrypted scrubbed data is
identical given the same encryption key, and there's
nothing stopping us from making multiple group entries
point to the same offset in the file, so we only have
to store one copy of this data per partition.

For reference, wit zeroes out the decrypted data,
but Dolphin's WIA writer can't do this because it currently
doesn't know which parts of the disc are scrubbed.

This is also useful for things such as storing Datel discs
full of 0x55 blocks (repesenting unreadable blocks)
without compression enabled.
This commit is contained in:
JosJuice 2020-04-19 15:34:02 +02:00
parent 40e46aee57
commit e5b9e1ba1f
2 changed files with 171 additions and 84 deletions

View File

@ -730,7 +730,8 @@ bool WIAFileReader::PurgeCompressor::Start()
return true; return true;
} }
bool WIAFileReader::PurgeCompressor::AddPrecedingDataOnlyForPurgeHashing(const u8* data, size_t size) bool WIAFileReader::PurgeCompressor::AddPrecedingDataOnlyForPurgeHashing(const u8* data,
size_t size)
{ {
mbedtls_sha1_update_ret(&m_sha1_context, data, size); mbedtls_sha1_update_ret(&m_sha1_context, data, size);
return true; return true;
@ -1361,10 +1362,28 @@ WIAFileReader::ConversionResult WIAFileReader::SetUpDataEntriesForWriting(
return ConversionResult::Success; return ConversionResult::Success;
} }
bool WIAFileReader::TryReuseGroup(std::vector<GroupEntry>* group_entries, size_t* groups_written,
std::map<ReuseID, GroupEntry>* reusable_groups,
std::optional<ReuseID> reuse_id)
{
if (!reuse_id)
return false;
const auto it = reusable_groups->find(*reuse_id);
if (it != reusable_groups->end())
{
(*group_entries)[*groups_written] = it->second;
++*groups_written;
}
return it != reusable_groups->end();
}
WIAFileReader::ConversionResult WIAFileReader::CompressAndWriteGroup( WIAFileReader::ConversionResult WIAFileReader::CompressAndWriteGroup(
File::IOFile* file, u64* bytes_written, std::vector<GroupEntry>* group_entries, File::IOFile* file, u64* bytes_written, std::vector<GroupEntry>* group_entries,
size_t* groups_written, Compressor* compressor, bool compressed_exception_lists, size_t* groups_written, Compressor* compressor, bool compressed_exception_lists,
const std::vector<u8>& exception_lists, const std::vector<u8>& main_data) const std::vector<u8>& exception_lists, const std::vector<u8>& main_data,
std::map<ReuseID, GroupEntry>* reusable_groups, std::optional<ReuseID> reuse_id)
{ {
const auto all_zero = [](const std::vector<u8>& data) { const auto all_zero = [](const std::vector<u8>& data) {
return std::all_of(data.begin(), data.end(), [](u8 x) { return x == 0; }); return std::all_of(data.begin(), data.end(), [](u8 x) { return x == 0; });
@ -1377,6 +1396,9 @@ WIAFileReader::ConversionResult WIAFileReader::CompressAndWriteGroup(
return ConversionResult::Success; return ConversionResult::Success;
} }
if (TryReuseGroup(group_entries, groups_written, reusable_groups, reuse_id))
return ConversionResult::Success;
const u64 data_offset = *bytes_written; const u64 data_offset = *bytes_written;
if (compressor) if (compressor)
@ -1450,6 +1472,9 @@ WIAFileReader::ConversionResult WIAFileReader::CompressAndWriteGroup(
group_entry.data_size = Common::swap32(static_cast<u32>(*bytes_written - data_offset)); group_entry.data_size = Common::swap32(static_cast<u32>(*bytes_written - data_offset));
++*groups_written; ++*groups_written;
if (reuse_id)
reusable_groups->emplace(*reuse_id, group_entry);
if (!PadTo4(file, bytes_written)) if (!PadTo4(file, bytes_written))
return ConversionResult::WriteFailed; return ConversionResult::WriteFailed;
@ -1566,6 +1591,11 @@ WIAFileReader::ConvertToWIA(BlobReader* infile, const VolumeDisc* infile_volume,
return ConversionResult::ReadFailed; return ConversionResult::ReadFailed;
// We intentially do not increment bytes_read here, since these bytes will be read again // We intentially do not increment bytes_read here, since these bytes will be read again
const auto all_same = [](const std::vector<u8>& data) {
const u8 first_byte = data.front();
return std::all_of(data.begin(), data.end(), [first_byte](u8 x) { return x == first_byte; });
};
using WiiBlockData = std::array<u8, VolumeWii::BLOCK_DATA_SIZE>; using WiiBlockData = std::array<u8, VolumeWii::BLOCK_DATA_SIZE>;
std::vector<u8> buffer; std::vector<u8> buffer;
@ -1573,6 +1603,8 @@ WIAFileReader::ConvertToWIA(BlobReader* infile, const VolumeDisc* infile_volume,
std::vector<WiiBlockData> decryption_buffer; std::vector<WiiBlockData> decryption_buffer;
std::vector<VolumeWii::HashBlock> hash_buffer; std::vector<VolumeWii::HashBlock> hash_buffer;
std::map<ReuseID, GroupEntry> reusable_groups;
if (!partition_entries.empty()) if (!partition_entries.empty())
{ {
decryption_buffer.resize(VolumeWii::BLOCKS_PER_GROUP); decryption_buffer.resize(VolumeWii::BLOCKS_PER_GROUP);
@ -1617,98 +1649,119 @@ WIAFileReader::ConvertToWIA(BlobReader* infile, const VolumeDisc* infile_volume,
return ConversionResult::ReadFailed; return ConversionResult::ReadFailed;
bytes_read += bytes_to_read; bytes_read += bytes_to_read;
std::vector<std::vector<HashExceptionEntry>> exception_lists(exception_lists_per_chunk); const auto create_reuse_id = [&partition_entry, bytes_to_write](u8 value, bool decrypted) {
return ReuseID{&partition_entry.partition_key, bytes_to_write, decrypted, value};
};
for (u64 j = 0; j < groups; ++j) std::optional<ReuseID> reuse_id;
// Set this group as reusable if the encrypted data is all_same
if (all_same(buffer))
reuse_id = create_reuse_id(buffer.front(), false);
if (!TryReuseGroup(&group_entries, &groups_written, &reusable_groups, reuse_id))
{ {
const u64 offset_of_group = j * VolumeWii::GROUP_TOTAL_SIZE; std::vector<std::vector<HashExceptionEntry>> exception_lists(exception_lists_per_chunk);
const u64 write_offset_of_group = j * VolumeWii::GROUP_DATA_SIZE;
const u64 blocks_in_this_group = for (u64 j = 0; j < groups; ++j)
std::min<u64>(VolumeWii::BLOCKS_PER_GROUP, blocks - j * VolumeWii::BLOCKS_PER_GROUP);
for (u32 k = 0; k < VolumeWii::BLOCKS_PER_GROUP; ++k)
{ {
if (k < blocks_in_this_group) const u64 offset_of_group = j * VolumeWii::GROUP_TOTAL_SIZE;
const u64 write_offset_of_group = j * VolumeWii::GROUP_DATA_SIZE;
const u64 blocks_in_this_group = std::min<u64>(
VolumeWii::BLOCKS_PER_GROUP, blocks - j * VolumeWii::BLOCKS_PER_GROUP);
for (u32 k = 0; k < VolumeWii::BLOCKS_PER_GROUP; ++k)
{
if (k < blocks_in_this_group)
{
const u64 offset_of_block = offset_of_group + k * VolumeWii::BLOCK_TOTAL_SIZE;
VolumeWii::DecryptBlockData(buffer.data() + offset_of_block,
decryption_buffer[k].data(), &aes_context);
}
else
{
decryption_buffer[k].fill(0);
}
}
VolumeWii::HashGroup(decryption_buffer.data(), hash_buffer.data());
for (u64 k = 0; k < blocks_in_this_group; ++k)
{ {
const u64 offset_of_block = offset_of_group + k * VolumeWii::BLOCK_TOTAL_SIZE; const u64 offset_of_block = offset_of_group + k * VolumeWii::BLOCK_TOTAL_SIZE;
VolumeWii::DecryptBlockData(buffer.data() + offset_of_block, const u64 hash_offset_of_block = k * VolumeWii::BLOCK_HEADER_SIZE;
decryption_buffer[k].data(), &aes_context);
VolumeWii::HashBlock hashes;
VolumeWii::DecryptBlockHashes(buffer.data() + offset_of_block, &hashes, &aes_context);
const auto compare_hash = [&](size_t offset_in_block) {
ASSERT(offset_in_block + sizeof(SHA1) <= VolumeWii::BLOCK_HEADER_SIZE);
const u8* desired_hash = reinterpret_cast<u8*>(&hashes) + offset_in_block;
const u8* computed_hash = reinterpret_cast<u8*>(&hash_buffer[k]) + offset_in_block;
if (!std::equal(desired_hash, desired_hash + sizeof(SHA1), computed_hash))
{
const u64 hash_offset = hash_offset_of_block + offset_in_block;
ASSERT(hash_offset <= std::numeric_limits<u16>::max());
HashExceptionEntry& exception = exception_lists[j].emplace_back();
exception.offset = static_cast<u16>(Common::swap16(hash_offset));
std::memcpy(exception.hash.data(), desired_hash, sizeof(SHA1));
}
};
const auto compare_hashes = [&compare_hash](size_t offset, size_t size) {
for (size_t l = 0; l < size; l += sizeof(SHA1))
// The std::min is to ensure that we don't go beyond the end of HashBlock with
// padding_2, which is 32 bytes long (not divisible by sizeof(SHA1), which is 20).
compare_hash(offset + std::min(l, size - sizeof(SHA1)));
};
using HashBlock = VolumeWii::HashBlock;
compare_hashes(offsetof(HashBlock, h0), sizeof(HashBlock::h0));
compare_hashes(offsetof(HashBlock, padding_0), sizeof(HashBlock::padding_0));
compare_hashes(offsetof(HashBlock, h1), sizeof(HashBlock::h1));
compare_hashes(offsetof(HashBlock, padding_1), sizeof(HashBlock::padding_1));
compare_hashes(offsetof(HashBlock, h2), sizeof(HashBlock::h2));
compare_hashes(offsetof(HashBlock, padding_2), sizeof(HashBlock::padding_2));
} }
else
for (u64 k = 0; k < blocks_in_this_group; ++k)
{ {
decryption_buffer[k].fill(0); std::memcpy(buffer.data() + write_offset_of_group + k * VolumeWii::BLOCK_DATA_SIZE,
decryption_buffer[k].data(), VolumeWii::BLOCK_DATA_SIZE);
} }
} }
VolumeWii::HashGroup(decryption_buffer.data(), hash_buffer.data()); bool have_exceptions = false;
for (u64 k = 0; k < blocks_in_this_group; ++k) exceptions_buffer.clear();
for (const std::vector<HashExceptionEntry>& exception_list : exception_lists)
{ {
const u64 offset_of_block = offset_of_group + k * VolumeWii::BLOCK_TOTAL_SIZE; const u16 exceptions = Common::swap16(static_cast<u16>(exception_list.size()));
const u64 hash_offset_of_block = k * VolumeWii::BLOCK_HEADER_SIZE; PushBack(&exceptions_buffer, exceptions);
for (const HashExceptionEntry& exception : exception_list)
VolumeWii::HashBlock hashes; PushBack(&exceptions_buffer, exception);
VolumeWii::DecryptBlockHashes(buffer.data() + offset_of_block, &hashes, &aes_context); if (!exception_list.empty())
have_exceptions = true;
const auto compare_hash = [&](size_t offset_in_block) {
ASSERT(offset_in_block + sizeof(SHA1) <= VolumeWii::BLOCK_HEADER_SIZE);
const u8* desired_hash = reinterpret_cast<u8*>(&hashes) + offset_in_block;
const u8* computed_hash = reinterpret_cast<u8*>(&hash_buffer[k]) + offset_in_block;
if (!std::equal(desired_hash, desired_hash + sizeof(SHA1), computed_hash))
{
const u64 hash_offset = hash_offset_of_block + offset_in_block;
ASSERT(hash_offset <= std::numeric_limits<u16>::max());
HashExceptionEntry& exception = exception_lists[j].emplace_back();
exception.offset = static_cast<u16>(Common::swap16(hash_offset));
std::memcpy(exception.hash.data(), desired_hash, sizeof(SHA1));
}
};
const auto compare_hashes = [&compare_hash](size_t offset, size_t size) {
for (size_t l = 0; l < size; l += sizeof(SHA1))
// The std::min is to ensure that we don't go beyond the end of HashBlock with
// padding_2, which is 32 bytes long (not divisible by sizeof(SHA1), which is 20).
compare_hash(offset + std::min(l, size - sizeof(SHA1)));
};
using HashBlock = VolumeWii::HashBlock;
compare_hashes(offsetof(HashBlock, h0), sizeof(HashBlock::h0));
compare_hashes(offsetof(HashBlock, padding_0), sizeof(HashBlock::padding_0));
compare_hashes(offsetof(HashBlock, h1), sizeof(HashBlock::h1));
compare_hashes(offsetof(HashBlock, padding_1), sizeof(HashBlock::padding_1));
compare_hashes(offsetof(HashBlock, h2), sizeof(HashBlock::h2));
compare_hashes(offsetof(HashBlock, padding_2), sizeof(HashBlock::padding_2));
} }
for (u64 k = 0; k < blocks_in_this_group; ++k) buffer.resize(bytes_to_write);
{
std::memcpy(buffer.data() + write_offset_of_group + k * VolumeWii::BLOCK_DATA_SIZE, // Set this group as reusable if it lacks exceptions and the decrypted data is all_same
decryption_buffer[k].data(), VolumeWii::BLOCK_DATA_SIZE); if (!reuse_id && !have_exceptions && all_same(buffer))
} reuse_id = create_reuse_id(buffer.front(), true);
const ConversionResult write_result = CompressAndWriteGroup(
outfile, &bytes_written, &group_entries, &groups_written, compressor.get(),
compressed_exception_lists, exceptions_buffer, buffer, &reusable_groups, reuse_id);
if (write_result != ConversionResult::Success)
return write_result;
} }
exceptions_buffer.clear();
for (const std::vector<HashExceptionEntry>& exception_list : exception_lists)
{
const u16 exceptions = Common::swap16(static_cast<u16>(exception_list.size()));
PushBack(&exceptions_buffer, exceptions);
for (const HashExceptionEntry& exception : exception_list)
PushBack(&exceptions_buffer, exception);
}
buffer.resize(bytes_to_write);
const ConversionResult write_result = CompressAndWriteGroup(
outfile, &bytes_written, &group_entries, &groups_written, compressor.get(),
compressed_exception_lists, exceptions_buffer, buffer);
if (write_result != ConversionResult::Success)
return write_result;
if (!run_callback()) if (!run_callback())
return ConversionResult::Canceled; return ConversionResult::Canceled;
} }
@ -1742,9 +1795,13 @@ WIAFileReader::ConvertToWIA(BlobReader* infile, const VolumeDisc* infile_volume,
return ConversionResult::ReadFailed; return ConversionResult::ReadFailed;
bytes_read += bytes_to_read; bytes_read += bytes_to_read;
std::optional<ReuseID> reuse_id;
if (all_same(buffer))
reuse_id = ReuseID{nullptr, bytes_to_read, false, buffer.front()};
const ConversionResult write_result = CompressAndWriteGroup( const ConversionResult write_result = CompressAndWriteGroup(
outfile, &bytes_written, &group_entries, &groups_written, compressor.get(), outfile, &bytes_written, &group_entries, &groups_written, compressor.get(),
compressed_exception_lists, exceptions_buffer, buffer); compressed_exception_lists, exceptions_buffer, buffer, &reusable_groups, reuse_id);
if (write_result != ConversionResult::Success) if (write_result != ConversionResult::Success)
return write_result; return write_result;

View File

@ -8,6 +8,7 @@
#include <limits> #include <limits>
#include <map> #include <map>
#include <memory> #include <memory>
#include <optional>
#include <utility> #include <utility>
#include <bzlib.h> #include <bzlib.h>
@ -389,6 +390,33 @@ private:
static u32 LZMA2DictionarySize(u8 p); static u32 LZMA2DictionarySize(u8 p);
struct ReuseID
{
bool operator==(const ReuseID& other) const
{
return std::tie(partition_key, data_size, decrypted, value) ==
std::tie(other.partition_key, other.data_size, other.decrypted, other.value);
}
bool operator<(const ReuseID& other) const
{
return std::tie(partition_key, data_size, decrypted, value) <
std::tie(other.partition_key, other.data_size, other.decrypted, other.value);
}
bool operator>(const ReuseID& other) const
{
return std::tie(partition_key, data_size, decrypted, value) >
std::tie(other.partition_key, other.data_size, other.decrypted, other.value);
}
bool operator!=(const ReuseID& other) const { return !operator==(other); }
bool operator>=(const ReuseID& other) const { return !operator<(other); }
bool operator<=(const ReuseID& other) const { return !operator>(other); }
const WiiKey* partition_key;
u64 data_size;
bool decrypted;
u8 value;
};
static bool PadTo4(File::IOFile* file, u64* bytes_written); static bool PadTo4(File::IOFile* file, u64* bytes_written);
static void AddRawDataEntry(u64 offset, u64 size, int chunk_size, u32* total_groups, static void AddRawDataEntry(u64 offset, u64 size, int chunk_size, u32* total_groups,
std::vector<RawDataEntry>* raw_data_entries, std::vector<RawDataEntry>* raw_data_entries,
@ -402,12 +430,14 @@ private:
std::vector<PartitionEntry>* partition_entries, std::vector<PartitionEntry>* partition_entries,
std::vector<RawDataEntry>* raw_data_entries, std::vector<RawDataEntry>* raw_data_entries,
std::vector<DataEntry>* data_entries); std::vector<DataEntry>* data_entries);
static ConversionResult CompressAndWriteGroup(File::IOFile* file, u64* bytes_written, static bool TryReuseGroup(std::vector<GroupEntry>* group_entries, size_t* groups_written,
std::vector<GroupEntry>* group_entries, std::map<ReuseID, GroupEntry>* reusable_groups,
size_t* groups_written, Compressor* compressor, std::optional<ReuseID> reuse_id);
bool compressed_exception_lists, static ConversionResult CompressAndWriteGroup(
const std::vector<u8>& exception_lists, File::IOFile* file, u64* bytes_written, std::vector<GroupEntry>* group_entries,
const std::vector<u8>& main_data); size_t* groups_written, Compressor* compressor, bool compressed_exception_lists,
const std::vector<u8>& exception_lists, const std::vector<u8>& main_data,
std::map<ReuseID, GroupEntry>* reusable_groups, std::optional<ReuseID> reuse_id);
static ConversionResult CompressAndWrite(File::IOFile* file, u64* bytes_written, static ConversionResult CompressAndWrite(File::IOFile* file, u64* bytes_written,
Compressor* compressor, const u8* data, size_t size, Compressor* compressor, const u8* data, size_t size,
size_t* size_out); size_t* size_out);