Sockets: Use epoll on Linux
This commit is contained in:
parent
78800870bc
commit
ad374ef5e2
|
@ -42,6 +42,10 @@ using nfds_t = ULONG;
|
|||
#include <sys/un.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#ifdef __linux__
|
||||
#include <sys/epoll.h>
|
||||
#endif
|
||||
|
||||
#define ioctlsocket ioctl
|
||||
#define closesocket close
|
||||
#define WSAEWOULDBLOCK EAGAIN
|
||||
|
@ -227,16 +231,42 @@ SocketMultiplexer::~SocketMultiplexer()
|
|||
{
|
||||
CloseAll();
|
||||
|
||||
#ifdef __linux__
|
||||
if (m_epoll_fd >= 0)
|
||||
close(m_epoll_fd);
|
||||
#else
|
||||
if (m_poll_array)
|
||||
std::free(m_poll_array);
|
||||
#endif
|
||||
}
|
||||
|
||||
std::unique_ptr<SocketMultiplexer> SocketMultiplexer::Create(Error* error)
|
||||
{
|
||||
if (!PlatformMisc::InitializeSocketSupport(error))
|
||||
return {};
|
||||
std::unique_ptr<SocketMultiplexer> ret;
|
||||
if (PlatformMisc::InitializeSocketSupport(error))
|
||||
{
|
||||
ret = std::unique_ptr<SocketMultiplexer>(new SocketMultiplexer());
|
||||
if (!ret->Initialize(error))
|
||||
ret.reset();
|
||||
}
|
||||
|
||||
return std::unique_ptr<SocketMultiplexer>(new SocketMultiplexer());
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool SocketMultiplexer::Initialize(Error* error)
|
||||
{
|
||||
#ifdef __linux__
|
||||
m_epoll_fd = epoll_create1(0);
|
||||
if (m_epoll_fd < 0)
|
||||
{
|
||||
Error::SetErrno(error, "epoll_create1() failed: ", errno);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
#else
|
||||
return true;
|
||||
#endif
|
||||
}
|
||||
|
||||
std::shared_ptr<ListenSocket> SocketMultiplexer::InternalCreateListenSocket(const SocketAddress& address,
|
||||
|
@ -325,8 +355,13 @@ std::shared_ptr<StreamSocket> SocketMultiplexer::InternalConnectStreamSocket(con
|
|||
|
||||
void SocketMultiplexer::AddOpenSocket(std::shared_ptr<BaseSocket> socket)
|
||||
{
|
||||
std::unique_lock lock(m_open_sockets_lock);
|
||||
#ifdef __linux__
|
||||
struct epoll_event ev = {.events = 0u, .data = {.fd = socket->GetDescriptor()}};
|
||||
if (epoll_ctl(m_epoll_fd, EPOLL_CTL_ADD, socket->GetDescriptor(), &ev) != 0) [[unlikely]]
|
||||
ERROR_LOG("epoll_ctl() to add socket failed: {}", Error::CreateErrno(errno).GetDescription());
|
||||
#endif
|
||||
|
||||
std::unique_lock lock(m_open_sockets_lock);
|
||||
DebugAssert(m_open_sockets.find(socket->GetDescriptor()) == m_open_sockets.end());
|
||||
m_open_sockets.emplace(socket->GetDescriptor(), std::move(socket));
|
||||
}
|
||||
|
@ -339,27 +374,29 @@ void SocketMultiplexer::AddClientSocket(std::shared_ptr<BaseSocket> socket)
|
|||
|
||||
void SocketMultiplexer::RemoveOpenSocket(BaseSocket* socket)
|
||||
{
|
||||
#ifdef _DEBUG
|
||||
{
|
||||
std::unique_lock lock(m_poll_array_lock);
|
||||
for (size_t i = 0; i < m_poll_array_active_size; i++)
|
||||
{
|
||||
pollfd& pfd = m_poll_array[i];
|
||||
DebugAssert(pfd.fd != socket->GetDescriptor());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
std::unique_lock lock(m_open_sockets_lock);
|
||||
const auto iter = m_open_sockets.find(socket->GetDescriptor());
|
||||
Assert(iter != m_open_sockets.end());
|
||||
m_open_sockets.erase(iter);
|
||||
|
||||
#ifdef __linux__
|
||||
if (epoll_ctl(m_epoll_fd, EPOLL_CTL_DEL, socket->GetDescriptor(), nullptr) != 0) [[unlikely]]
|
||||
ERROR_LOG("epoll_ctl() to remove socket failed: {}", Error::CreateErrno(errno).GetDescription());
|
||||
#else
|
||||
#ifdef _DEBUG
|
||||
for (size_t i = 0; i < m_poll_array_active_size; i++)
|
||||
{
|
||||
pollfd& pfd = m_poll_array[i];
|
||||
DebugAssert(pfd.fd != socket->GetDescriptor());
|
||||
}
|
||||
#endif
|
||||
|
||||
// Update size.
|
||||
size_t new_active_size = 0;
|
||||
for (size_t i = 0; i < m_poll_array_active_size; i++)
|
||||
new_active_size = (m_poll_array[i].fd != INVALID_SOCKET) ? (i + 1) : new_active_size;
|
||||
m_poll_array_active_size = new_active_size;
|
||||
#endif
|
||||
}
|
||||
|
||||
void SocketMultiplexer::RemoveClientSocket(BaseSocket* socket)
|
||||
|
@ -400,6 +437,11 @@ void SocketMultiplexer::CloseAll()
|
|||
|
||||
void SocketMultiplexer::SetNotificationMask(BaseSocket* socket, SocketDescriptor descriptor, u32 events)
|
||||
{
|
||||
#ifdef __linux__
|
||||
struct epoll_event ev = {.events = events, .data = {.fd = descriptor}};
|
||||
if (epoll_ctl(m_epoll_fd, EPOLL_CTL_MOD, descriptor, &ev) != 0) [[unlikely]]
|
||||
ERROR_LOG("epoll_ctl() for events 0x{:x} failed: {}", events, Error::CreateErrno(errno).GetDescription());
|
||||
#else
|
||||
std::unique_lock lock(m_poll_array_lock);
|
||||
size_t free_slot = m_poll_array_active_size;
|
||||
for (size_t i = 0; i < m_poll_array_active_size; i++)
|
||||
|
@ -440,10 +482,64 @@ void SocketMultiplexer::SetNotificationMask(BaseSocket* socket, SocketDescriptor
|
|||
|
||||
m_poll_array[free_slot] = {.fd = descriptor, .events = static_cast<short>(events), .revents = 0};
|
||||
m_poll_array_active_size = free_slot + 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)
|
||||
{
|
||||
#ifdef __linux__
|
||||
constexpr int MAX_EVENTS = 128;
|
||||
struct epoll_event events[MAX_EVENTS];
|
||||
|
||||
const int nevents = epoll_wait(m_epoll_fd, events, MAX_EVENTS, static_cast<int>(milliseconds));
|
||||
if (nevents <= 0)
|
||||
return false;
|
||||
|
||||
// find sockets that triggered, we use an array here so we can avoid holding the lock, and if a socket disconnects
|
||||
using PendingSocketPair = std::pair<std::shared_ptr<BaseSocket>, u32>;
|
||||
PendingSocketPair* triggered_sockets =
|
||||
reinterpret_cast<PendingSocketPair*>(alloca(sizeof(PendingSocketPair) * static_cast<size_t>(nevents)));
|
||||
size_t num_triggered_sockets = 0;
|
||||
{
|
||||
std::unique_lock open_lock(m_open_sockets_lock);
|
||||
for (int i = 0; i < nevents; i++)
|
||||
{
|
||||
const epoll_event& ev = events[i];
|
||||
const auto iter = m_open_sockets.find(ev.data.fd);
|
||||
if (iter == m_open_sockets.end()) [[unlikely]]
|
||||
{
|
||||
ERROR_LOG("Attempting to look up unknown socket {}, this should never happen.", ev.data.fd);
|
||||
continue;
|
||||
}
|
||||
|
||||
// we add a reference here in case the read kills it with a write pending, or something like that
|
||||
new (&triggered_sockets[num_triggered_sockets++]) PendingSocketPair(iter->second->shared_from_this(), ev.events);
|
||||
}
|
||||
}
|
||||
|
||||
// fire events
|
||||
for (size_t i = 0; i < num_triggered_sockets; i++)
|
||||
{
|
||||
PendingSocketPair& psp = triggered_sockets[i];
|
||||
|
||||
// fire events
|
||||
if (psp.second & (EPOLLRDHUP | EPOLLHUP | EPOLLERR))
|
||||
{
|
||||
psp.first->OnHangupEvent();
|
||||
}
|
||||
else
|
||||
{
|
||||
if (psp.second & EPOLLIN)
|
||||
psp.first->OnReadEvent();
|
||||
if (psp.second & EPOLLOUT)
|
||||
psp.first->OnWriteEvent();
|
||||
}
|
||||
|
||||
psp.first.~shared_ptr();
|
||||
}
|
||||
|
||||
return true;
|
||||
#else
|
||||
std::unique_lock lock(m_poll_array_lock);
|
||||
if (m_poll_array_active_size == 0)
|
||||
return false;
|
||||
|
@ -454,7 +550,8 @@ bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)
|
|||
|
||||
// find sockets that triggered, we use an array here so we can avoid holding the lock, and if a socket disconnects
|
||||
using PendingSocketPair = std::pair<std::shared_ptr<BaseSocket>, u32>;
|
||||
PendingSocketPair* triggered_sockets = reinterpret_cast<PendingSocketPair*>(alloca(sizeof(PendingSocketPair) * res));
|
||||
PendingSocketPair* triggered_sockets =
|
||||
reinterpret_cast<PendingSocketPair*>(alloca(sizeof(PendingSocketPair) * static_cast<size_t>(res)));
|
||||
size_t num_triggered_sockets = 0;
|
||||
{
|
||||
std::unique_lock open_lock(m_open_sockets_lock);
|
||||
|
@ -467,7 +564,7 @@ bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)
|
|||
const auto iter = m_open_sockets.find(pfd.fd);
|
||||
if (iter == m_open_sockets.end()) [[unlikely]]
|
||||
{
|
||||
ERROR_LOG("Attempting to look up known socket {}, this should never happen.", pfd.fd);
|
||||
ERROR_LOG("Attempting to look up unknown socket {}, this should never happen.", pfd.fd);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -481,7 +578,7 @@ bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)
|
|||
lock.unlock();
|
||||
|
||||
// fire events
|
||||
for (u32 i = 0; i < num_triggered_sockets; i++)
|
||||
for (size_t i = 0; i < num_triggered_sockets; i++)
|
||||
{
|
||||
PendingSocketPair& psp = triggered_sockets[i];
|
||||
|
||||
|
@ -502,6 +599,7 @@ bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)
|
|||
}
|
||||
|
||||
return true;
|
||||
#endif
|
||||
}
|
||||
|
||||
ListenSocket::ListenSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor,
|
||||
|
|
|
@ -135,6 +135,9 @@ private:
|
|||
// Hide the constructor.
|
||||
SocketMultiplexer();
|
||||
|
||||
// Initialization.
|
||||
bool Initialize(Error* error);
|
||||
|
||||
// Tracking of open sockets.
|
||||
void AddOpenSocket(std::shared_ptr<BaseSocket> socket);
|
||||
void AddClientSocket(std::shared_ptr<BaseSocket> socket);
|
||||
|
@ -148,10 +151,14 @@ private:
|
|||
// We store the fd in the struct to avoid the cache miss reading the object.
|
||||
using SocketMap = std::unordered_map<SocketDescriptor, std::shared_ptr<BaseSocket>>;
|
||||
|
||||
#ifdef __linux__
|
||||
int m_epoll_fd = -1;
|
||||
#else
|
||||
std::mutex m_poll_array_lock;
|
||||
pollfd* m_poll_array = nullptr;
|
||||
size_t m_poll_array_active_size = 0;
|
||||
size_t m_poll_array_max_size = 0;
|
||||
#endif
|
||||
|
||||
std::mutex m_open_sockets_lock;
|
||||
SocketMap m_open_sockets;
|
||||
|
|
Loading…
Reference in New Issue