ByteStream: Move zstd wrapper to util project

Removes zstd dependency from common, fixes updater running on Mac.
This commit is contained in:
Stenzek 2024-02-05 14:26:42 +10:00
parent ac1fd7f0cf
commit d5fb5645fc
No known key found for this signature in database
7 changed files with 327 additions and 313 deletions

View File

@ -61,7 +61,7 @@ add_library(common
target_include_directories(common PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/..") target_include_directories(common PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/..")
target_include_directories(common PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/..") target_include_directories(common PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/..")
target_link_libraries(common PUBLIC fmt Threads::Threads fast_float) target_link_libraries(common PUBLIC fmt Threads::Threads fast_float)
target_link_libraries(common PRIVATE stb ZLIB::ZLIB minizip Zstd::Zstd "${CMAKE_DL_LIBS}") target_link_libraries(common PRIVATE stb ZLIB::ZLIB minizip "${CMAKE_DL_LIBS}")
if(WIN32) if(WIN32)
target_sources(common PRIVATE target_sources(common PRIVATE

View File

@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2019-2022 Connor McLaughlin <stenzek@gmail.com> // SPDX-FileCopyrightText: 2019-2024 Connor McLaughlin <stenzek@gmail.com>
// SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0) // SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0)
#include "byte_stream.h" #include "byte_stream.h"
@ -6,8 +6,7 @@
#include "file_system.h" #include "file_system.h"
#include "log.h" #include "log.h"
#include "string_util.h" #include "string_util.h"
#include "zstd.h"
#include "zstd_errors.h"
#include <algorithm> #include <algorithm>
#include <cerrno> #include <cerrno>
#include <cstdio> #include <cstdio>
@ -1345,310 +1344,3 @@ bool ByteStream::WriteBinaryToStream(ByteStream* stream, const void* data, size_
return stream->Write2(data, static_cast<u32>(data_length)); return stream->Write2(data, static_cast<u32>(data_length));
} }
namespace {
class ZstdCompressStream final : public ByteStream
{
public:
ZstdCompressStream(ByteStream* dst_stream, int compression_level) : m_dst_stream(dst_stream)
{
m_cstream = ZSTD_createCStream();
ZSTD_CCtx_setParameter(m_cstream, ZSTD_c_compressionLevel, compression_level);
}
~ZstdCompressStream() override
{
if (!m_done)
Compress(ZSTD_e_end);
ZSTD_freeCStream(m_cstream);
}
bool ReadByte(u8* pDestByte) override { return false; }
u32 Read(void* pDestination, u32 ByteCount) override { return 0; }
bool Read2(void* pDestination, u32 ByteCount, u32* pNumberOfBytesRead = nullptr) override { return false; }
bool WriteByte(u8 SourceByte) override
{
if (m_input_buffer_wpos == INPUT_BUFFER_SIZE && !Compress(ZSTD_e_continue))
return false;
m_input_buffer[m_input_buffer_wpos++] = SourceByte;
return true;
}
u32 Write(const void* pSource, u32 ByteCount) override
{
u32 remaining = ByteCount;
const u8* read_ptr = static_cast<const u8*>(pSource);
for (;;)
{
const u32 copy_size = std::min(INPUT_BUFFER_SIZE - m_input_buffer_wpos, remaining);
std::memcpy(&m_input_buffer[m_input_buffer_wpos], read_ptr, copy_size);
read_ptr += copy_size;
remaining -= copy_size;
m_input_buffer_wpos += copy_size;
if (remaining == 0 || !Compress(ZSTD_e_continue))
break;
}
return ByteCount - remaining;
}
bool Write2(const void* pSource, u32 ByteCount, u32* pNumberOfBytesWritten = nullptr) override
{
const u32 bytes_written = Write(pSource, ByteCount);
if (pNumberOfBytesWritten)
*pNumberOfBytesWritten = bytes_written;
return (bytes_written == ByteCount);
}
bool SeekAbsolute(u64 Offset) override { return false; }
bool SeekRelative(s64 Offset) override { return (Offset == 0); }
bool SeekToEnd() override { return false; }
u64 GetPosition() const override { return m_position; }
u64 GetSize() const override { return 0; }
bool Flush() override { return Compress(ZSTD_e_flush); }
bool Discard() override { return true; }
bool Commit() override { return Compress(ZSTD_e_end); }
private:
enum : u32
{
INPUT_BUFFER_SIZE = 131072,
OUTPUT_BUFFER_SIZE = 65536,
};
bool Compress(ZSTD_EndDirective action)
{
if (m_errorState || m_done)
return false;
ZSTD_inBuffer inbuf = {m_input_buffer, m_input_buffer_wpos, 0};
for (;;)
{
ZSTD_outBuffer outbuf = {m_output_buffer, OUTPUT_BUFFER_SIZE, 0};
const size_t ret = ZSTD_compressStream2(m_cstream, &outbuf, &inbuf, action);
if (ZSTD_isError(ret))
{
Log_ErrorPrintf("ZSTD_compressStream2() failed: %u (%s)", static_cast<unsigned>(ZSTD_getErrorCode(ret)),
ZSTD_getErrorString(ZSTD_getErrorCode(ret)));
SetErrorState();
return false;
}
if (outbuf.pos > 0)
{
if (!m_dst_stream->Write2(m_output_buffer, static_cast<u32>(outbuf.pos)))
{
SetErrorState();
return false;
}
outbuf.pos = 0;
}
if (action == ZSTD_e_end)
{
// break when compression output has finished
if (ret == 0)
{
m_done = true;
break;
}
}
else
{
// break when all input data is consumed
if (inbuf.pos == inbuf.size)
break;
}
}
m_position += m_input_buffer_wpos;
m_input_buffer_wpos = 0;
return true;
}
ByteStream* m_dst_stream;
ZSTD_CStream* m_cstream = nullptr;
u64 m_position = 0;
u32 m_input_buffer_wpos = 0;
bool m_done = false;
u8 m_input_buffer[INPUT_BUFFER_SIZE];
u8 m_output_buffer[OUTPUT_BUFFER_SIZE];
};
} // namespace
std::unique_ptr<ByteStream> ByteStream::CreateZstdCompressStream(ByteStream* src_stream, int compression_level)
{
return std::make_unique<ZstdCompressStream>(src_stream, compression_level);
}
namespace {
class ZstdDecompressStream final : public ByteStream
{
public:
ZstdDecompressStream(ByteStream* src_stream, u32 compressed_size)
: m_src_stream(src_stream), m_bytes_remaining(compressed_size)
{
m_cstream = ZSTD_createDStream();
m_in_buffer.src = m_input_buffer;
Decompress();
}
~ZstdDecompressStream() override { ZSTD_freeDStream(m_cstream); }
bool ReadByte(u8* pDestByte) override { return Read(pDestByte, sizeof(u8)) == sizeof(u8); }
u32 Read(void* pDestination, u32 ByteCount) override
{
u8* write_ptr = static_cast<u8*>(pDestination);
u32 remaining = ByteCount;
for (;;)
{
const u32 copy_size = std::min<u32>(m_output_buffer_wpos - m_output_buffer_rpos, remaining);
std::memcpy(write_ptr, &m_output_buffer[m_output_buffer_rpos], copy_size);
m_output_buffer_rpos += copy_size;
write_ptr += copy_size;
remaining -= copy_size;
if (remaining == 0 || !Decompress())
break;
}
return ByteCount - remaining;
}
bool Read2(void* pDestination, u32 ByteCount, u32* pNumberOfBytesRead = nullptr) override
{
const u32 bytes_read = Read(pDestination, ByteCount);
if (pNumberOfBytesRead)
*pNumberOfBytesRead = bytes_read;
return (bytes_read == ByteCount);
}
bool WriteByte(u8 SourceByte) override { return false; }
u32 Write(const void* pSource, u32 ByteCount) override { return 0; }
bool Write2(const void* pSource, u32 ByteCount, u32* pNumberOfBytesWritten = nullptr) override { return false; }
bool SeekAbsolute(u64 Offset) override { return false; }
bool SeekRelative(s64 Offset) override
{
if (Offset < 0)
return false;
else if (Offset == 0)
return true;
s64 remaining = Offset;
for (;;)
{
const s64 skip = std::min<s64>(m_output_buffer_wpos - m_output_buffer_rpos, remaining);
remaining -= skip;
m_output_buffer_rpos += static_cast<u32>(skip);
if (remaining == 0)
return true;
else if (!Decompress())
return false;
}
}
bool SeekToEnd() override { return false; }
u64 GetPosition() const override { return 0; }
u64 GetSize() const override { return 0; }
bool Flush() override { return true; }
bool Discard() override { return true; }
bool Commit() override { return true; }
private:
enum : u32
{
INPUT_BUFFER_SIZE = 65536,
OUTPUT_BUFFER_SIZE = 131072,
};
bool Decompress()
{
if (m_output_buffer_rpos != m_output_buffer_wpos)
{
const u32 move_size = m_output_buffer_wpos - m_output_buffer_rpos;
std::memmove(&m_output_buffer[0], &m_output_buffer[m_output_buffer_rpos], move_size);
m_output_buffer_rpos = move_size;
m_output_buffer_wpos = move_size;
}
else
{
m_output_buffer_rpos = 0;
m_output_buffer_wpos = 0;
}
ZSTD_outBuffer outbuf = {m_output_buffer, OUTPUT_BUFFER_SIZE - m_output_buffer_wpos, 0};
while (outbuf.pos == 0)
{
if (m_in_buffer.pos == m_in_buffer.size && !m_errorState)
{
const u32 requested_size = std::min<u32>(m_bytes_remaining, INPUT_BUFFER_SIZE);
const u32 bytes_read = m_src_stream->Read(m_input_buffer, requested_size);
m_in_buffer.size = bytes_read;
m_in_buffer.pos = 0;
m_bytes_remaining -= bytes_read;
if (bytes_read != requested_size || m_bytes_remaining == 0)
{
m_errorState = true;
break;
}
}
size_t ret = ZSTD_decompressStream(m_cstream, &outbuf, &m_in_buffer);
if (ZSTD_isError(ret))
{
Log_ErrorPrintf("ZSTD_decompressStream() failed: %u (%s)", static_cast<unsigned>(ZSTD_getErrorCode(ret)),
ZSTD_getErrorString(ZSTD_getErrorCode(ret)));
m_in_buffer.pos = m_in_buffer.size;
m_output_buffer_rpos = 0;
m_output_buffer_wpos = 0;
m_errorState = true;
return false;
}
}
m_output_buffer_wpos = static_cast<u32>(outbuf.pos);
return true;
}
ByteStream* m_src_stream;
ZSTD_DStream* m_cstream = nullptr;
ZSTD_inBuffer m_in_buffer = {};
u32 m_output_buffer_rpos = 0;
u32 m_output_buffer_wpos = 0;
u32 m_bytes_remaining;
bool m_errorState = false;
u8 m_input_buffer[INPUT_BUFFER_SIZE];
u8 m_output_buffer[OUTPUT_BUFFER_SIZE];
};
} // namespace
std::unique_ptr<ByteStream> ByteStream::CreateZstdDecompressStream(ByteStream* src_stream, u32 compressed_size)
{
return std::make_unique<ZstdDecompressStream>(src_stream, compressed_size);
}

View File

@ -1,8 +1,10 @@
// SPDX-FileCopyrightText: 2019-2022 Connor McLaughlin <stenzek@gmail.com> // SPDX-FileCopyrightText: 2019-2024 Connor McLaughlin <stenzek@gmail.com>
// SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0) // SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0)
#pragma once #pragma once
#include "types.h" #include "types.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <string_view> #include <string_view>
@ -120,7 +122,7 @@ public:
// null memory stream // null memory stream
static std::unique_ptr<NullByteStream> CreateNullStream(); static std::unique_ptr<NullByteStream> CreateNullStream();
// zstd stream // zstd stream, actually defined in util/zstd_byte_stream.cpp, to avoid common dependency on libzstd
static std::unique_ptr<ByteStream> CreateZstdCompressStream(ByteStream* src_stream, int compression_level); static std::unique_ptr<ByteStream> CreateZstdCompressStream(ByteStream* src_stream, int compression_level);
static std::unique_ptr<ByteStream> CreateZstdDecompressStream(ByteStream* src_stream, u32 compressed_size); static std::unique_ptr<ByteStream> CreateZstdDecompressStream(ByteStream* src_stream, u32 compressed_size);

View File

@ -67,6 +67,7 @@ add_library(util
wav_writer.h wav_writer.h
window_info.cpp window_info.cpp
window_info.h window_info.h
zstd_byte_stream.cpp
) )
target_precompile_headers(util PRIVATE "pch.h") target_precompile_headers(util PRIVATE "pch.h")

View File

@ -214,6 +214,7 @@
<ClCompile Include="window_info.cpp" /> <ClCompile Include="window_info.cpp" />
<ClCompile Include="xaudio2_audio_stream.cpp" /> <ClCompile Include="xaudio2_audio_stream.cpp" />
<ClCompile Include="xinput_source.cpp" /> <ClCompile Include="xinput_source.cpp" />
<ClCompile Include="zstd_byte_stream.cpp" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ProjectReference Include="..\..\dep\cubeb\cubeb.vcxproj"> <ProjectReference Include="..\..\dep\cubeb\cubeb.vcxproj">

View File

@ -153,6 +153,7 @@
<ClCompile Include="http_downloader.cpp" /> <ClCompile Include="http_downloader.cpp" />
<ClCompile Include="metal_device.mm" /> <ClCompile Include="metal_device.mm" />
<ClCompile Include="metal_stream_buffer.mm" /> <ClCompile Include="metal_stream_buffer.mm" />
<ClCompile Include="zstd_byte_stream.cpp" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<Filter Include="gl"> <Filter Include="gl">

View File

@ -0,0 +1,317 @@
// SPDX-FileCopyrightText: 2019-2024 Connor McLaughlin <stenzek@gmail.com>
// SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0)
#include "common/byte_stream.h"
#include "common/log.h"
#include <zstd.h>
#include <zstd_errors.h>
Log_SetChannel(ByteStream);
namespace {
class ZstdCompressStream final : public ByteStream
{
public:
ZstdCompressStream(ByteStream* dst_stream, int compression_level) : m_dst_stream(dst_stream)
{
m_cstream = ZSTD_createCStream();
ZSTD_CCtx_setParameter(m_cstream, ZSTD_c_compressionLevel, compression_level);
}
~ZstdCompressStream() override
{
if (!m_done)
Compress(ZSTD_e_end);
ZSTD_freeCStream(m_cstream);
}
bool ReadByte(u8* pDestByte) override { return false; }
u32 Read(void* pDestination, u32 ByteCount) override { return 0; }
bool Read2(void* pDestination, u32 ByteCount, u32* pNumberOfBytesRead = nullptr) override { return false; }
bool WriteByte(u8 SourceByte) override
{
if (m_input_buffer_wpos == INPUT_BUFFER_SIZE && !Compress(ZSTD_e_continue))
return false;
m_input_buffer[m_input_buffer_wpos++] = SourceByte;
return true;
}
u32 Write(const void* pSource, u32 ByteCount) override
{
u32 remaining = ByteCount;
const u8* read_ptr = static_cast<const u8*>(pSource);
for (;;)
{
const u32 copy_size = std::min(INPUT_BUFFER_SIZE - m_input_buffer_wpos, remaining);
std::memcpy(&m_input_buffer[m_input_buffer_wpos], read_ptr, copy_size);
read_ptr += copy_size;
remaining -= copy_size;
m_input_buffer_wpos += copy_size;
if (remaining == 0 || !Compress(ZSTD_e_continue))
break;
}
return ByteCount - remaining;
}
bool Write2(const void* pSource, u32 ByteCount, u32* pNumberOfBytesWritten = nullptr) override
{
const u32 bytes_written = Write(pSource, ByteCount);
if (pNumberOfBytesWritten)
*pNumberOfBytesWritten = bytes_written;
return (bytes_written == ByteCount);
}
bool SeekAbsolute(u64 Offset) override { return false; }
bool SeekRelative(s64 Offset) override { return (Offset == 0); }
bool SeekToEnd() override { return false; }
u64 GetPosition() const override { return m_position; }
u64 GetSize() const override { return 0; }
bool Flush() override { return Compress(ZSTD_e_flush); }
bool Discard() override { return true; }
bool Commit() override { return Compress(ZSTD_e_end); }
private:
enum : u32
{
INPUT_BUFFER_SIZE = 131072,
OUTPUT_BUFFER_SIZE = 65536,
};
bool Compress(ZSTD_EndDirective action)
{
if (m_errorState || m_done)
return false;
ZSTD_inBuffer inbuf = {m_input_buffer, m_input_buffer_wpos, 0};
for (;;)
{
ZSTD_outBuffer outbuf = {m_output_buffer, OUTPUT_BUFFER_SIZE, 0};
const size_t ret = ZSTD_compressStream2(m_cstream, &outbuf, &inbuf, action);
if (ZSTD_isError(ret))
{
Log_ErrorPrintf("ZSTD_compressStream2() failed: %u (%s)", static_cast<unsigned>(ZSTD_getErrorCode(ret)),
ZSTD_getErrorString(ZSTD_getErrorCode(ret)));
SetErrorState();
return false;
}
if (outbuf.pos > 0)
{
if (!m_dst_stream->Write2(m_output_buffer, static_cast<u32>(outbuf.pos)))
{
SetErrorState();
return false;
}
outbuf.pos = 0;
}
if (action == ZSTD_e_end)
{
// break when compression output has finished
if (ret == 0)
{
m_done = true;
break;
}
}
else
{
// break when all input data is consumed
if (inbuf.pos == inbuf.size)
break;
}
}
m_position += m_input_buffer_wpos;
m_input_buffer_wpos = 0;
return true;
}
ByteStream* m_dst_stream;
ZSTD_CStream* m_cstream = nullptr;
u64 m_position = 0;
u32 m_input_buffer_wpos = 0;
bool m_done = false;
u8 m_input_buffer[INPUT_BUFFER_SIZE];
u8 m_output_buffer[OUTPUT_BUFFER_SIZE];
};
} // namespace
std::unique_ptr<ByteStream> ByteStream::CreateZstdCompressStream(ByteStream* src_stream, int compression_level)
{
return std::make_unique<ZstdCompressStream>(src_stream, compression_level);
}
namespace {
class ZstdDecompressStream final : public ByteStream
{
public:
ZstdDecompressStream(ByteStream* src_stream, u32 compressed_size)
: m_src_stream(src_stream), m_bytes_remaining(compressed_size)
{
m_cstream = ZSTD_createDStream();
m_in_buffer.src = m_input_buffer;
Decompress();
}
~ZstdDecompressStream() override { ZSTD_freeDStream(m_cstream); }
bool ReadByte(u8* pDestByte) override { return Read(pDestByte, sizeof(u8)) == sizeof(u8); }
u32 Read(void* pDestination, u32 ByteCount) override
{
u8* write_ptr = static_cast<u8*>(pDestination);
u32 remaining = ByteCount;
for (;;)
{
const u32 copy_size = std::min<u32>(m_output_buffer_wpos - m_output_buffer_rpos, remaining);
std::memcpy(write_ptr, &m_output_buffer[m_output_buffer_rpos], copy_size);
m_output_buffer_rpos += copy_size;
write_ptr += copy_size;
remaining -= copy_size;
if (remaining == 0 || !Decompress())
break;
}
return ByteCount - remaining;
}
bool Read2(void* pDestination, u32 ByteCount, u32* pNumberOfBytesRead = nullptr) override
{
const u32 bytes_read = Read(pDestination, ByteCount);
if (pNumberOfBytesRead)
*pNumberOfBytesRead = bytes_read;
return (bytes_read == ByteCount);
}
bool WriteByte(u8 SourceByte) override { return false; }
u32 Write(const void* pSource, u32 ByteCount) override { return 0; }
bool Write2(const void* pSource, u32 ByteCount, u32* pNumberOfBytesWritten = nullptr) override { return false; }
bool SeekAbsolute(u64 Offset) override { return false; }
bool SeekRelative(s64 Offset) override
{
if (Offset < 0)
return false;
else if (Offset == 0)
return true;
s64 remaining = Offset;
for (;;)
{
const s64 skip = std::min<s64>(m_output_buffer_wpos - m_output_buffer_rpos, remaining);
remaining -= skip;
m_output_buffer_rpos += static_cast<u32>(skip);
if (remaining == 0)
return true;
else if (!Decompress())
return false;
}
}
bool SeekToEnd() override { return false; }
u64 GetPosition() const override { return 0; }
u64 GetSize() const override { return 0; }
bool Flush() override { return true; }
bool Discard() override { return true; }
bool Commit() override { return true; }
private:
enum : u32
{
INPUT_BUFFER_SIZE = 65536,
OUTPUT_BUFFER_SIZE = 131072,
};
bool Decompress()
{
if (m_output_buffer_rpos != m_output_buffer_wpos)
{
const u32 move_size = m_output_buffer_wpos - m_output_buffer_rpos;
std::memmove(&m_output_buffer[0], &m_output_buffer[m_output_buffer_rpos], move_size);
m_output_buffer_rpos = move_size;
m_output_buffer_wpos = move_size;
}
else
{
m_output_buffer_rpos = 0;
m_output_buffer_wpos = 0;
}
ZSTD_outBuffer outbuf = {m_output_buffer, OUTPUT_BUFFER_SIZE - m_output_buffer_wpos, 0};
while (outbuf.pos == 0)
{
if (m_in_buffer.pos == m_in_buffer.size && !m_errorState)
{
const u32 requested_size = std::min<u32>(m_bytes_remaining, INPUT_BUFFER_SIZE);
const u32 bytes_read = m_src_stream->Read(m_input_buffer, requested_size);
m_in_buffer.size = bytes_read;
m_in_buffer.pos = 0;
m_bytes_remaining -= bytes_read;
if (bytes_read != requested_size || m_bytes_remaining == 0)
{
m_errorState = true;
break;
}
}
size_t ret = ZSTD_decompressStream(m_cstream, &outbuf, &m_in_buffer);
if (ZSTD_isError(ret))
{
Log_ErrorPrintf("ZSTD_decompressStream() failed: %u (%s)", static_cast<unsigned>(ZSTD_getErrorCode(ret)),
ZSTD_getErrorString(ZSTD_getErrorCode(ret)));
m_in_buffer.pos = m_in_buffer.size;
m_output_buffer_rpos = 0;
m_output_buffer_wpos = 0;
m_errorState = true;
return false;
}
}
m_output_buffer_wpos = static_cast<u32>(outbuf.pos);
return true;
}
ByteStream* m_src_stream;
ZSTD_DStream* m_cstream = nullptr;
ZSTD_inBuffer m_in_buffer = {};
u32 m_output_buffer_rpos = 0;
u32 m_output_buffer_wpos = 0;
u32 m_bytes_remaining;
bool m_errorState = false;
u8 m_input_buffer[INPUT_BUFFER_SIZE];
u8 m_output_buffer[OUTPUT_BUFFER_SIZE];
};
} // namespace
std::unique_ptr<ByteStream> ByteStream::CreateZstdDecompressStream(ByteStream* src_stream, u32 compressed_size)
{
return std::make_unique<ZstdDecompressStream>(src_stream, compressed_size);
}