From 1eaa9380dd64ec5f501bcd2437cb65309bd81154 Mon Sep 17 00:00:00 2001 From: sowens99 Date: Mon, 2 Oct 2023 00:56:35 -0400 Subject: [PATCH] Savestates: Use LZ4 algorithm for faster decompression --- Source/Core/Core/State.cpp | 456 ++++++++++++++++++++++++++----------- Source/Core/Core/State.h | 61 ++++- 2 files changed, 370 insertions(+), 147 deletions(-) diff --git a/Source/Core/Core/State.cpp b/Source/Core/Core/State.cpp index c298dd121e..ce0c636f66 100644 --- a/Source/Core/Core/State.cpp +++ b/Source/Core/Core/State.cpp @@ -16,6 +16,7 @@ #include +#include #include #include "Common/ChunkFile.h" @@ -60,11 +61,6 @@ static const u32 OUT_LEN = IN_LEN + (IN_LEN / 16) + 64 + 3; static unsigned char __LZO_MMODEL out[OUT_LEN]; -#define HEAP_ALLOC(var, size) \ - lzo_align_t __LZO_MMODEL var[((size) + (sizeof(lzo_align_t) - 1)) / sizeof(lzo_align_t)] - -static HEAP_ALLOC(wrkmem, LZO1X_1_MEM_COMPRESS); - static AfterLoadCallbackFunc s_on_after_load_callback; // Temporary undo state buffer @@ -96,7 +92,15 @@ static size_t s_state_writes_in_queue; static std::condition_variable s_state_write_queue_is_empty; // Don't forget to increase this after doing changes on the savestate system -constexpr u32 STATE_VERSION = 162; // Last changed in PR 11767 +constexpr u32 STATE_VERSION = 163; // Last changed in PR 12217 + +// Increase this if the StateExtendedHeader definition changes +constexpr u32 EXTENDED_HEADER_VERSION = 1; // Last changed in PR 12217 + +// Change this if we ever need to store more data in the extended header +constexpr u32 COMPRESSED_DATA_OFFSET = 0; + +constexpr u32 COOKIE_BASE = 0xBAADBABE; // Maps savestate versions to Dolphin versions. // Versions after 42 don't need to be added to this list, @@ -128,60 +132,8 @@ void EnableCompression(bool compression) s_use_compression = compression; } -// Returns true if state version matches current Dolphin state version, false otherwise. -static bool DoStateVersion(PointerWrap& p, std::string* version_created_by) -{ - u32 version = STATE_VERSION; - { - static const u32 COOKIE_BASE = 0xBAADBABE; - u32 cookie = version + COOKIE_BASE; - p.Do(cookie); - version = cookie - COOKIE_BASE; - } - - *version_created_by = Common::GetScmRevStr(); - if (version > 42) - p.Do(*version_created_by); - else - version_created_by->clear(); - - if (version != STATE_VERSION) - { - if (version_created_by->empty() && s_old_versions.count(version)) - { - // The savestate is from an old version that doesn't - // save the Dolphin version number to savestates, but - // by looking up the savestate version number, it is possible - // to know approximately which Dolphin version was used. - - std::pair version_range = s_old_versions.find(version)->second; - std::string oldest_version = version_range.first; - std::string newest_version = version_range.second; - - *version_created_by = "Dolphin " + oldest_version + " - " + newest_version; - } - - return false; - } - - p.DoMarker("Version"); - return true; -} - static void DoState(PointerWrap& p) { - std::string version_created_by; - if (!DoStateVersion(p, &version_created_by)) - { - const std::string message = - version_created_by.empty() ? - "This savestate was created using an incompatible version of Dolphin" : - "This savestate was created using the incompatible version " + version_created_by; - Core::DisplayMessage(message, OSD::Duration::NORMAL); - p.SetMeasureMode(); - return; - } - bool is_wii = SConfig::GetInstance().bWii || SConfig::GetInstance().m_is_mios; const bool is_wii_currently = is_wii; p.Do(is_wii); @@ -341,7 +293,7 @@ static std::map GetSavedStates() { if (ReadHeader(filename, header)) { - double d = GetSystemTimeAsDouble() - header.time; + double d = GetSystemTimeAsDouble() - header.legacy_header.time; // increase time until unique value is obtained while (m.find(d) != m.end()) @@ -354,6 +306,72 @@ static std::map GetSavedStates() return m; } +static void CompressBufferToFile(const u8* raw_buffer, u64 size, File::IOFile& f) +{ + u64 total_bytes_compressed = 0; + + while (true) + { + u64 bytes_left_to_compress = size - total_bytes_compressed; + + int bytes_to_compress = + static_cast(std::min(static_cast(LZ4_MAX_INPUT_SIZE), bytes_left_to_compress)); + int compressed_buffer_size = LZ4_compressBound(bytes_to_compress); + auto compressed_buffer = std::make_unique(compressed_buffer_size); + s32 compressed_len = + LZ4_compress_default(reinterpret_cast(raw_buffer) + total_bytes_compressed, + compressed_buffer.get(), bytes_to_compress, compressed_buffer_size); + + if (compressed_len == 0) + { + PanicAlertFmtT("Internal LZ4 Error - compression failed"); + break; + } + + // The size of the data to write is 'compressed_len' + f.WriteArray(&compressed_len, 1); + f.WriteBytes(compressed_buffer.get(), compressed_len); + + total_bytes_compressed += bytes_to_compress; + if (total_bytes_compressed == size) + break; + } +} + +static void CreateExtendedHeader(StateExtendedHeader& extended_header, size_t uncompressed_size) +{ + StateExtendedBaseHeader& base_header = extended_header.base_header; + base_header.header_version = EXTENDED_HEADER_VERSION; + base_header.compression_type = + s_use_compression ? CompressionType::LZ4 : CompressionType::Uncompressed; + base_header.payload_offset = COMPRESSED_DATA_OFFSET; + base_header.uncompressed_size = uncompressed_size; + + // If more fields are added to StateExtendedHeader, set them here. +} + +static void WriteHeadersToFile(size_t uncompressed_size, File::IOFile& f) +{ + StateHeader header{}; + SConfig::GetInstance().GetGameID().copy(header.legacy_header.game_id, + std::size(header.legacy_header.game_id)); + header.legacy_header.time = GetSystemTimeAsDouble(); + + header.version_header.version_cookie = COOKIE_BASE + STATE_VERSION; + header.version_string = Common::GetScmRevStr(); + header.version_header.version_string_length = static_cast(header.version_string.length()); + + StateExtendedHeader extended_header{}; + CreateExtendedHeader(extended_header, uncompressed_size); + + f.WriteArray(&header.legacy_header, 1); + f.WriteArray(&header.version_header, 1); + f.WriteString(header.version_string); + + f.WriteArray(&extended_header.base_header, 1); + // If StateExtendedHeader is amended to include more than the base, add WriteBytes() calls here. +} + static void CompressAndDumpState(CompressAndDumpState_args& save_args) { const u8* const buffer_data = save_args.buffer_vector.data(); @@ -378,48 +396,12 @@ static void CompressAndDumpState(CompressAndDumpState_args& save_args) return; } - // Setting up the header - StateHeader header{}; - SConfig::GetInstance().GetGameID().copy(header.gameID, std::size(header.gameID)); - header.size = s_use_compression ? (u32)buffer_size : 0; - header.time = GetSystemTimeAsDouble(); + WriteHeadersToFile(buffer_size, f); - f.WriteArray(&header, 1); - - if (header.size != 0) // non-zero header size means the state is compressed - { - lzo_uint i = 0; - while (true) - { - lzo_uint32 cur_len = 0; - lzo_uint out_len = 0; - - if ((i + IN_LEN) >= buffer_size) - { - cur_len = (lzo_uint32)(buffer_size - i); - } - else - { - cur_len = IN_LEN; - } - - if (lzo1x_1_compress(buffer_data + i, cur_len, out, &out_len, wrkmem) != LZO_E_OK) - PanicAlertFmtT("Internal LZO Error - compression failed"); - - // The size of the data to write is 'out_len' - f.WriteArray((lzo_uint32*)&out_len, 1); - f.WriteBytes(out, out_len); - - if (cur_len != IN_LEN) - break; - - i += cur_len; - } - } - else // uncompressed - { + if (s_use_compression) + CompressBufferToFile(buffer_data, buffer_size, f); + else f.WriteBytes(buffer_data, buffer_size); - } const std::string last_state_filename = File::GetUserPath(D_STATESAVES_IDX) + "lastState.sav"; const std::string last_state_dtmname = last_state_filename + ".dtm"; @@ -525,13 +507,108 @@ void SaveAs(const std::string& filename, bool wait) true); } +static bool GetVersionFromLZO(StateHeader& header, File::IOFile& f) +{ + // Just read the first block, since it will contain the full revision string + lzo_uint32 cur_len = 0; // size of compressed bytes + lzo_uint new_len = 0; // size of uncompressed bytes + std::vector buffer; + buffer.resize(header.legacy_header.lzo_size); + + if (!f.ReadArray(&cur_len, 1) || !f.ReadBytes(out, cur_len)) + return false; + + const int res = lzo1x_decompress(out, cur_len, buffer.data(), &new_len, nullptr); + if (res != LZO_E_OK) + { + // This doesn't seem to happen anymore. + PanicAlertFmtT("Internal LZO Error - decompression failed ({0}) ({1}) \n" + "Unable to retrieve outdated savestate version info.", + res, new_len); + return false; + } + + // Read in cookie and string length + if (buffer.size() >= sizeof(StateHeaderVersion)) + { + memcpy(&header.version_header, buffer.data(), sizeof(StateHeaderVersion)); + } + else + { + PanicAlertFmtT("Internal LZO Error - failed to parse decompressed version cookie and version " + "string length ({0})", + buffer.size()); + return false; + } + + // Read in the string + if (buffer.size() >= sizeof(StateHeaderVersion) + header.version_header.version_string_length) + { + auto version_buffer = std::make_unique(header.version_header.version_string_length); + memcpy(version_buffer.get(), buffer.data() + sizeof(StateHeaderVersion), + header.version_header.version_string_length); + header.version_string = + std::string(version_buffer.get(), header.version_header.version_string_length); + } + else + { + PanicAlertFmtT("Internal LZO Error - failed to parse decompressed version string ({0} / {1})", + header.version_header.version_string_length, buffer.size()); + return false; + } + + return true; +} + +static bool ReadStateHeaderFromFile(StateHeader& header, File::IOFile& f) +{ + if (!f.IsOpen()) + { + Core::DisplayMessage("State not found", 2000); + return false; + } + + if (!f.ReadArray(&header.legacy_header, 1)) + { + Core::DisplayMessage("Failed to read state legacy header", 2000); + return false; + } + + if (header.legacy_header.lzo_size != 0) + { + // Parse out version from legacy LZO compressed states + if (!GetVersionFromLZO(header, f)) + return false; + } + else + { + if (!f.ReadArray(&header.version_header, 1)) + { + Core::DisplayMessage("Failed to read state version header", 2000); + return false; + } + + auto version_buffer = std::make_unique(header.version_header.version_string_length); + if (!f.ReadBytes(version_buffer.get(), header.version_header.version_string_length)) + { + Core::DisplayMessage("Failed to read state version string", 2000); + return false; + } + + header.version_string = + std::string(version_buffer.get(), header.version_header.version_string_length); + } + + return true; +} + bool ReadHeader(const std::string& filename, StateHeader& header) { // ensure that the savestate write thread isn't moving around states while we do this std::lock_guard lk(s_save_thread_mutex); File::IOFile f(filename, "rb"); - return f.ReadArray(&header, 1); + return ReadStateHeaderFromFile(header, f); } std::string GetInfoStringOfSlot(int slot, bool translate) @@ -544,7 +621,7 @@ std::string GetInfoStringOfSlot(int slot, bool translate) if (!ReadHeader(filename, header)) return translate ? Common::GetStringT("Unknown") : "Unknown"; - return SystemTimeAsDoubleToString(header.time); + return SystemTimeAsDoubleToString(header.legacy_header.time); } u64 GetUnixTimeOfSlot(int slot) @@ -554,7 +631,111 @@ u64 GetUnixTimeOfSlot(int slot) return 0; constexpr u64 MS_PER_SEC = 1000; - return static_cast(header.time * MS_PER_SEC) + (DOUBLE_TIME_OFFSET * MS_PER_SEC); + return static_cast(header.legacy_header.time * MS_PER_SEC) + + (DOUBLE_TIME_OFFSET * MS_PER_SEC); +} + +static bool DecompressLZ4(std::vector& raw_buffer, u64 size, File::IOFile& f) +{ + raw_buffer.resize(size); + + u64 total_bytes_read = 0; + while (true) + { + s32 compressed_data_len; + if (!f.ReadArray(&compressed_data_len, 1)) + { + PanicAlertFmt("Could not read state data length"); + return false; + } + + if (compressed_data_len <= 0) + { + PanicAlertFmtT("Internal LZ4 Error - Tried decompressing {0} bytes", compressed_data_len); + return false; + } + + auto compressed_data = std::make_unique(compressed_data_len); + if (!f.ReadBytes(compressed_data.get(), compressed_data_len)) + { + PanicAlertFmt("Could not read state data"); + return false; + } + + u32 max_decompress_size = + static_cast(std::min((u64)LZ4_MAX_INPUT_SIZE, size - total_bytes_read)); + + int bytes_read = LZ4_decompress_safe( + compressed_data.get(), reinterpret_cast(raw_buffer.data()) + total_bytes_read, + compressed_data_len, max_decompress_size); + + if (bytes_read < 0) + { + PanicAlertFmtT("Internal LZ4 Error - decompression failed ({0}, {1}, {2})", bytes_read, + compressed_data_len, max_decompress_size); + return false; + } + + total_bytes_read += static_cast(bytes_read); + + if (total_bytes_read == size) + { + return true; + } + else if (total_bytes_read > size) + { + PanicAlertFmtT("Internal LZ4 Error - payload size mismatch ({0} / {1}))", total_bytes_read, + size); + return false; + } + } +} + +static bool ValidateHeaders(const StateHeader& header) +{ + bool success = true; + + // Game ID + if (strncmp(SConfig::GetInstance().GetGameID().c_str(), header.legacy_header.game_id, 6)) + { + Core::DisplayMessage(fmt::format("State belongs to a different game (ID {})", + std::string_view{header.legacy_header.game_id, + std::size(header.legacy_header.game_id)}), + 2000); + return false; + } + + // Check both the state version and the revision string + std::string current_str = Common::GetScmRevStr(); + std::string loaded_str = header.version_string; + const u32 loaded_version = header.version_header.version_cookie - COOKIE_BASE; + + if (s_old_versions.count(loaded_version)) + { + // This is a REALLY old version, before we started writing the version string to file + success = false; + + std::pair version_range = s_old_versions.find(loaded_version)->second; + std::string oldest_version = version_range.first; + std::string newest_version = version_range.second; + + loaded_str = "Dolphin " + oldest_version + " - " + newest_version; + } + else if (loaded_version != STATE_VERSION) + { + success = false; + } + + if (!success) + { + const std::string message = + loaded_str.empty() ? + "This savestate was created using an incompatible version of Dolphin" : + "This savestate was created using the incompatible version " + loaded_str; + Core::DisplayMessage(message, OSD::Duration::NORMAL); + } + + return success; } static void LoadFileStateData(const std::string& filename, std::vector& ret_data) @@ -578,61 +759,61 @@ static void LoadFileStateData(const std::string& filename, std::vector& ret_ } StateHeader header; - if (!f.ReadArray(&header, 1)) + if (!ReadStateHeaderFromFile(header, f) || !ValidateHeaders(header)) + return; + + StateExtendedHeader extended_header; + if (!f.ReadArray(&extended_header.base_header, 1)) { - Core::DisplayMessage("State not found", 2000); + PanicAlertFmt("Unable to read state header"); return; } + // If StateExtendedHeader is amended to include more than the base, add ReadBytes() calls here. - if (strncmp(SConfig::GetInstance().GetGameID().c_str(), header.gameID, 6)) + if (extended_header.base_header.header_version != EXTENDED_HEADER_VERSION) { - Core::DisplayMessage(fmt::format("State belongs to a different game (ID {})", - std::string_view{header.gameID, std::size(header.gameID)}), - 2000); + PanicAlertFmt("State header corrupted"); return; } std::vector buffer; - if (header.size != 0) // non-zero size means the state is compressed + switch (extended_header.base_header.compression_type) + { + case CompressionType::LZ4: { Core::DisplayMessage("Decompressing State...", 500); + if (!DecompressLZ4(buffer, extended_header.base_header.uncompressed_size, f)) + return; - buffer.resize(header.size); - - lzo_uint i = 0; - while (true) - { - lzo_uint32 cur_len = 0; // number of bytes to read - lzo_uint new_len = 0; // number of bytes to write - - if (!f.ReadArray(&cur_len, 1)) - break; - - f.ReadBytes(out, cur_len); - const int res = lzo1x_decompress(out, cur_len, &buffer[i], &new_len, nullptr); - if (res != LZO_E_OK) - { - // This doesn't seem to happen anymore. - PanicAlertFmtT("Internal LZO Error - decompression failed ({0}) ({1}, {2}) \n" - "Try loading the state again", - res, i, new_len); - return; - } - - i += new_len; - } + break; } - else // uncompressed + case CompressionType::Uncompressed: { - const auto size = static_cast(f.GetSize() - sizeof(StateHeader)); + u64 header_len = sizeof(StateHeaderLegacy) + sizeof(StateHeaderVersion) + + header.version_header.version_string_length + sizeof(StateExtendedBaseHeader) + + extended_header.base_header.payload_offset; + + u64 file_size = f.GetSize(); + if (file_size < header_len) + { + PanicAlertFmt("State header length corrupted"); + return; + } + + const auto size = static_cast(file_size - header_len); buffer.resize(size); - if (!f.ReadBytes(&buffer[0], size)) + if (!f.ReadBytes(buffer.data(), size)) { PanicAlertFmt("Error reading bytes: {0}", size); return; } + break; + } + default: + PanicAlertFmt("Unknown compression type {0}", extended_header.base_header.compression_type); + return; } // all good @@ -721,9 +902,6 @@ void SetOnAfterLoadCallback(AfterLoadCallbackFunc callback) void Init() { - if (lzo_init() != LZO_E_OK) - PanicAlertFmtT("Internal LZO Error - lzo_init() failed"); - s_save_thread.Reset("Savestate Worker", [](CompressAndDumpState_args args) { CompressAndDumpState(args); diff --git a/Source/Core/Core/State.h b/Source/Core/Core/State.h index 7aa75b2a75..7d592b2d88 100644 --- a/Source/Core/Core/State.h +++ b/Source/Core/Core/State.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #include "Common/CommonTypes.h" @@ -17,18 +18,62 @@ namespace State // number of states static const u32 NUM_STATES = 10; -struct StateHeader +struct StateHeaderLegacy { - char gameID[6]; - u16 reserved1; - u32 size; - u32 reserved2; + char game_id[6]; + char reserved1[2]; + u32 lzo_size = 0; // Must be zero for new states. Used to support legacy decompression algorithm. + char reserved2[4]; double time; }; -constexpr size_t STATE_HEADER_SIZE = sizeof(StateHeader); +constexpr size_t STATE_HEADER_SIZE = sizeof(StateHeaderLegacy); static_assert(STATE_HEADER_SIZE == 24); -static_assert(offsetof(StateHeader, size) == 8); -static_assert(offsetof(StateHeader, time) == 16); +static_assert(offsetof(StateHeaderLegacy, lzo_size) == 8); +static_assert(offsetof(StateHeaderLegacy, time) == 16); +static_assert(std::is_trivially_copyable_v); + +struct StateHeaderVersion +{ + u32 version_cookie; + u32 version_string_length; +}; +static_assert(std::is_trivially_copyable_v); + +struct StateHeader +{ + StateHeaderLegacy legacy_header; + StateHeaderVersion version_header; + std::string version_string; +}; + +enum CompressionType : u16 +{ + Uncompressed = 0, + LZ4 = 1, + // Add new compression types after this, as the compression type + // is numerically stored in the state file. +}; + +struct StateExtendedBaseHeader +{ + u16 header_version; + u16 compression_type; + u32 payload_offset; + u64 uncompressed_size; +}; +constexpr size_t EXTENDED_BASE_HEADER_SIZE = sizeof(StateExtendedBaseHeader); +static_assert(EXTENDED_BASE_HEADER_SIZE == 16); +static_assert(offsetof(StateExtendedBaseHeader, payload_offset) == 4); +static_assert(offsetof(StateExtendedBaseHeader, uncompressed_size) == 8); +static_assert(std::is_trivially_copyable_v); + +struct StateExtendedHeader +{ + StateExtendedBaseHeader base_header; + // Feel free to add new fields here, adjusting COMPRESSED_DATA_OFFSET accordingly, as well as + // CreateExtendedHeader(). Add the appropriate IOFile read/write calls within LoadFileStateData() + // and WriteHeadersToFile() +}; void Init();