diff --git a/src/util/sockets.cpp b/src/util/sockets.cpp index 7734f4f8d..95e6a7c41 100644 --- a/src/util/sockets.cpp +++ b/src/util/sockets.cpp @@ -321,7 +321,13 @@ void SocketMultiplexer::AddOpenSocket(std::shared_ptr socket) std::unique_lock lock(m_open_sockets_lock); DebugAssert(m_open_sockets.find(socket->GetDescriptor()) == m_open_sockets.end()); - m_open_sockets.emplace(socket->GetDescriptor(), socket); + m_open_sockets.emplace(socket->GetDescriptor(), std::move(socket)); +} + +void SocketMultiplexer::AddClientSocket(std::shared_ptr socket) +{ + AddOpenSocket(std::move(socket)); + m_client_socket_count.fetch_add(1, std::memory_order_acq_rel); } void SocketMultiplexer::RemoveOpenSocket(BaseSocket* socket) @@ -349,12 +355,29 @@ void SocketMultiplexer::RemoveOpenSocket(BaseSocket* socket) m_poll_array_active_size = new_active_size; } +void SocketMultiplexer::RemoveClientSocket(BaseSocket* socket) +{ + DebugAssert(m_client_socket_count.load(std::memory_order_acquire) > 0); + m_client_socket_count.fetch_sub(1, std::memory_order_acq_rel); + RemoveOpenSocket(socket); +} + bool SocketMultiplexer::HasAnyOpenSockets() { std::unique_lock lock(m_open_sockets_lock); return !m_open_sockets.empty(); } +bool SocketMultiplexer::HasAnyClientSockets() +{ + return (GetClientSocketCount() > 0); +} + +size_t SocketMultiplexer::GetClientSocketCount() +{ + return m_client_socket_count.load(std::memory_order_acquire); +} + void SocketMultiplexer::CloseAll() { std::unique_lock lock(m_open_sockets_lock); @@ -559,7 +582,7 @@ u32 StreamSocket::GetSocketProtocolForAddress(const SocketAddress& sa) void StreamSocket::InitialSetup() { // register for notifications - m_multiplexer.AddOpenSocket(shared_from_this()); + m_multiplexer.AddClientSocket(shared_from_this()); m_multiplexer.SetNotificationMask(this, m_descriptor, POLLIN); // trigger connected notification @@ -679,7 +702,7 @@ void StreamSocket::Close() return; m_multiplexer.SetNotificationMask(this, m_descriptor, 0); - m_multiplexer.RemoveOpenSocket(this); + m_multiplexer.RemoveClientSocket(this); shutdown(m_descriptor, SD_BOTH); closesocket(m_descriptor); m_descriptor = INVALID_SOCKET; @@ -701,7 +724,7 @@ void StreamSocket::CloseWithError() error.SetSocket(error_code); m_multiplexer.SetNotificationMask(this, m_descriptor, 0); - m_multiplexer.RemoveOpenSocket(this); + m_multiplexer.RemoveClientSocket(this); closesocket(m_descriptor); m_descriptor = INVALID_SOCKET; m_connected = false; diff --git a/src/util/sockets.h b/src/util/sockets.h index ddd608eb2..4d5833dc4 100644 --- a/src/util/sockets.h +++ b/src/util/sockets.h @@ -13,8 +13,8 @@ #include #include #include -#include #include +#include #ifdef _WIN32 using SocketDescriptor = uintptr_t; @@ -108,6 +108,12 @@ public: // Returns true if any sockets are currently registered. bool HasAnyOpenSockets(); + // Returns true if any client sockets are currently connected. + bool HasAnyClientSockets(); + + // Returns the number of current client sockets. + size_t GetClientSocketCount(); + // Close all sockets on this multiplexer. void CloseAll(); @@ -127,7 +133,9 @@ private: // Tracking of open sockets. void AddOpenSocket(std::shared_ptr socket); + void AddClientSocket(std::shared_ptr socket); void RemoveOpenSocket(BaseSocket* socket); + void RemoveClientSocket(BaseSocket* socket); // Register for notifications void SetNotificationMask(BaseSocket* socket, SocketDescriptor descriptor, u32 events); @@ -143,6 +151,7 @@ private: std::mutex m_open_sockets_lock; SocketMap m_open_sockets; + std::atomic_size_t m_client_socket_count{0}; }; template