diff --git a/src/common-tests/string_tests.cpp b/src/common-tests/string_tests.cpp index 4f3da18d6..ddf84ba7e 100644 --- a/src/common-tests/string_tests.cpp +++ b/src/common-tests/string_tests.cpp @@ -33,3 +33,41 @@ TEST(StringUtil, EllipsiseInPlace) StringUtil::EllipsiseInPlace(s, 10, "..."); ASSERT_EQ(s, "Hello"); } + +TEST(StringUtil, Base64EncodeDecode) +{ + struct TestCase + { + const char* hexString; + const char* base64String; + }; + static const TestCase testCases[] = { + {"33326a6f646933326a68663937683732383368", "MzJqb2RpMzJqaGY5N2g3MjgzaA=="}, + {"32753965333268756979386672677537366967723839683432703075693132393065755c5d0931325c335c31323439303438753839333272", + "MnU5ZTMyaHVpeThmcmd1NzZpZ3I4OWg0MnAwdWkxMjkwZXVcXQkxMlwzXDEyNDkwNDh1ODkzMnI="}, + {"3332726a33323738676838666233326830393233386637683938323139", "MzJyajMyNzhnaDhmYjMyaDA5MjM4ZjdoOTgyMTk="}, + {"9956967BE9C96E10B27FF8897A5B768A2F4B103CE934718D020FE6B5B770", "mVaWe+nJbhCyf/iJelt2ii9LEDzpNHGNAg/mtbdw"}, + {"BC94251814827A5D503D62D5EE6CBAB0FD55D2E2FCEDBB2261D6010084B95DD648766D8983F03AFA3908956D8201E26BB09FE52B515A61A9E" + "1D3ADC207BD9E622128F22929CDED456B595A410F7168B0BA6370289E6291E38E47C18278561C79A7297C21D23C06BB2F694DC2F65FAAF994" + "59E3FC14B1FA415A3320AF00ACE54C00BE", + "vJQlGBSCel1QPWLV7my6sP1V0uL87bsiYdYBAIS5XdZIdm2Jg/A6+jkIlW2CAeJrsJ/" + "lK1FaYanh063CB72eYiEo8ikpze1Fa1laQQ9xaLC6Y3AonmKR445HwYJ4Vhx5pyl8IdI8BrsvaU3C9l+q+ZRZ4/wUsfpBWjMgrwCs5UwAvg=="}, + {"192B42CB0F66F69BE8A5", "GStCyw9m9pvopQ=="}, + {"38ABD400F3BB6960EB60C056719B5362", "OKvUAPO7aWDrYMBWcZtTYg=="}, + {"776FAB27DC7F8DA86F298D55B69F8C278D53871F8CBCCF", "d2+rJ9x/jahvKY1Vtp+MJ41Thx+MvM8="}, + {"B1ED3EA2E35EE69C7E16707B05042A", "se0+ouNe5px+FnB7BQQq"}, + }; + + for (const TestCase& tc : testCases) + { + std::optional> bytes = StringUtil::DecodeHex(tc.hexString); + ASSERT_TRUE(bytes.has_value()); + + std::string encoded_b64 = StringUtil::EncodeBase64(bytes.value()); + ASSERT_EQ(encoded_b64, tc.base64String); + + std::optional> dbytes = StringUtil::DecodeBase64(tc.base64String); + ASSERT_TRUE(dbytes.has_value()); + ASSERT_EQ(dbytes.value(), bytes.value()); + } +} diff --git a/src/common/string_util.cpp b/src/common/string_util.cpp index ccb7b4181..18cc8110d 100644 --- a/src/common/string_util.cpp +++ b/src/common/string_util.cpp @@ -197,6 +197,107 @@ std::string StringUtil::EncodeHex(const void* data, size_t length) return ret; } +size_t StringUtil::EncodeBase64(const std::span dest, const std::span data) +{ + const size_t expected_length = EncodedBase64Length(data); + Assert(dest.size() <= expected_length); + + static constexpr std::array table = { + {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', + 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', + 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'}}; + + const size_t dataLength = data.size(); + size_t dest_pos = 0; + + for (size_t i = 0; i < dataLength;) + { + const size_t bytes_in_sequence = std::min(dataLength - i, 3); + switch (bytes_in_sequence) + { + case 1: + dest[dest_pos++] = table[(data[i] >> 2) & 63]; + dest[dest_pos++] = table[(data[i] & 3) << 4]; + dest[dest_pos++] = '='; + dest[dest_pos++] = '='; + break; + + case 2: + dest[dest_pos++] = table[(data[i] >> 2) & 63]; + dest[dest_pos++] = table[((data[i] & 3) << 4) | ((data[i + 1] >> 4) & 15)]; + dest[dest_pos++] = table[(data[i + 1] & 15) << 2]; + dest[dest_pos++] = '='; + break; + + case 3: + dest[dest_pos++] = table[(data[i] >> 2) & 63]; + dest[dest_pos++] = table[((data[i] & 3) << 4) | ((data[i + 1] >> 4) & 15)]; + dest[dest_pos++] = table[((data[i + 1] & 15) << 2) | ((data[i + 2] >> 6) & 3)]; + dest[dest_pos++] = table[data[i + 2] & 63]; + break; + + DefaultCaseIsUnreachable(); + } + + i += bytes_in_sequence; + } + + DebugAssert(dest_pos == expected_length); + return dest_pos; +} + +size_t StringUtil::DecodeBase64(const std::span data, const std::string_view str) +{ + static constexpr std::array table = { + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 62, 64, 64, 64, 63, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 64, 64, 64, 0, 64, 64, 64, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 64, 64, 64, 64, 64, 64, 26, 27, 28, 29, 30, 31, 32, + 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 64, 64, 64, 64, 64}; + + const size_t str_length = str.length(); + if ((str_length % 4) != 0) + return 0; + + size_t data_pos = 0; + for (size_t i = 0; i < str_length;) + { + const u8 byte1 = table[str[i++] & 0x7F]; + const u8 byte2 = table[str[i++] & 0x7F]; + const u8 byte3 = table[str[i++] & 0x7F]; + const u8 byte4 = table[str[i++] & 0x7F]; + + if (byte1 == 64 || byte2 == 64 || byte3 == 64 || byte4 == 64) + break; + + data[data_pos++] = (byte1 << 2) | (byte2 >> 4); + if (str[i - 2] != '=') + data[data_pos++] = ((byte2 << 4) | (byte3 >> 2)); + if (str[i - 1] != '=') + data[data_pos++] = ((byte3 << 6) | byte4); + } + + return data_pos; +} + +std::optional> StringUtil::DecodeBase64(const std::string_view str) +{ + std::vector ret; + const size_t len = DecodedBase64Length(str); + ret.resize(len); + if (DecodeBase64(ret, str) != len) + ret = {}; + return ret; +} + +std::string StringUtil::EncodeBase64(const std::span data) +{ + std::string ret; + ret.resize(EncodedBase64Length(data)); + ret.resize(EncodeBase64(ret, data)); + return ret; +} + std::string_view StringUtil::StripWhitespace(const std::string_view str) { std::string_view::size_type start = 0; diff --git a/src/common/string_util.h b/src/common/string_util.h index fee55fa51..8c01ec17e 100644 --- a/src/common/string_util.h +++ b/src/common/string_util.h @@ -4,6 +4,7 @@ #pragma once #include "types.h" +#include #include #include #include @@ -270,6 +271,33 @@ static constexpr std::array ParseFixedHexString(const char str[]) return h; } +/// Encode/decode Base64 buffers. +static constexpr size_t DecodedBase64Length(const std::string_view str) +{ + // Should be a multiple of 4. + const size_t str_length = str.length(); + if ((str_length % 4) != 0) + return 0; + + // Reverse padding. + size_t padding = 0; + if (str.length() >= 2) + { + padding += static_cast(str[str_length - 1] == '='); + padding += static_cast(str[str_length - 2] == '='); + } + + return (str_length / 4) * 3 - padding; +} +static constexpr size_t EncodedBase64Length(const std::span data) +{ + return ((data.size() + 2) / 3) * 4; +} +size_t DecodeBase64(const std::span data, const std::string_view str); +size_t EncodeBase64(const std::span dest, const std::span data); +std::string EncodeBase64(const std::span data); +std::optional> DecodeBase64(const std::string_view str); + /// StartsWith/EndsWith variants which aren't case sensitive. ALWAYS_INLINE static bool StartsWithNoCase(const std::string_view str, const std::string_view prefix) {