WIA: Implement caching and partial decompression

This commit is contained in:
JosJuice 2020-01-05 22:26:28 +01:00
parent b59ef81a7e
commit 01a77ae8a1
2 changed files with 386 additions and 283 deletions

View File

@ -7,6 +7,7 @@
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <cstring> #include <cstring>
#include <limits>
#include <memory> #include <memory>
#include <utility> #include <utility>
@ -122,13 +123,12 @@ bool WIAFileReader::Initialize(const std::string& path)
const u32 number_of_raw_data_entries = Common::swap32(m_header_2.number_of_raw_data_entries); const u32 number_of_raw_data_entries = Common::swap32(m_header_2.number_of_raw_data_entries);
m_raw_data_entries.resize(number_of_raw_data_entries); m_raw_data_entries.resize(number_of_raw_data_entries);
if (!ReadCompressedData(number_of_raw_data_entries * sizeof(RawDataEntry), Chunk& raw_data_entries =
Common::swap64(m_header_2.raw_data_entries_offset), ReadCompressedData(Common::swap64(m_header_2.raw_data_entries_offset),
Common::swap32(m_header_2.raw_data_entries_size), Common::swap32(m_header_2.raw_data_entries_size),
reinterpret_cast<u8*>(m_raw_data_entries.data()), false)) number_of_raw_data_entries * sizeof(RawDataEntry), false);
{ if (!raw_data_entries.ReadAll(&m_raw_data_entries))
return false; return false;
}
std::sort(m_raw_data_entries.begin(), m_raw_data_entries.end(), std::sort(m_raw_data_entries.begin(), m_raw_data_entries.end(),
[](const RawDataEntry& a, const RawDataEntry& b) { [](const RawDataEntry& a, const RawDataEntry& b) {
@ -137,13 +137,11 @@ bool WIAFileReader::Initialize(const std::string& path)
const u32 number_of_group_entries = Common::swap32(m_header_2.number_of_group_entries); const u32 number_of_group_entries = Common::swap32(m_header_2.number_of_group_entries);
m_group_entries.resize(number_of_group_entries); m_group_entries.resize(number_of_group_entries);
if (!ReadCompressedData(number_of_group_entries * sizeof(GroupEntry), Chunk& group_entries = ReadCompressedData(Common::swap64(m_header_2.group_entries_offset),
Common::swap64(m_header_2.group_entries_offset), Common::swap32(m_header_2.group_entries_size),
Common::swap32(m_header_2.group_entries_size), number_of_group_entries * sizeof(GroupEntry), false);
reinterpret_cast<u8*>(m_group_entries.data()), false)) if (!group_entries.ReadAll(&m_group_entries))
{
return false; return false;
}
return true; return true;
} }
@ -260,9 +258,11 @@ bool WIAFileReader::ReadFromGroups(u64* offset, u64* size, u8** out_ptr, u64 chu
else else
{ {
const u64 group_offset_in_file = static_cast<u64>(Common::swap32(group.data_offset)) << 2; const u64 group_offset_in_file = static_cast<u64>(Common::swap32(group.data_offset)) << 2;
if (!ReadCompressedData(chunk_size, group_offset_in_file, group_data_size, offset_in_group, Chunk& chunk =
bytes_to_read, *out_ptr, exception_list)) ReadCompressedData(group_offset_in_file, group_data_size, chunk_size, exception_list);
if (!chunk.Read(offset_in_group, bytes_to_read, *out_ptr))
{ {
m_cached_chunk_offset = std::numeric_limits<u64>::max(); // Invalidate the cache
return false; return false;
} }
} }
@ -275,183 +275,40 @@ bool WIAFileReader::ReadFromGroups(u64* offset, u64* size, u8** out_ptr, u64 chu
return true; return true;
} }
bool WIAFileReader::ReadCompressedData(u32 decompressed_data_size, u64 data_offset, u64 data_size, WIAFileReader::Chunk& WIAFileReader::ReadCompressedData(u64 offset_in_file, u64 compressed_size,
u8* out_ptr, bool exception_list) u64 decompressed_size, bool exception_list)
{ {
if (offset_in_file == m_cached_chunk_offset)
return m_cached_chunk;
std::unique_ptr<Decompressor> decompressor;
switch (m_compression_type) switch (m_compression_type)
{ {
case CompressionType::None: case CompressionType::None:
{ decompressor = std::make_unique<NoneDecompressor>();
return ReadCompressedData(decompressed_data_size, data_offset, data_size, 0, break;
decompressed_data_size, out_ptr, exception_list);
}
case CompressionType::Purge: case CompressionType::Purge:
{ decompressor = std::make_unique<PurgeDecompressor>(decompressed_size);
if (!m_file.Seek(data_offset, SEEK_SET)) break;
return false;
if (exception_list)
{
const std::optional<u64> exception_size = ReadExceptionListFromFile();
if (!exception_size)
return false;
data_size -= *exception_size;
}
const u64 hash_offset = data_size - sizeof(SHA1);
u32 offset_in_data = 0;
u32 offset_in_decompressed_data = 0;
while (offset_in_data < hash_offset)
{
PurgeSegment purge_segment;
if (!m_file.ReadArray(&purge_segment, 1))
return false;
const u32 segment_offset = Common::swap32(purge_segment.offset);
const u32 segment_size = Common::swap32(purge_segment.size);
if (segment_offset < offset_in_decompressed_data)
return false;
const u32 blank_bytes = segment_offset - offset_in_decompressed_data;
std::memset(out_ptr, 0, blank_bytes);
out_ptr += blank_bytes;
if (segment_size != 0 && !m_file.ReadBytes(out_ptr, segment_size))
return false;
out_ptr += segment_size;
offset_in_data += sizeof(PurgeSegment) + segment_size;
offset_in_decompressed_data = segment_offset + segment_size;
}
if (offset_in_data != hash_offset || offset_in_decompressed_data > decompressed_data_size)
return false;
std::memset(out_ptr, 0, decompressed_data_size - offset_in_decompressed_data);
SHA1 expected_hash;
if (!m_file.ReadArray(&expected_hash, 1))
return false;
// TODO: Check hash
return true;
}
case CompressionType::Bzip2: case CompressionType::Bzip2:
decompressor = std::make_unique<Bzip2Decompressor>();
break;
case CompressionType::LZMA: case CompressionType::LZMA:
decompressor = std::make_unique<LZMADecompressor>(false, m_header_2.compressor_data,
m_header_2.compressor_data_size);
break;
case CompressionType::LZMA2: case CompressionType::LZMA2:
{ decompressor = std::make_unique<LZMADecompressor>(true, m_header_2.compressor_data,
std::vector<u8> compressed_data(data_size); m_header_2.compressor_data_size);
if (!m_file.Seek(data_offset, SEEK_SET) || !m_file.ReadBytes(compressed_data.data(), data_size)) break;
return false;
std::unique_ptr<Decompressor> decompressor;
switch (m_compression_type)
{
case CompressionType::Bzip2:
decompressor = std::make_unique<Bzip2Decompressor>();
break;
case CompressionType::LZMA:
decompressor = std::make_unique<LZMADecompressor>(false, m_header_2.compressor_data,
m_header_2.compressor_data_size);
break;
case CompressionType::LZMA2:
decompressor = std::make_unique<LZMADecompressor>(true, m_header_2.compressor_data,
m_header_2.compressor_data_size);
break;
}
if (!decompressor->Start(compressed_data.data(), compressed_data.size()))
return false;
if (exception_list)
{
u16 exceptions;
if (decompressor->Read(reinterpret_cast<u8*>(&exceptions), sizeof(exceptions)) !=
sizeof(exceptions))
{
return false;
}
std::vector<HashExceptionEntry> exception_entries(Common::swap16(exceptions));
u8* exceptions_data = reinterpret_cast<u8*>(exception_entries.data());
const size_t exceptions_size = exception_entries.size() * sizeof(HashExceptionEntry);
if (decompressor->Read(exceptions_data, exceptions_size) != exceptions_size)
return false;
// TODO: Actually handle the exceptions
}
if (decompressor->Read(out_ptr, decompressed_data_size) != decompressed_data_size)
return false;
if (!decompressor->DoneReading())
return false;
return true;
}
} }
return false; const bool compressed_exception_list = m_compression_type > CompressionType::Purge;
}
bool WIAFileReader::ReadCompressedData(u32 decompressed_data_size, u64 data_offset, u64 data_size, m_cached_chunk = Chunk(&m_file, offset_in_file, compressed_size, decompressed_size,
u64 offset_in_data, u64 size_in_data, u8* out_ptr, exception_list, compressed_exception_list, std::move(decompressor));
bool exception_list) m_cached_chunk_offset = offset_in_file;
{ return m_cached_chunk;
if (m_compression_type == CompressionType::None)
{
if (!m_file.Seek(data_offset, SEEK_SET))
return false;
if (exception_list)
{
const std::optional<u64> exception_list_size = ReadExceptionListFromFile();
if (!exception_list_size)
return false;
data_size -= *exception_list_size;
}
if (!m_file.Seek(offset_in_data, SEEK_CUR) || !m_file.ReadBytes(out_ptr, size_in_data))
return false;
return true;
}
else
{
// TODO: Caching
std::vector<u8> buffer(decompressed_data_size);
if (!ReadCompressedData(decompressed_data_size, data_offset, data_size, buffer.data(),
exception_list))
{
return false;
}
std::memcpy(out_ptr, buffer.data() + offset_in_data, size_in_data);
return true;
}
}
std::optional<u64> WIAFileReader::ReadExceptionListFromFile()
{
u16 exceptions;
if (!m_file.ReadArray(&exceptions, 1))
return std::nullopt;
const u64 exception_list_size = Common::AlignUp(
sizeof(exceptions) + Common::swap16(exceptions) * sizeof(HashExceptionEntry), 4);
if (!m_file.Seek(exception_list_size - sizeof(exceptions), SEEK_CUR))
return std::nullopt;
// TODO: Actually handle the exceptions
return exception_list_size;
} }
std::string WIAFileReader::VersionToString(u32 version) std::string WIAFileReader::VersionToString(u32 version)
@ -469,54 +326,142 @@ std::string WIAFileReader::VersionToString(u32 version)
WIAFileReader::Decompressor::~Decompressor() = default; WIAFileReader::Decompressor::~Decompressor() = default;
bool WIAFileReader::NoneDecompressor::Decompress(const DecompressionBuffer& in,
DecompressionBuffer* out, size_t* in_bytes_read)
{
const size_t length =
std::min(in.bytes_written - *in_bytes_read, out->data.size() - out->bytes_written);
std::memcpy(out->data.data() + out->bytes_written, in.data.data() + *in_bytes_read, length);
*in_bytes_read += length;
out->bytes_written += length;
m_done = in.data.size() == *in_bytes_read;
return true;
}
WIAFileReader::PurgeDecompressor::PurgeDecompressor(u64 decompressed_size)
: m_decompressed_size(decompressed_size)
{
}
bool WIAFileReader::PurgeDecompressor::Decompress(const DecompressionBuffer& in,
DecompressionBuffer* out, size_t* in_bytes_read)
{
while (!m_done && in.bytes_written != *in_bytes_read &&
(m_segment_bytes_written < sizeof(m_segment) || out->data.size() != out->bytes_written))
{
if (m_segment_bytes_written == 0 && *in_bytes_read == in.data.size() - sizeof(SHA1))
{
const size_t zeroes_to_write = std::min<size_t>(m_decompressed_size - m_out_bytes_written,
out->data.size() - out->bytes_written);
std::memset(out->data.data() + out->bytes_written, 0, zeroes_to_write);
out->bytes_written += zeroes_to_write;
m_out_bytes_written += zeroes_to_write;
if (m_out_bytes_written == m_decompressed_size)
{
*in_bytes_read += sizeof(SHA1);
m_done = true;
// TODO: Check hash
}
return true;
}
if (m_segment_bytes_written < sizeof(m_segment))
{
const size_t bytes_to_copy =
std::min(in.bytes_written - *in_bytes_read, sizeof(m_segment) - m_segment_bytes_written);
std::memcpy(reinterpret_cast<u8*>(&m_segment) + m_segment_bytes_written,
in.data.data() + *in_bytes_read, bytes_to_copy);
*in_bytes_read += bytes_to_copy;
m_bytes_read += bytes_to_copy;
m_segment_bytes_written += bytes_to_copy;
}
if (m_segment_bytes_written < sizeof(m_segment))
return true;
const size_t offset = Common::swap32(m_segment.offset);
const size_t size = Common::swap32(m_segment.size);
if (m_out_bytes_written < offset)
{
const size_t zeroes_to_write =
std::min(offset - m_out_bytes_written, out->data.size() - out->bytes_written);
std::memset(out->data.data() + out->bytes_written, 0, zeroes_to_write);
out->bytes_written += zeroes_to_write;
m_out_bytes_written += zeroes_to_write;
}
if (m_out_bytes_written >= offset && m_out_bytes_written < offset + size)
{
const size_t bytes_to_copy = std::min(
std::min(offset + size - m_out_bytes_written, out->data.size() - out->bytes_written),
in.bytes_written - *in_bytes_read);
std::memcpy(out->data.data() + out->bytes_written, in.data.data() + *in_bytes_read,
bytes_to_copy);
*in_bytes_read += bytes_to_copy;
m_bytes_read += bytes_to_copy;
out->bytes_written += bytes_to_copy;
m_out_bytes_written += bytes_to_copy;
}
if (m_out_bytes_written >= offset + size)
m_segment_bytes_written = 0;
}
return true;
}
WIAFileReader::Bzip2Decompressor::~Bzip2Decompressor() WIAFileReader::Bzip2Decompressor::~Bzip2Decompressor()
{
End();
}
bool WIAFileReader::Bzip2Decompressor::Start(const u8* in_ptr, u64 size)
{ {
if (m_started) if (m_started)
return false; BZ2_bzDecompressEnd(&m_stream);
m_stream.bzalloc = nullptr;
m_stream.bzfree = nullptr;
m_stream.opaque = nullptr;
m_started = BZ2_bzDecompressInit(&m_stream, 0, 0) == BZ_OK;
m_stream.next_in = reinterpret_cast<char*>(const_cast<u8*>(in_ptr));
m_stream.avail_in = size;
return m_started;
} }
u64 WIAFileReader::Bzip2Decompressor::Read(u8* out_ptr, u64 size) bool WIAFileReader::Bzip2Decompressor::Decompress(const DecompressionBuffer& in,
DecompressionBuffer* out, size_t* in_bytes_read)
{ {
if (!m_started || m_error_occurred || m_stream.avail_in == 0) if (!m_started)
return 0; {
if (BZ2_bzDecompressInit(&m_stream, 0, 0) != BZ_OK)
return false;
m_stream.next_out = reinterpret_cast<char*>(out_ptr); m_started = true;
m_stream.avail_out = size; }
constexpr auto clamped_cast = [](size_t x) {
return static_cast<unsigned int>(
std::min<size_t>(std::numeric_limits<unsigned int>().max(), x));
};
char* const in_ptr = reinterpret_cast<char*>(const_cast<u8*>(in.data.data() + *in_bytes_read));
m_stream.next_in = in_ptr;
m_stream.avail_in = clamped_cast(in.bytes_written - *in_bytes_read);
char* const out_ptr = reinterpret_cast<char*>(out->data.data() + out->bytes_written);
m_stream.next_out = out_ptr;
m_stream.avail_out = clamped_cast(out->data.size() - out->bytes_written);
const int result = BZ2_bzDecompress(&m_stream); const int result = BZ2_bzDecompress(&m_stream);
m_error_occurred = result != BZ_OK && result != BZ_STREAM_END;
return m_error_occurred ? 0 : m_stream.next_out - reinterpret_cast<char*>(out_ptr); *in_bytes_read += m_stream.next_in - in_ptr;
} out->bytes_written += m_stream.next_out - out_ptr;
bool WIAFileReader::Bzip2Decompressor::DoneReading() const m_done = result == BZ_STREAM_END;
{ return result == BZ_OK || result == BZ_STREAM_END;
return m_started && !m_error_occurred && m_stream.avail_in == 0;
}
void WIAFileReader::Bzip2Decompressor::End()
{
if (m_started && !m_ended)
{
BZ2_bzDecompressEnd(&m_stream);
m_ended = true;
}
} }
WIAFileReader::LZMADecompressor::LZMADecompressor(bool lzma2, const u8* filter_options, WIAFileReader::LZMADecompressor::LZMADecompressor(bool lzma2, const u8* filter_options,
@ -564,48 +509,162 @@ WIAFileReader::LZMADecompressor::LZMADecompressor(bool lzma2, const u8* filter_o
WIAFileReader::LZMADecompressor::~LZMADecompressor() WIAFileReader::LZMADecompressor::~LZMADecompressor()
{ {
End(); if (m_started)
lzma_end(&m_stream);
} }
bool WIAFileReader::LZMADecompressor::Start(const u8* in_ptr, u64 size) bool WIAFileReader::LZMADecompressor::Decompress(const DecompressionBuffer& in,
DecompressionBuffer* out, size_t* in_bytes_read)
{ {
if (m_started || m_error_occurred) if (!m_started)
return false; {
if (m_error_occurred || lzma_raw_decoder(&m_stream, m_filters) != LZMA_OK)
return false;
m_started = lzma_raw_decoder(&m_stream, m_filters) == LZMA_OK; m_started = true;
}
const u8* const in_ptr = in.data.data() + *in_bytes_read;
m_stream.next_in = in_ptr; m_stream.next_in = in_ptr;
m_stream.avail_in = size; m_stream.avail_in = in.bytes_written - *in_bytes_read;
return m_started;
}
u64 WIAFileReader::LZMADecompressor::Read(u8* out_ptr, u64 size)
{
if (!m_started || m_error_occurred || m_stream.avail_in == 0)
return 0;
u8* const out_ptr = out->data.data() + out->bytes_written;
m_stream.next_out = out_ptr; m_stream.next_out = out_ptr;
m_stream.avail_out = size; m_stream.avail_out = out->data.size() - out->bytes_written;
const lzma_ret result = lzma_code(&m_stream, LZMA_RUN); const lzma_ret result = lzma_code(&m_stream, LZMA_RUN);
m_error_occurred = result != LZMA_OK && result != LZMA_STREAM_END;
return m_error_occurred ? 0 : m_stream.next_out - out_ptr; *in_bytes_read += m_stream.next_in - in_ptr;
out->bytes_written += m_stream.next_out - out_ptr;
m_done = result == LZMA_STREAM_END;
return result == LZMA_OK || result == LZMA_STREAM_END;
} }
bool WIAFileReader::LZMADecompressor::DoneReading() const WIAFileReader::Chunk::Chunk() = default;
WIAFileReader::Chunk::Chunk(File::IOFile* file, u64 offset_in_file, u64 compressed_size,
u64 decompressed_size, bool exception_list,
bool compressed_exception_list,
std::unique_ptr<Decompressor> decompressor)
: m_file(file), m_offset_in_file(offset_in_file), m_exception_list(exception_list),
m_compressed_exception_list(compressed_exception_list),
m_decompressor(std::move(decompressor))
{ {
return m_started && !m_error_occurred && m_stream.avail_in == 0; m_in.data.resize(compressed_size);
m_out.data.resize(decompressed_size);
} }
void WIAFileReader::LZMADecompressor::End() bool WIAFileReader::Chunk::Read(u64 offset, u64 size, u8* out_ptr)
{ {
if (m_started && !m_ended) if (offset + size > m_out.data.size() || !m_decompressor || !m_file)
return false;
if (m_exception_list && !m_compressed_exception_list)
{ {
lzma_end(&m_stream); u16 exceptions;
m_ended = true; if (!m_file->Seek(m_offset_in_file, SEEK_SET) || !m_file->ReadArray(&exceptions, 1))
return false;
m_exceptions.data.resize(Common::swap16(exceptions) * sizeof(HashExceptionEntry));
if (!m_file->ReadBytes(m_exceptions.data.data(), m_exceptions.data.size()))
return false;
m_exceptions.bytes_written = m_exceptions.data.size();
m_in.bytes_written = Common::AlignUp(sizeof(exceptions) + m_exceptions.data.size(), 4);
m_in_bytes_read = m_in.bytes_written;
m_exception_list = false;
// TODO: Actually handle the exceptions
} }
while (offset + size > m_out.bytes_written)
{
u64 bytes_to_read;
if (offset + size == m_out.data.size())
{
// Read all the remaining data.
bytes_to_read = m_in.data.size() - m_in.bytes_written;
}
else
{
// Pick a suitable amount of compressed data to read. The std::min line has to
// be as it is, but the rest is a bit arbitrary and can be changed if desired.
// The compressed data is probably not much bigger than the decompressed data.
// Add a few bytes for possible compression overhead and for the exception list.
bytes_to_read = offset + size - m_out.bytes_written + 0x100;
// Align the access in an attempt to gain speed. But we don't actually know the
// block size of the underlying storage device, so we just use the Wii block size.
bytes_to_read =
Common::AlignUp(bytes_to_read + m_offset_in_file, VolumeWii::BLOCK_TOTAL_SIZE) -
m_offset_in_file;
// Ensure we don't read too much.
bytes_to_read = std::min<u64>(m_in.data.size() - m_in.bytes_written, bytes_to_read);
}
if (bytes_to_read == 0)
{
// Compressed size is larger than expected or decompressed size is smaller than expected
return false;
}
if (!m_file->Seek(m_offset_in_file, SEEK_SET))
return false;
if (!m_file->ReadBytes(m_in.data.data() + m_in.bytes_written, bytes_to_read))
return false;
m_offset_in_file += bytes_to_read;
m_in.bytes_written += bytes_to_read;
if (m_exception_list)
{
if (m_exceptions.data.empty())
m_exceptions.data.resize(sizeof(u16));
if (m_exceptions.data.size() == sizeof(u16))
{
if (!m_decompressor->Decompress(m_in, &m_exceptions, &m_in_bytes_read))
return false;
if (m_exceptions.bytes_written == m_exceptions.data.size())
{
u16 exceptions;
std::memcpy(&exceptions, m_exceptions.data.data(), sizeof(exceptions));
m_exceptions.data.resize(Common::swap16(exceptions) * sizeof(HashExceptionEntry));
m_exceptions.bytes_written = 0;
}
}
if (m_exceptions.data.size() != sizeof(u16))
{
if (!m_decompressor->Decompress(m_in, &m_exceptions, &m_in_bytes_read))
return false;
if (m_exceptions.bytes_written == m_exceptions.data.size())
m_exception_list = false;
// TODO: Actually handle the exceptions
}
}
if (!m_exception_list)
{
if (!m_decompressor->Decompress(m_in, &m_out, &m_in_bytes_read))
return false;
if (m_out.bytes_written == m_out.data.size() && !m_decompressor->Done())
return false; // Decompressed size is larger than expected
if (m_decompressor->Done() && m_in_bytes_read != m_in.data.size())
return false; // Compressed size is smaller than expected
}
}
std::memcpy(out_ptr, m_out.data.data() + offset, size);
return true;
} }
} // namespace DiscIO } // namespace DiscIO

View File

@ -5,6 +5,7 @@
#pragma once #pragma once
#include <array> #include <array>
#include <limits>
#include <memory> #include <memory>
#include <utility> #include <utility>
@ -41,25 +42,18 @@ public:
bool ReadWiiDecrypted(u64 offset, u64 size, u8* out_ptr, u64 partition_data_offset) override; bool ReadWiiDecrypted(u64 offset, u64 size, u8* out_ptr, u64 partition_data_offset) override;
private: private:
explicit WIAFileReader(File::IOFile file, const std::string& path);
bool Initialize(const std::string& path);
bool ReadFromGroups(u64* offset, u64* size, u8** out_ptr, u64 chunk_size, u32 sector_size,
u64 data_offset, u64 data_size, u32 group_index, u32 number_of_groups,
bool exception_list);
bool ReadCompressedData(u32 decompressed_data_size, u64 data_offset, u64 data_size, u8* out_ptr,
bool exception_list);
bool ReadCompressedData(u32 decompressed_data_size, u64 data_offset, u64 data_size,
u64 offset_in_data, u64 size_in_data, u8* out_ptr, bool exception_list);
// Returns the number of bytes read
std::optional<u64> ReadExceptionListFromFile();
static std::string VersionToString(u32 version);
using SHA1 = std::array<u8, 20>; using SHA1 = std::array<u8, 20>;
using WiiKey = std::array<u8, 16>; using WiiKey = std::array<u8, 16>;
enum class CompressionType : u32
{
None = 0,
Purge = 1,
Bzip2 = 2,
LZMA = 3,
LZMA2 = 4,
};
#pragma pack(push, 1) #pragma pack(push, 1)
struct WIAHeader1 struct WIAHeader1
{ {
@ -148,13 +142,10 @@ private:
static_assert(sizeof(PurgeSegment) == 0x08, "Wrong size for WIA purge segment"); static_assert(sizeof(PurgeSegment) == 0x08, "Wrong size for WIA purge segment");
#pragma pack(pop) #pragma pack(pop)
enum class CompressionType : u32 struct DecompressionBuffer
{ {
None = 0, std::vector<u8> data;
Purge = 1, size_t bytes_written = 0;
Bzip2 = 2,
LZMA = 3,
LZMA2 = 4,
}; };
class Decompressor class Decompressor
@ -162,18 +153,36 @@ private:
public: public:
virtual ~Decompressor(); virtual ~Decompressor();
// Specifies the compressed data to read. The data must still be in memory when calling Read. virtual bool Decompress(const DecompressionBuffer& in, DecompressionBuffer* out,
virtual bool Start(const u8* in_ptr, u64 size) = 0; size_t* in_bytes_read) = 0;
virtual bool Done() const { return m_done; };
// Reads the specified number of bytes into out_ptr (or less, if there aren't that many bytes protected:
// to output). Returns the number of bytes read. Start must be called before this. bool m_done = false;
virtual u64 Read(u8* out_ptr, u64 size) = 0; };
// Returns whether every byte of the input data has been read. class NoneDecompressor final : public Decompressor
virtual bool DoneReading() const = 0; {
public:
bool Decompress(const DecompressionBuffer& in, DecompressionBuffer* out,
size_t* in_bytes_read) override;
};
// Will be called automatically upon destruction, but can be called earlier if desired. // This class assumes that more bytes won't be added to in once in.bytes_written == in.data.size()
virtual void End() = 0; class PurgeDecompressor final : public Decompressor
{
public:
PurgeDecompressor(u64 decompressed_size);
bool Decompress(const DecompressionBuffer& in, DecompressionBuffer* out,
size_t* in_bytes_read) override;
private:
PurgeSegment m_segment = {};
size_t m_bytes_read = 0;
size_t m_segment_bytes_written = 0;
size_t m_out_bytes_written = 0;
const u64 m_decompressed_size;
}; };
class Bzip2Decompressor final : public Decompressor class Bzip2Decompressor final : public Decompressor
@ -181,16 +190,12 @@ private:
public: public:
~Bzip2Decompressor(); ~Bzip2Decompressor();
bool Start(const u8* in_ptr, u64 size) override; bool Decompress(const DecompressionBuffer& in, DecompressionBuffer* out,
u64 Read(u8* out_ptr, u64 size) override; size_t* in_bytes_read) override;
bool DoneReading() const override;
void End() override;
private: private:
bz_stream m_stream; bz_stream m_stream = {};
bool m_started = false; bool m_started = false;
bool m_ended = false;
bool m_error_occurred = false;
}; };
class LZMADecompressor final : public Decompressor class LZMADecompressor final : public Decompressor
@ -199,24 +204,63 @@ private:
LZMADecompressor(bool lzma2, const u8* filter_options, size_t filter_options_size); LZMADecompressor(bool lzma2, const u8* filter_options, size_t filter_options_size);
~LZMADecompressor(); ~LZMADecompressor();
bool Start(const u8* in_ptr, u64 size) override; bool Decompress(const DecompressionBuffer& in, DecompressionBuffer* out,
u64 Read(u8* out_ptr, u64 size) override; size_t* in_bytes_read) override;
bool DoneReading() const override;
void End() override;
private: private:
lzma_stream m_stream = LZMA_STREAM_INIT; lzma_stream m_stream = LZMA_STREAM_INIT;
lzma_options_lzma m_options = {}; lzma_options_lzma m_options = {};
lzma_filter m_filters[2]; lzma_filter m_filters[2];
bool m_started = false; bool m_started = false;
bool m_ended = false;
bool m_error_occurred = false; bool m_error_occurred = false;
}; };
class Chunk
{
public:
Chunk();
Chunk(File::IOFile* file, u64 offset_in_file, u64 compressed_size, u64 decompressed_size,
bool exception_list, bool compressed_exception_list,
std::unique_ptr<Decompressor> decompressor);
bool Read(u64 offset, u64 size, u8* out_ptr);
template <typename T>
bool ReadAll(std::vector<T>* vector)
{
return Read(0, vector->size() * sizeof(T), reinterpret_cast<u8*>(vector->data()));
}
private:
DecompressionBuffer m_in;
DecompressionBuffer m_out;
DecompressionBuffer m_exceptions;
size_t m_in_bytes_read = 0;
std::unique_ptr<Decompressor> m_decompressor = nullptr;
File::IOFile* m_file = nullptr;
u64 m_offset_in_file = 0;
bool m_exception_list = false;
bool m_compressed_exception_list = false;
};
explicit WIAFileReader(File::IOFile file, const std::string& path);
bool Initialize(const std::string& path);
bool ReadFromGroups(u64* offset, u64* size, u8** out_ptr, u64 chunk_size, u32 sector_size,
u64 data_offset, u64 data_size, u32 group_index, u32 number_of_groups,
bool exception_list);
Chunk& ReadCompressedData(u64 offset_in_file, u64 compressed_size, u64 decompressed_size,
bool exception_list);
static std::string VersionToString(u32 version);
bool m_valid; bool m_valid;
CompressionType m_compression_type; CompressionType m_compression_type;
File::IOFile m_file; File::IOFile m_file;
Chunk m_cached_chunk;
u64 m_cached_chunk_offset = std::numeric_limits<u64>::max();
WIAHeader1 m_header_1; WIAHeader1 m_header_1;
WIAHeader2 m_header_2; WIAHeader2 m_header_2;