diff --git a/src/util/sockets.cpp b/src/util/sockets.cpp index 95e6a7c41..9492a76a1 100644 --- a/src/util/sockets.cpp +++ b/src/util/sockets.cpp @@ -33,6 +33,7 @@ using nfds_t = ULONG; #include #include #include +#include #include #include #include @@ -76,6 +77,12 @@ void SocketAddress::SetFromSockaddr(const void* sa, size_t length) std::memset(m_data + m_length, 0, sizeof(m_data) - m_length); } +bool SocketAddress::IsIPAddress() const +{ + const sockaddr* addr = reinterpret_cast(m_data); + return (addr->sa_family == AF_INET || addr->sa_family == AF_INET6); +} + std::optional SocketAddress::Parse(Type type, const char* address, u32 port, Error* error) { std::optional ret = SocketAddress(); @@ -695,6 +702,24 @@ size_t StreamSocket::WriteVector(const void** buffers, const size_t* buffer_leng #endif } +bool StreamSocket::SetNagleBuffering(bool enabled, Error* error /* = nullptr */) +{ + if (!m_local_address.IsIPAddress()) + { + Error::SetStringView(error, "Attempting to disable nagle on a non-IP socket."); + return false; + } + + int disable = enabled ? 0 : 1; + if (setsockopt(m_descriptor, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&disable), sizeof(disable)) != 0) + { + Error::SetSocket(error, "setsockopt(TCP_NODELAY) failed: ", WSAGetLastError()); + return false; + } + + return true; +} + void StreamSocket::Close() { std::unique_lock lock(m_lock); diff --git a/src/util/sockets.h b/src/util/sockets.h index 4d5833dc4..b01918664 100644 --- a/src/util/sockets.h +++ b/src/util/sockets.h @@ -56,6 +56,9 @@ struct SocketAddress final // initializers void SetFromSockaddr(const void* sa, size_t length); + /// Returns true if the address is IP. + bool IsIPAddress() const; + private: u8 m_data[128] = {}; u32 m_length = 0; @@ -218,6 +221,9 @@ public: size_t Write(const void* buffer, size_t buffer_size); size_t WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers); + /// Disables Nagle's buffering algorithm, i.e. TCP_NODELAY. + bool SetNagleBuffering(bool enabled, Error* error = nullptr); + protected: virtual void OnConnected() = 0; virtual void OnDisconnected(const Error& error) = 0;