From d5fb5645fcdc838c712d9eb0308f94a625d99fda Mon Sep 17 00:00:00 2001 From: Stenzek Date: Mon, 5 Feb 2024 14:26:42 +1000 Subject: [PATCH] ByteStream: Move zstd wrapper to util project Removes zstd dependency from common, fixes updater running on Mac. --- src/common/CMakeLists.txt | 2 +- src/common/byte_stream.cpp | 312 +-------------------------------- src/common/byte_stream.h | 6 +- src/util/CMakeLists.txt | 1 + src/util/util.vcxproj | 1 + src/util/util.vcxproj.filters | 1 + src/util/zstd_byte_stream.cpp | 317 ++++++++++++++++++++++++++++++++++ 7 files changed, 327 insertions(+), 313 deletions(-) create mode 100644 src/util/zstd_byte_stream.cpp diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index 35f75a128..a2143bc23 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -61,7 +61,7 @@ add_library(common target_include_directories(common PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/..") target_include_directories(common PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/..") target_link_libraries(common PUBLIC fmt Threads::Threads fast_float) -target_link_libraries(common PRIVATE stb ZLIB::ZLIB minizip Zstd::Zstd "${CMAKE_DL_LIBS}") +target_link_libraries(common PRIVATE stb ZLIB::ZLIB minizip "${CMAKE_DL_LIBS}") if(WIN32) target_sources(common PRIVATE diff --git a/src/common/byte_stream.cpp b/src/common/byte_stream.cpp index 3b0de49ab..af6610277 100644 --- a/src/common/byte_stream.cpp +++ b/src/common/byte_stream.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2019-2022 Connor McLaughlin +// SPDX-FileCopyrightText: 2019-2024 Connor McLaughlin // SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0) #include "byte_stream.h" @@ -6,8 +6,7 @@ #include "file_system.h" #include "log.h" #include "string_util.h" -#include "zstd.h" -#include "zstd_errors.h" + #include #include #include @@ -1345,310 +1344,3 @@ bool ByteStream::WriteBinaryToStream(ByteStream* stream, const void* data, size_ return stream->Write2(data, static_cast(data_length)); } - -namespace { -class ZstdCompressStream final : public ByteStream -{ -public: - ZstdCompressStream(ByteStream* dst_stream, int compression_level) : m_dst_stream(dst_stream) - { - m_cstream = ZSTD_createCStream(); - ZSTD_CCtx_setParameter(m_cstream, ZSTD_c_compressionLevel, compression_level); - } - - ~ZstdCompressStream() override - { - if (!m_done) - Compress(ZSTD_e_end); - - ZSTD_freeCStream(m_cstream); - } - - bool ReadByte(u8* pDestByte) override { return false; } - - u32 Read(void* pDestination, u32 ByteCount) override { return 0; } - - bool Read2(void* pDestination, u32 ByteCount, u32* pNumberOfBytesRead = nullptr) override { return false; } - - bool WriteByte(u8 SourceByte) override - { - if (m_input_buffer_wpos == INPUT_BUFFER_SIZE && !Compress(ZSTD_e_continue)) - return false; - - m_input_buffer[m_input_buffer_wpos++] = SourceByte; - return true; - } - - u32 Write(const void* pSource, u32 ByteCount) override - { - u32 remaining = ByteCount; - const u8* read_ptr = static_cast(pSource); - for (;;) - { - const u32 copy_size = std::min(INPUT_BUFFER_SIZE - m_input_buffer_wpos, remaining); - std::memcpy(&m_input_buffer[m_input_buffer_wpos], read_ptr, copy_size); - read_ptr += copy_size; - remaining -= copy_size; - m_input_buffer_wpos += copy_size; - if (remaining == 0 || !Compress(ZSTD_e_continue)) - break; - } - - return ByteCount - remaining; - } - - bool Write2(const void* pSource, u32 ByteCount, u32* pNumberOfBytesWritten = nullptr) override - { - const u32 bytes_written = Write(pSource, ByteCount); - if (pNumberOfBytesWritten) - *pNumberOfBytesWritten = bytes_written; - return (bytes_written == ByteCount); - } - - bool SeekAbsolute(u64 Offset) override { return false; } - - bool SeekRelative(s64 Offset) override { return (Offset == 0); } - - bool SeekToEnd() override { return false; } - - u64 GetPosition() const override { return m_position; } - - u64 GetSize() const override { return 0; } - - bool Flush() override { return Compress(ZSTD_e_flush); } - - bool Discard() override { return true; } - - bool Commit() override { return Compress(ZSTD_e_end); } - -private: - enum : u32 - { - INPUT_BUFFER_SIZE = 131072, - OUTPUT_BUFFER_SIZE = 65536, - }; - - bool Compress(ZSTD_EndDirective action) - { - if (m_errorState || m_done) - return false; - - ZSTD_inBuffer inbuf = {m_input_buffer, m_input_buffer_wpos, 0}; - - for (;;) - { - ZSTD_outBuffer outbuf = {m_output_buffer, OUTPUT_BUFFER_SIZE, 0}; - - const size_t ret = ZSTD_compressStream2(m_cstream, &outbuf, &inbuf, action); - if (ZSTD_isError(ret)) - { - Log_ErrorPrintf("ZSTD_compressStream2() failed: %u (%s)", static_cast(ZSTD_getErrorCode(ret)), - ZSTD_getErrorString(ZSTD_getErrorCode(ret))); - SetErrorState(); - return false; - } - - if (outbuf.pos > 0) - { - if (!m_dst_stream->Write2(m_output_buffer, static_cast(outbuf.pos))) - { - SetErrorState(); - return false; - } - - outbuf.pos = 0; - } - - if (action == ZSTD_e_end) - { - // break when compression output has finished - if (ret == 0) - { - m_done = true; - break; - } - } - else - { - // break when all input data is consumed - if (inbuf.pos == inbuf.size) - break; - } - } - - m_position += m_input_buffer_wpos; - m_input_buffer_wpos = 0; - return true; - } - - ByteStream* m_dst_stream; - ZSTD_CStream* m_cstream = nullptr; - u64 m_position = 0; - u32 m_input_buffer_wpos = 0; - bool m_done = false; - - u8 m_input_buffer[INPUT_BUFFER_SIZE]; - u8 m_output_buffer[OUTPUT_BUFFER_SIZE]; -}; -} // namespace - -std::unique_ptr ByteStream::CreateZstdCompressStream(ByteStream* src_stream, int compression_level) -{ - return std::make_unique(src_stream, compression_level); -} - -namespace { -class ZstdDecompressStream final : public ByteStream -{ -public: - ZstdDecompressStream(ByteStream* src_stream, u32 compressed_size) - : m_src_stream(src_stream), m_bytes_remaining(compressed_size) - { - m_cstream = ZSTD_createDStream(); - m_in_buffer.src = m_input_buffer; - Decompress(); - } - - ~ZstdDecompressStream() override { ZSTD_freeDStream(m_cstream); } - - bool ReadByte(u8* pDestByte) override { return Read(pDestByte, sizeof(u8)) == sizeof(u8); } - - u32 Read(void* pDestination, u32 ByteCount) override - { - u8* write_ptr = static_cast(pDestination); - u32 remaining = ByteCount; - for (;;) - { - const u32 copy_size = std::min(m_output_buffer_wpos - m_output_buffer_rpos, remaining); - std::memcpy(write_ptr, &m_output_buffer[m_output_buffer_rpos], copy_size); - m_output_buffer_rpos += copy_size; - write_ptr += copy_size; - remaining -= copy_size; - if (remaining == 0 || !Decompress()) - break; - } - - return ByteCount - remaining; - } - - bool Read2(void* pDestination, u32 ByteCount, u32* pNumberOfBytesRead = nullptr) override - { - const u32 bytes_read = Read(pDestination, ByteCount); - if (pNumberOfBytesRead) - *pNumberOfBytesRead = bytes_read; - return (bytes_read == ByteCount); - } - - bool WriteByte(u8 SourceByte) override { return false; } - - u32 Write(const void* pSource, u32 ByteCount) override { return 0; } - - bool Write2(const void* pSource, u32 ByteCount, u32* pNumberOfBytesWritten = nullptr) override { return false; } - - bool SeekAbsolute(u64 Offset) override { return false; } - - bool SeekRelative(s64 Offset) override - { - if (Offset < 0) - return false; - else if (Offset == 0) - return true; - - s64 remaining = Offset; - for (;;) - { - const s64 skip = std::min(m_output_buffer_wpos - m_output_buffer_rpos, remaining); - remaining -= skip; - m_output_buffer_rpos += static_cast(skip); - if (remaining == 0) - return true; - else if (!Decompress()) - return false; - } - } - - bool SeekToEnd() override { return false; } - - u64 GetPosition() const override { return 0; } - - u64 GetSize() const override { return 0; } - - bool Flush() override { return true; } - - bool Discard() override { return true; } - - bool Commit() override { return true; } - -private: - enum : u32 - { - INPUT_BUFFER_SIZE = 65536, - OUTPUT_BUFFER_SIZE = 131072, - }; - - bool Decompress() - { - if (m_output_buffer_rpos != m_output_buffer_wpos) - { - const u32 move_size = m_output_buffer_wpos - m_output_buffer_rpos; - std::memmove(&m_output_buffer[0], &m_output_buffer[m_output_buffer_rpos], move_size); - m_output_buffer_rpos = move_size; - m_output_buffer_wpos = move_size; - } - else - { - m_output_buffer_rpos = 0; - m_output_buffer_wpos = 0; - } - - ZSTD_outBuffer outbuf = {m_output_buffer, OUTPUT_BUFFER_SIZE - m_output_buffer_wpos, 0}; - while (outbuf.pos == 0) - { - if (m_in_buffer.pos == m_in_buffer.size && !m_errorState) - { - const u32 requested_size = std::min(m_bytes_remaining, INPUT_BUFFER_SIZE); - const u32 bytes_read = m_src_stream->Read(m_input_buffer, requested_size); - m_in_buffer.size = bytes_read; - m_in_buffer.pos = 0; - m_bytes_remaining -= bytes_read; - if (bytes_read != requested_size || m_bytes_remaining == 0) - { - m_errorState = true; - break; - } - } - - size_t ret = ZSTD_decompressStream(m_cstream, &outbuf, &m_in_buffer); - if (ZSTD_isError(ret)) - { - Log_ErrorPrintf("ZSTD_decompressStream() failed: %u (%s)", static_cast(ZSTD_getErrorCode(ret)), - ZSTD_getErrorString(ZSTD_getErrorCode(ret))); - m_in_buffer.pos = m_in_buffer.size; - m_output_buffer_rpos = 0; - m_output_buffer_wpos = 0; - m_errorState = true; - return false; - } - } - - m_output_buffer_wpos = static_cast(outbuf.pos); - return true; - } - - ByteStream* m_src_stream; - ZSTD_DStream* m_cstream = nullptr; - ZSTD_inBuffer m_in_buffer = {}; - u32 m_output_buffer_rpos = 0; - u32 m_output_buffer_wpos = 0; - u32 m_bytes_remaining; - bool m_errorState = false; - - u8 m_input_buffer[INPUT_BUFFER_SIZE]; - u8 m_output_buffer[OUTPUT_BUFFER_SIZE]; -}; -} // namespace - -std::unique_ptr ByteStream::CreateZstdDecompressStream(ByteStream* src_stream, u32 compressed_size) -{ - return std::make_unique(src_stream, compressed_size); -} diff --git a/src/common/byte_stream.h b/src/common/byte_stream.h index 125888d0e..a9506e95b 100644 --- a/src/common/byte_stream.h +++ b/src/common/byte_stream.h @@ -1,8 +1,10 @@ -// SPDX-FileCopyrightText: 2019-2022 Connor McLaughlin +// SPDX-FileCopyrightText: 2019-2024 Connor McLaughlin // SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0) #pragma once + #include "types.h" + #include #include #include @@ -120,7 +122,7 @@ public: // null memory stream static std::unique_ptr CreateNullStream(); - // zstd stream + // zstd stream, actually defined in util/zstd_byte_stream.cpp, to avoid common dependency on libzstd static std::unique_ptr CreateZstdCompressStream(ByteStream* src_stream, int compression_level); static std::unique_ptr CreateZstdDecompressStream(ByteStream* src_stream, u32 compressed_size); diff --git a/src/util/CMakeLists.txt b/src/util/CMakeLists.txt index d7862a42a..86149aa2a 100644 --- a/src/util/CMakeLists.txt +++ b/src/util/CMakeLists.txt @@ -67,6 +67,7 @@ add_library(util wav_writer.h window_info.cpp window_info.h + zstd_byte_stream.cpp ) target_precompile_headers(util PRIVATE "pch.h") diff --git a/src/util/util.vcxproj b/src/util/util.vcxproj index 9ffd36a4a..db2983f77 100644 --- a/src/util/util.vcxproj +++ b/src/util/util.vcxproj @@ -214,6 +214,7 @@ + diff --git a/src/util/util.vcxproj.filters b/src/util/util.vcxproj.filters index ceaa70861..724e927c0 100644 --- a/src/util/util.vcxproj.filters +++ b/src/util/util.vcxproj.filters @@ -153,6 +153,7 @@ + diff --git a/src/util/zstd_byte_stream.cpp b/src/util/zstd_byte_stream.cpp new file mode 100644 index 000000000..122c60765 --- /dev/null +++ b/src/util/zstd_byte_stream.cpp @@ -0,0 +1,317 @@ +// SPDX-FileCopyrightText: 2019-2024 Connor McLaughlin +// SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0) + +#include "common/byte_stream.h" +#include "common/log.h" + +#include +#include + +Log_SetChannel(ByteStream); + +namespace { +class ZstdCompressStream final : public ByteStream +{ +public: + ZstdCompressStream(ByteStream* dst_stream, int compression_level) : m_dst_stream(dst_stream) + { + m_cstream = ZSTD_createCStream(); + ZSTD_CCtx_setParameter(m_cstream, ZSTD_c_compressionLevel, compression_level); + } + + ~ZstdCompressStream() override + { + if (!m_done) + Compress(ZSTD_e_end); + + ZSTD_freeCStream(m_cstream); + } + + bool ReadByte(u8* pDestByte) override { return false; } + + u32 Read(void* pDestination, u32 ByteCount) override { return 0; } + + bool Read2(void* pDestination, u32 ByteCount, u32* pNumberOfBytesRead = nullptr) override { return false; } + + bool WriteByte(u8 SourceByte) override + { + if (m_input_buffer_wpos == INPUT_BUFFER_SIZE && !Compress(ZSTD_e_continue)) + return false; + + m_input_buffer[m_input_buffer_wpos++] = SourceByte; + return true; + } + + u32 Write(const void* pSource, u32 ByteCount) override + { + u32 remaining = ByteCount; + const u8* read_ptr = static_cast(pSource); + for (;;) + { + const u32 copy_size = std::min(INPUT_BUFFER_SIZE - m_input_buffer_wpos, remaining); + std::memcpy(&m_input_buffer[m_input_buffer_wpos], read_ptr, copy_size); + read_ptr += copy_size; + remaining -= copy_size; + m_input_buffer_wpos += copy_size; + if (remaining == 0 || !Compress(ZSTD_e_continue)) + break; + } + + return ByteCount - remaining; + } + + bool Write2(const void* pSource, u32 ByteCount, u32* pNumberOfBytesWritten = nullptr) override + { + const u32 bytes_written = Write(pSource, ByteCount); + if (pNumberOfBytesWritten) + *pNumberOfBytesWritten = bytes_written; + return (bytes_written == ByteCount); + } + + bool SeekAbsolute(u64 Offset) override { return false; } + + bool SeekRelative(s64 Offset) override { return (Offset == 0); } + + bool SeekToEnd() override { return false; } + + u64 GetPosition() const override { return m_position; } + + u64 GetSize() const override { return 0; } + + bool Flush() override { return Compress(ZSTD_e_flush); } + + bool Discard() override { return true; } + + bool Commit() override { return Compress(ZSTD_e_end); } + +private: + enum : u32 + { + INPUT_BUFFER_SIZE = 131072, + OUTPUT_BUFFER_SIZE = 65536, + }; + + bool Compress(ZSTD_EndDirective action) + { + if (m_errorState || m_done) + return false; + + ZSTD_inBuffer inbuf = {m_input_buffer, m_input_buffer_wpos, 0}; + + for (;;) + { + ZSTD_outBuffer outbuf = {m_output_buffer, OUTPUT_BUFFER_SIZE, 0}; + + const size_t ret = ZSTD_compressStream2(m_cstream, &outbuf, &inbuf, action); + if (ZSTD_isError(ret)) + { + Log_ErrorPrintf("ZSTD_compressStream2() failed: %u (%s)", static_cast(ZSTD_getErrorCode(ret)), + ZSTD_getErrorString(ZSTD_getErrorCode(ret))); + SetErrorState(); + return false; + } + + if (outbuf.pos > 0) + { + if (!m_dst_stream->Write2(m_output_buffer, static_cast(outbuf.pos))) + { + SetErrorState(); + return false; + } + + outbuf.pos = 0; + } + + if (action == ZSTD_e_end) + { + // break when compression output has finished + if (ret == 0) + { + m_done = true; + break; + } + } + else + { + // break when all input data is consumed + if (inbuf.pos == inbuf.size) + break; + } + } + + m_position += m_input_buffer_wpos; + m_input_buffer_wpos = 0; + return true; + } + + ByteStream* m_dst_stream; + ZSTD_CStream* m_cstream = nullptr; + u64 m_position = 0; + u32 m_input_buffer_wpos = 0; + bool m_done = false; + + u8 m_input_buffer[INPUT_BUFFER_SIZE]; + u8 m_output_buffer[OUTPUT_BUFFER_SIZE]; +}; +} // namespace + +std::unique_ptr ByteStream::CreateZstdCompressStream(ByteStream* src_stream, int compression_level) +{ + return std::make_unique(src_stream, compression_level); +} + +namespace { +class ZstdDecompressStream final : public ByteStream +{ +public: + ZstdDecompressStream(ByteStream* src_stream, u32 compressed_size) + : m_src_stream(src_stream), m_bytes_remaining(compressed_size) + { + m_cstream = ZSTD_createDStream(); + m_in_buffer.src = m_input_buffer; + Decompress(); + } + + ~ZstdDecompressStream() override { ZSTD_freeDStream(m_cstream); } + + bool ReadByte(u8* pDestByte) override { return Read(pDestByte, sizeof(u8)) == sizeof(u8); } + + u32 Read(void* pDestination, u32 ByteCount) override + { + u8* write_ptr = static_cast(pDestination); + u32 remaining = ByteCount; + for (;;) + { + const u32 copy_size = std::min(m_output_buffer_wpos - m_output_buffer_rpos, remaining); + std::memcpy(write_ptr, &m_output_buffer[m_output_buffer_rpos], copy_size); + m_output_buffer_rpos += copy_size; + write_ptr += copy_size; + remaining -= copy_size; + if (remaining == 0 || !Decompress()) + break; + } + + return ByteCount - remaining; + } + + bool Read2(void* pDestination, u32 ByteCount, u32* pNumberOfBytesRead = nullptr) override + { + const u32 bytes_read = Read(pDestination, ByteCount); + if (pNumberOfBytesRead) + *pNumberOfBytesRead = bytes_read; + return (bytes_read == ByteCount); + } + + bool WriteByte(u8 SourceByte) override { return false; } + + u32 Write(const void* pSource, u32 ByteCount) override { return 0; } + + bool Write2(const void* pSource, u32 ByteCount, u32* pNumberOfBytesWritten = nullptr) override { return false; } + + bool SeekAbsolute(u64 Offset) override { return false; } + + bool SeekRelative(s64 Offset) override + { + if (Offset < 0) + return false; + else if (Offset == 0) + return true; + + s64 remaining = Offset; + for (;;) + { + const s64 skip = std::min(m_output_buffer_wpos - m_output_buffer_rpos, remaining); + remaining -= skip; + m_output_buffer_rpos += static_cast(skip); + if (remaining == 0) + return true; + else if (!Decompress()) + return false; + } + } + + bool SeekToEnd() override { return false; } + + u64 GetPosition() const override { return 0; } + + u64 GetSize() const override { return 0; } + + bool Flush() override { return true; } + + bool Discard() override { return true; } + + bool Commit() override { return true; } + +private: + enum : u32 + { + INPUT_BUFFER_SIZE = 65536, + OUTPUT_BUFFER_SIZE = 131072, + }; + + bool Decompress() + { + if (m_output_buffer_rpos != m_output_buffer_wpos) + { + const u32 move_size = m_output_buffer_wpos - m_output_buffer_rpos; + std::memmove(&m_output_buffer[0], &m_output_buffer[m_output_buffer_rpos], move_size); + m_output_buffer_rpos = move_size; + m_output_buffer_wpos = move_size; + } + else + { + m_output_buffer_rpos = 0; + m_output_buffer_wpos = 0; + } + + ZSTD_outBuffer outbuf = {m_output_buffer, OUTPUT_BUFFER_SIZE - m_output_buffer_wpos, 0}; + while (outbuf.pos == 0) + { + if (m_in_buffer.pos == m_in_buffer.size && !m_errorState) + { + const u32 requested_size = std::min(m_bytes_remaining, INPUT_BUFFER_SIZE); + const u32 bytes_read = m_src_stream->Read(m_input_buffer, requested_size); + m_in_buffer.size = bytes_read; + m_in_buffer.pos = 0; + m_bytes_remaining -= bytes_read; + if (bytes_read != requested_size || m_bytes_remaining == 0) + { + m_errorState = true; + break; + } + } + + size_t ret = ZSTD_decompressStream(m_cstream, &outbuf, &m_in_buffer); + if (ZSTD_isError(ret)) + { + Log_ErrorPrintf("ZSTD_decompressStream() failed: %u (%s)", static_cast(ZSTD_getErrorCode(ret)), + ZSTD_getErrorString(ZSTD_getErrorCode(ret))); + m_in_buffer.pos = m_in_buffer.size; + m_output_buffer_rpos = 0; + m_output_buffer_wpos = 0; + m_errorState = true; + return false; + } + } + + m_output_buffer_wpos = static_cast(outbuf.pos); + return true; + } + + ByteStream* m_src_stream; + ZSTD_DStream* m_cstream = nullptr; + ZSTD_inBuffer m_in_buffer = {}; + u32 m_output_buffer_rpos = 0; + u32 m_output_buffer_wpos = 0; + u32 m_bytes_remaining; + bool m_errorState = false; + + u8 m_input_buffer[INPUT_BUFFER_SIZE]; + u8 m_output_buffer[OUTPUT_BUFFER_SIZE]; +}; +} // namespace + +std::unique_ptr ByteStream::CreateZstdDecompressStream(ByteStream* src_stream, u32 compressed_size) +{ + return std::make_unique(src_stream, compressed_size); +}