Sockets: Use epoll on Linux
This commit is contained in:
parent
78800870bc
commit
ad374ef5e2
|
@ -42,6 +42,10 @@ using nfds_t = ULONG;
|
||||||
#include <sys/un.h>
|
#include <sys/un.h>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
|
|
||||||
|
#ifdef __linux__
|
||||||
|
#include <sys/epoll.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
#define ioctlsocket ioctl
|
#define ioctlsocket ioctl
|
||||||
#define closesocket close
|
#define closesocket close
|
||||||
#define WSAEWOULDBLOCK EAGAIN
|
#define WSAEWOULDBLOCK EAGAIN
|
||||||
|
@ -227,16 +231,42 @@ SocketMultiplexer::~SocketMultiplexer()
|
||||||
{
|
{
|
||||||
CloseAll();
|
CloseAll();
|
||||||
|
|
||||||
|
#ifdef __linux__
|
||||||
|
if (m_epoll_fd >= 0)
|
||||||
|
close(m_epoll_fd);
|
||||||
|
#else
|
||||||
if (m_poll_array)
|
if (m_poll_array)
|
||||||
std::free(m_poll_array);
|
std::free(m_poll_array);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<SocketMultiplexer> SocketMultiplexer::Create(Error* error)
|
std::unique_ptr<SocketMultiplexer> SocketMultiplexer::Create(Error* error)
|
||||||
{
|
{
|
||||||
if (!PlatformMisc::InitializeSocketSupport(error))
|
std::unique_ptr<SocketMultiplexer> ret;
|
||||||
return {};
|
if (PlatformMisc::InitializeSocketSupport(error))
|
||||||
|
{
|
||||||
|
ret = std::unique_ptr<SocketMultiplexer>(new SocketMultiplexer());
|
||||||
|
if (!ret->Initialize(error))
|
||||||
|
ret.reset();
|
||||||
|
}
|
||||||
|
|
||||||
return std::unique_ptr<SocketMultiplexer>(new SocketMultiplexer());
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SocketMultiplexer::Initialize(Error* error)
|
||||||
|
{
|
||||||
|
#ifdef __linux__
|
||||||
|
m_epoll_fd = epoll_create1(0);
|
||||||
|
if (m_epoll_fd < 0)
|
||||||
|
{
|
||||||
|
Error::SetErrno(error, "epoll_create1() failed: ", errno);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
#else
|
||||||
|
return true;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ListenSocket> SocketMultiplexer::InternalCreateListenSocket(const SocketAddress& address,
|
std::shared_ptr<ListenSocket> SocketMultiplexer::InternalCreateListenSocket(const SocketAddress& address,
|
||||||
|
@ -325,8 +355,13 @@ std::shared_ptr<StreamSocket> SocketMultiplexer::InternalConnectStreamSocket(con
|
||||||
|
|
||||||
void SocketMultiplexer::AddOpenSocket(std::shared_ptr<BaseSocket> socket)
|
void SocketMultiplexer::AddOpenSocket(std::shared_ptr<BaseSocket> socket)
|
||||||
{
|
{
|
||||||
std::unique_lock lock(m_open_sockets_lock);
|
#ifdef __linux__
|
||||||
|
struct epoll_event ev = {.events = 0u, .data = {.fd = socket->GetDescriptor()}};
|
||||||
|
if (epoll_ctl(m_epoll_fd, EPOLL_CTL_ADD, socket->GetDescriptor(), &ev) != 0) [[unlikely]]
|
||||||
|
ERROR_LOG("epoll_ctl() to add socket failed: {}", Error::CreateErrno(errno).GetDescription());
|
||||||
|
#endif
|
||||||
|
|
||||||
|
std::unique_lock lock(m_open_sockets_lock);
|
||||||
DebugAssert(m_open_sockets.find(socket->GetDescriptor()) == m_open_sockets.end());
|
DebugAssert(m_open_sockets.find(socket->GetDescriptor()) == m_open_sockets.end());
|
||||||
m_open_sockets.emplace(socket->GetDescriptor(), std::move(socket));
|
m_open_sockets.emplace(socket->GetDescriptor(), std::move(socket));
|
||||||
}
|
}
|
||||||
|
@ -339,27 +374,29 @@ void SocketMultiplexer::AddClientSocket(std::shared_ptr<BaseSocket> socket)
|
||||||
|
|
||||||
void SocketMultiplexer::RemoveOpenSocket(BaseSocket* socket)
|
void SocketMultiplexer::RemoveOpenSocket(BaseSocket* socket)
|
||||||
{
|
{
|
||||||
#ifdef _DEBUG
|
|
||||||
{
|
|
||||||
std::unique_lock lock(m_poll_array_lock);
|
|
||||||
for (size_t i = 0; i < m_poll_array_active_size; i++)
|
|
||||||
{
|
|
||||||
pollfd& pfd = m_poll_array[i];
|
|
||||||
DebugAssert(pfd.fd != socket->GetDescriptor());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
std::unique_lock lock(m_open_sockets_lock);
|
std::unique_lock lock(m_open_sockets_lock);
|
||||||
const auto iter = m_open_sockets.find(socket->GetDescriptor());
|
const auto iter = m_open_sockets.find(socket->GetDescriptor());
|
||||||
Assert(iter != m_open_sockets.end());
|
Assert(iter != m_open_sockets.end());
|
||||||
m_open_sockets.erase(iter);
|
m_open_sockets.erase(iter);
|
||||||
|
|
||||||
|
#ifdef __linux__
|
||||||
|
if (epoll_ctl(m_epoll_fd, EPOLL_CTL_DEL, socket->GetDescriptor(), nullptr) != 0) [[unlikely]]
|
||||||
|
ERROR_LOG("epoll_ctl() to remove socket failed: {}", Error::CreateErrno(errno).GetDescription());
|
||||||
|
#else
|
||||||
|
#ifdef _DEBUG
|
||||||
|
for (size_t i = 0; i < m_poll_array_active_size; i++)
|
||||||
|
{
|
||||||
|
pollfd& pfd = m_poll_array[i];
|
||||||
|
DebugAssert(pfd.fd != socket->GetDescriptor());
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// Update size.
|
// Update size.
|
||||||
size_t new_active_size = 0;
|
size_t new_active_size = 0;
|
||||||
for (size_t i = 0; i < m_poll_array_active_size; i++)
|
for (size_t i = 0; i < m_poll_array_active_size; i++)
|
||||||
new_active_size = (m_poll_array[i].fd != INVALID_SOCKET) ? (i + 1) : new_active_size;
|
new_active_size = (m_poll_array[i].fd != INVALID_SOCKET) ? (i + 1) : new_active_size;
|
||||||
m_poll_array_active_size = new_active_size;
|
m_poll_array_active_size = new_active_size;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void SocketMultiplexer::RemoveClientSocket(BaseSocket* socket)
|
void SocketMultiplexer::RemoveClientSocket(BaseSocket* socket)
|
||||||
|
@ -400,6 +437,11 @@ void SocketMultiplexer::CloseAll()
|
||||||
|
|
||||||
void SocketMultiplexer::SetNotificationMask(BaseSocket* socket, SocketDescriptor descriptor, u32 events)
|
void SocketMultiplexer::SetNotificationMask(BaseSocket* socket, SocketDescriptor descriptor, u32 events)
|
||||||
{
|
{
|
||||||
|
#ifdef __linux__
|
||||||
|
struct epoll_event ev = {.events = events, .data = {.fd = descriptor}};
|
||||||
|
if (epoll_ctl(m_epoll_fd, EPOLL_CTL_MOD, descriptor, &ev) != 0) [[unlikely]]
|
||||||
|
ERROR_LOG("epoll_ctl() for events 0x{:x} failed: {}", events, Error::CreateErrno(errno).GetDescription());
|
||||||
|
#else
|
||||||
std::unique_lock lock(m_poll_array_lock);
|
std::unique_lock lock(m_poll_array_lock);
|
||||||
size_t free_slot = m_poll_array_active_size;
|
size_t free_slot = m_poll_array_active_size;
|
||||||
for (size_t i = 0; i < m_poll_array_active_size; i++)
|
for (size_t i = 0; i < m_poll_array_active_size; i++)
|
||||||
|
@ -440,10 +482,64 @@ void SocketMultiplexer::SetNotificationMask(BaseSocket* socket, SocketDescriptor
|
||||||
|
|
||||||
m_poll_array[free_slot] = {.fd = descriptor, .events = static_cast<short>(events), .revents = 0};
|
m_poll_array[free_slot] = {.fd = descriptor, .events = static_cast<short>(events), .revents = 0};
|
||||||
m_poll_array_active_size = free_slot + 1;
|
m_poll_array_active_size = free_slot + 1;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)
|
bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)
|
||||||
{
|
{
|
||||||
|
#ifdef __linux__
|
||||||
|
constexpr int MAX_EVENTS = 128;
|
||||||
|
struct epoll_event events[MAX_EVENTS];
|
||||||
|
|
||||||
|
const int nevents = epoll_wait(m_epoll_fd, events, MAX_EVENTS, static_cast<int>(milliseconds));
|
||||||
|
if (nevents <= 0)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// find sockets that triggered, we use an array here so we can avoid holding the lock, and if a socket disconnects
|
||||||
|
using PendingSocketPair = std::pair<std::shared_ptr<BaseSocket>, u32>;
|
||||||
|
PendingSocketPair* triggered_sockets =
|
||||||
|
reinterpret_cast<PendingSocketPair*>(alloca(sizeof(PendingSocketPair) * static_cast<size_t>(nevents)));
|
||||||
|
size_t num_triggered_sockets = 0;
|
||||||
|
{
|
||||||
|
std::unique_lock open_lock(m_open_sockets_lock);
|
||||||
|
for (int i = 0; i < nevents; i++)
|
||||||
|
{
|
||||||
|
const epoll_event& ev = events[i];
|
||||||
|
const auto iter = m_open_sockets.find(ev.data.fd);
|
||||||
|
if (iter == m_open_sockets.end()) [[unlikely]]
|
||||||
|
{
|
||||||
|
ERROR_LOG("Attempting to look up unknown socket {}, this should never happen.", ev.data.fd);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// we add a reference here in case the read kills it with a write pending, or something like that
|
||||||
|
new (&triggered_sockets[num_triggered_sockets++]) PendingSocketPair(iter->second->shared_from_this(), ev.events);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fire events
|
||||||
|
for (size_t i = 0; i < num_triggered_sockets; i++)
|
||||||
|
{
|
||||||
|
PendingSocketPair& psp = triggered_sockets[i];
|
||||||
|
|
||||||
|
// fire events
|
||||||
|
if (psp.second & (EPOLLRDHUP | EPOLLHUP | EPOLLERR))
|
||||||
|
{
|
||||||
|
psp.first->OnHangupEvent();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if (psp.second & EPOLLIN)
|
||||||
|
psp.first->OnReadEvent();
|
||||||
|
if (psp.second & EPOLLOUT)
|
||||||
|
psp.first->OnWriteEvent();
|
||||||
|
}
|
||||||
|
|
||||||
|
psp.first.~shared_ptr();
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
#else
|
||||||
std::unique_lock lock(m_poll_array_lock);
|
std::unique_lock lock(m_poll_array_lock);
|
||||||
if (m_poll_array_active_size == 0)
|
if (m_poll_array_active_size == 0)
|
||||||
return false;
|
return false;
|
||||||
|
@ -454,7 +550,8 @@ bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)
|
||||||
|
|
||||||
// find sockets that triggered, we use an array here so we can avoid holding the lock, and if a socket disconnects
|
// find sockets that triggered, we use an array here so we can avoid holding the lock, and if a socket disconnects
|
||||||
using PendingSocketPair = std::pair<std::shared_ptr<BaseSocket>, u32>;
|
using PendingSocketPair = std::pair<std::shared_ptr<BaseSocket>, u32>;
|
||||||
PendingSocketPair* triggered_sockets = reinterpret_cast<PendingSocketPair*>(alloca(sizeof(PendingSocketPair) * res));
|
PendingSocketPair* triggered_sockets =
|
||||||
|
reinterpret_cast<PendingSocketPair*>(alloca(sizeof(PendingSocketPair) * static_cast<size_t>(res)));
|
||||||
size_t num_triggered_sockets = 0;
|
size_t num_triggered_sockets = 0;
|
||||||
{
|
{
|
||||||
std::unique_lock open_lock(m_open_sockets_lock);
|
std::unique_lock open_lock(m_open_sockets_lock);
|
||||||
|
@ -467,7 +564,7 @@ bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)
|
||||||
const auto iter = m_open_sockets.find(pfd.fd);
|
const auto iter = m_open_sockets.find(pfd.fd);
|
||||||
if (iter == m_open_sockets.end()) [[unlikely]]
|
if (iter == m_open_sockets.end()) [[unlikely]]
|
||||||
{
|
{
|
||||||
ERROR_LOG("Attempting to look up known socket {}, this should never happen.", pfd.fd);
|
ERROR_LOG("Attempting to look up unknown socket {}, this should never happen.", pfd.fd);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -481,7 +578,7 @@ bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
|
|
||||||
// fire events
|
// fire events
|
||||||
for (u32 i = 0; i < num_triggered_sockets; i++)
|
for (size_t i = 0; i < num_triggered_sockets; i++)
|
||||||
{
|
{
|
||||||
PendingSocketPair& psp = triggered_sockets[i];
|
PendingSocketPair& psp = triggered_sockets[i];
|
||||||
|
|
||||||
|
@ -502,6 +599,7 @@ bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
ListenSocket::ListenSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor,
|
ListenSocket::ListenSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor,
|
||||||
|
|
|
@ -135,6 +135,9 @@ private:
|
||||||
// Hide the constructor.
|
// Hide the constructor.
|
||||||
SocketMultiplexer();
|
SocketMultiplexer();
|
||||||
|
|
||||||
|
// Initialization.
|
||||||
|
bool Initialize(Error* error);
|
||||||
|
|
||||||
// Tracking of open sockets.
|
// Tracking of open sockets.
|
||||||
void AddOpenSocket(std::shared_ptr<BaseSocket> socket);
|
void AddOpenSocket(std::shared_ptr<BaseSocket> socket);
|
||||||
void AddClientSocket(std::shared_ptr<BaseSocket> socket);
|
void AddClientSocket(std::shared_ptr<BaseSocket> socket);
|
||||||
|
@ -148,10 +151,14 @@ private:
|
||||||
// We store the fd in the struct to avoid the cache miss reading the object.
|
// We store the fd in the struct to avoid the cache miss reading the object.
|
||||||
using SocketMap = std::unordered_map<SocketDescriptor, std::shared_ptr<BaseSocket>>;
|
using SocketMap = std::unordered_map<SocketDescriptor, std::shared_ptr<BaseSocket>>;
|
||||||
|
|
||||||
|
#ifdef __linux__
|
||||||
|
int m_epoll_fd = -1;
|
||||||
|
#else
|
||||||
std::mutex m_poll_array_lock;
|
std::mutex m_poll_array_lock;
|
||||||
pollfd* m_poll_array = nullptr;
|
pollfd* m_poll_array = nullptr;
|
||||||
size_t m_poll_array_active_size = 0;
|
size_t m_poll_array_active_size = 0;
|
||||||
size_t m_poll_array_max_size = 0;
|
size_t m_poll_array_max_size = 0;
|
||||||
|
#endif
|
||||||
|
|
||||||
std::mutex m_open_sockets_lock;
|
std::mutex m_open_sockets_lock;
|
||||||
SocketMap m_open_sockets;
|
SocketMap m_open_sockets;
|
||||||
|
|
Loading…
Reference in New Issue