Sockets: Add HasAnyClientSockets()

This commit is contained in:
Stenzek 2024-07-06 21:45:48 +10:00
parent b06fceffa4
commit 1fd8d2701d
No known key found for this signature in database
2 changed files with 37 additions and 5 deletions

View File

@ -321,7 +321,13 @@ void SocketMultiplexer::AddOpenSocket(std::shared_ptr<BaseSocket> socket)
std::unique_lock lock(m_open_sockets_lock); std::unique_lock lock(m_open_sockets_lock);
DebugAssert(m_open_sockets.find(socket->GetDescriptor()) == m_open_sockets.end()); DebugAssert(m_open_sockets.find(socket->GetDescriptor()) == m_open_sockets.end());
m_open_sockets.emplace(socket->GetDescriptor(), socket); m_open_sockets.emplace(socket->GetDescriptor(), std::move(socket));
}
void SocketMultiplexer::AddClientSocket(std::shared_ptr<BaseSocket> socket)
{
AddOpenSocket(std::move(socket));
m_client_socket_count.fetch_add(1, std::memory_order_acq_rel);
} }
void SocketMultiplexer::RemoveOpenSocket(BaseSocket* socket) void SocketMultiplexer::RemoveOpenSocket(BaseSocket* socket)
@ -349,12 +355,29 @@ void SocketMultiplexer::RemoveOpenSocket(BaseSocket* socket)
m_poll_array_active_size = new_active_size; m_poll_array_active_size = new_active_size;
} }
void SocketMultiplexer::RemoveClientSocket(BaseSocket* socket)
{
DebugAssert(m_client_socket_count.load(std::memory_order_acquire) > 0);
m_client_socket_count.fetch_sub(1, std::memory_order_acq_rel);
RemoveOpenSocket(socket);
}
bool SocketMultiplexer::HasAnyOpenSockets() bool SocketMultiplexer::HasAnyOpenSockets()
{ {
std::unique_lock lock(m_open_sockets_lock); std::unique_lock lock(m_open_sockets_lock);
return !m_open_sockets.empty(); return !m_open_sockets.empty();
} }
bool SocketMultiplexer::HasAnyClientSockets()
{
return (GetClientSocketCount() > 0);
}
size_t SocketMultiplexer::GetClientSocketCount()
{
return m_client_socket_count.load(std::memory_order_acquire);
}
void SocketMultiplexer::CloseAll() void SocketMultiplexer::CloseAll()
{ {
std::unique_lock lock(m_open_sockets_lock); std::unique_lock lock(m_open_sockets_lock);
@ -559,7 +582,7 @@ u32 StreamSocket::GetSocketProtocolForAddress(const SocketAddress& sa)
void StreamSocket::InitialSetup() void StreamSocket::InitialSetup()
{ {
// register for notifications // register for notifications
m_multiplexer.AddOpenSocket(shared_from_this()); m_multiplexer.AddClientSocket(shared_from_this());
m_multiplexer.SetNotificationMask(this, m_descriptor, POLLIN); m_multiplexer.SetNotificationMask(this, m_descriptor, POLLIN);
// trigger connected notification // trigger connected notification
@ -679,7 +702,7 @@ void StreamSocket::Close()
return; return;
m_multiplexer.SetNotificationMask(this, m_descriptor, 0); m_multiplexer.SetNotificationMask(this, m_descriptor, 0);
m_multiplexer.RemoveOpenSocket(this); m_multiplexer.RemoveClientSocket(this);
shutdown(m_descriptor, SD_BOTH); shutdown(m_descriptor, SD_BOTH);
closesocket(m_descriptor); closesocket(m_descriptor);
m_descriptor = INVALID_SOCKET; m_descriptor = INVALID_SOCKET;
@ -701,7 +724,7 @@ void StreamSocket::CloseWithError()
error.SetSocket(error_code); error.SetSocket(error_code);
m_multiplexer.SetNotificationMask(this, m_descriptor, 0); m_multiplexer.SetNotificationMask(this, m_descriptor, 0);
m_multiplexer.RemoveOpenSocket(this); m_multiplexer.RemoveClientSocket(this);
closesocket(m_descriptor); closesocket(m_descriptor);
m_descriptor = INVALID_SOCKET; m_descriptor = INVALID_SOCKET;
m_connected = false; m_connected = false;

View File

@ -13,8 +13,8 @@
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <optional> #include <optional>
#include <unordered_map>
#include <span> #include <span>
#include <unordered_map>
#ifdef _WIN32 #ifdef _WIN32
using SocketDescriptor = uintptr_t; using SocketDescriptor = uintptr_t;
@ -108,6 +108,12 @@ public:
// Returns true if any sockets are currently registered. // Returns true if any sockets are currently registered.
bool HasAnyOpenSockets(); bool HasAnyOpenSockets();
// Returns true if any client sockets are currently connected.
bool HasAnyClientSockets();
// Returns the number of current client sockets.
size_t GetClientSocketCount();
// Close all sockets on this multiplexer. // Close all sockets on this multiplexer.
void CloseAll(); void CloseAll();
@ -127,7 +133,9 @@ private:
// Tracking of open sockets. // Tracking of open sockets.
void AddOpenSocket(std::shared_ptr<BaseSocket> socket); void AddOpenSocket(std::shared_ptr<BaseSocket> socket);
void AddClientSocket(std::shared_ptr<BaseSocket> socket);
void RemoveOpenSocket(BaseSocket* socket); void RemoveOpenSocket(BaseSocket* socket);
void RemoveClientSocket(BaseSocket* socket);
// Register for notifications // Register for notifications
void SetNotificationMask(BaseSocket* socket, SocketDescriptor descriptor, u32 events); void SetNotificationMask(BaseSocket* socket, SocketDescriptor descriptor, u32 events);
@ -143,6 +151,7 @@ private:
std::mutex m_open_sockets_lock; std::mutex m_open_sockets_lock;
SocketMap m_open_sockets; SocketMap m_open_sockets;
std::atomic_size_t m_client_socket_count{0};
}; };
template<class T> template<class T>