From 779f275486dc720858c7f91416803fe5f6b43e74 Mon Sep 17 00:00:00 2001 From: Ziek Date: Mon, 16 Feb 2015 22:53:50 -0800 Subject: [PATCH] Added TraversalServer.cpp to Core/Common --- Externals/enet/enet.vcxproj | 5 + Source/Core/Common/CMakeLists.txt | 1 + Source/Core/Common/TraversalClient.cpp | 12 +- Source/Core/Common/TraversalClient.h | 4 +- Source/Core/Common/TraversalServer.cpp | 458 +++++++++++++++++++++++++ Source/Core/DolphinWX/NetWindow.cpp | 4 +- 6 files changed, 472 insertions(+), 12 deletions(-) create mode 100644 Source/Core/Common/TraversalServer.cpp diff --git a/Externals/enet/enet.vcxproj b/Externals/enet/enet.vcxproj index b188b20267..f19b35a8a8 100644 --- a/Externals/enet/enet.vcxproj +++ b/Externals/enet/enet.vcxproj @@ -76,6 +76,11 @@ $(ExternalsDir)enet\include;%(AdditionalIncludeDirectories) + + + $(ExternalsDir)enet\include;%(AdditionalIncludeDirectories) + + diff --git a/Source/Core/Common/CMakeLists.txt b/Source/Core/Common/CMakeLists.txt index e4054d286e..bd6f21a81f 100644 --- a/Source/Core/Common/CMakeLists.txt +++ b/Source/Core/Common/CMakeLists.txt @@ -67,3 +67,4 @@ if(NOT APPLE AND NOT ANDROID) endif() add_dolphin_library(common "${SRCS}" "${LIBS}") +add_executable(traversal_server TraversalServer.cpp) diff --git a/Source/Core/Common/TraversalClient.cpp b/Source/Core/Common/TraversalClient.cpp index e9e6c5939b..574f002e4e 100644 --- a/Source/Core/Common/TraversalClient.cpp +++ b/Source/Core/Common/TraversalClient.cpp @@ -1,6 +1,6 @@ // This file is public domain, in case it's useful to anyone. -comex -#include "Timer.h" +#include "Common/Timer.h" #include "Common/TraversalClient.h" static void GetRandomishBytes(u8* buf, size_t size) @@ -207,7 +207,7 @@ void TraversalClient::HandleServerPacket(TraversalPacket* packet) } } -void TraversalClient::OnFailure(int reason) +void TraversalClient::OnFailure(FailureReason reason) { m_State = Failure; m_FailureReason = reason; @@ -216,8 +216,7 @@ void TraversalClient::OnFailure(int reason) { case TraversalClient::BadHost: { - auto server = "dolphin-emu.org"; - PanicAlertT("Couldn't look up central server %s", server); + PanicAlertT("Couldn't look up central server %s", m_Server.c_str()); break; } case TraversalClient::VersionTooOld: @@ -232,9 +231,6 @@ void TraversalClient::OnFailure(int reason) case TraversalClient::ResendTimeout: PanicAlertT("Timeout connecting to traversal server"); break; - default: - PanicAlertT("Unknown error %x", reason); - break; } if (m_Client) @@ -279,7 +275,7 @@ void TraversalClient::HandlePing() enet_uint32 now = enet_time_get(); if (m_State == Connected && now - m_PingTime >= 500) { - TraversalPacket ping = {0}; + TraversalPacket ping = {}; ping.type = TraversalPacketPing; ping.ping.hostId = m_HostId; SendTraversalPacket(ping); diff --git a/Source/Core/Common/TraversalClient.h b/Source/Core/Common/TraversalClient.h index 804f3226f8..7dbc92286e 100644 --- a/Source/Core/Common/TraversalClient.h +++ b/Source/Core/Common/TraversalClient.h @@ -4,10 +4,10 @@ #include #include #include +#include #include "Common/Common.h" #include "Common/Thread.h" #include "Common/TraversalProto.h" -#include "enet/include/enet/enet.h" class TraversalClientClient { @@ -62,7 +62,7 @@ private: void HandleServerPacket(TraversalPacket* packet); void ResendPacket(OutgoingTraversalPacketInfo* info); TraversalRequestId SendTraversalPacket(const TraversalPacket& packet); - void OnFailure(int reason); + void OnFailure(FailureReason reason); void HandlePing(); static int ENET_CALLBACK InterceptCallback(ENetHost* host, ENetEvent* event); TraversalRequestId m_ConnectRequestId; diff --git a/Source/Core/Common/TraversalServer.cpp b/Source/Core/Common/TraversalServer.cpp new file mode 100644 index 0000000000..777078bccb --- /dev/null +++ b/Source/Core/Common/TraversalServer.cpp @@ -0,0 +1,458 @@ +// This file is public domain, in case it's useful to anyone. -comex + +// The central server implementation. +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "Common/TraversalProto.h" + +#define DEBUG 1 +#define NUMBER_OF_TRIES 5 + +static u64 currentTime; + +struct OutgoingPacketInfo +{ + TraversalPacket packet; + TraversalRequestId misc; + sockaddr_in6 dest; + int tries; + u64 sendTime; +}; + +template +struct EvictEntry +{ + u64 updateTime; + T value; +}; + +template +struct EvictFindResult +{ + bool found; + V* value; +}; + +template +EvictFindResult EvictFind(std::unordered_map>& map, const K& key, bool refresh = false) +{ + retry: + const u64 expiryTime = 30 * 1000000; // 30s + EvictFindResult result; + if (map.bucket_count()) + { + auto bucket = map.bucket(key); + auto it = map.begin(bucket); + for (; it != map.end(bucket); ++it) + { + if (currentTime - it->second.updateTime > expiryTime) + { + map.erase(it->first); + goto retry; + } + if (it->first == key) + { + if (refresh) + it->second.updateTime = currentTime; + result.found = true; + result.value = &it->second.value; + return result; + } + } + } +#if DEBUG + printf("failed to find key '"); + for (size_t i = 0; i < sizeof(key); i++) { + printf("%02x", ((u8 *) &key)[i]); + } + printf("'\n"); +#endif + result.found = false; + return result; +} + +template +V* EvictSet(std::unordered_map>& map, const K& key) +{ + // can't use a local_iterator to emplace... + auto& result = map[key]; + result.updateTime = currentTime; + return &result.value; +} + +namespace std +{ + template <> + struct hash + { + size_t operator()(const TraversalHostId& id) const + { + auto p = (u32*) id.data(); + return p[0] ^ ((p[1] << 13) | (p[1] >> 19)); + } + }; +} + +static int sock; +static int urandomFd; +static std::unordered_map< + TraversalRequestId, + OutgoingPacketInfo +> outgoingPackets; +static std::unordered_map< + TraversalHostId, + EvictEntry +> connectedClients; + +static TraversalInetAddress MakeInetAddress(const sockaddr_in6& addr) +{ + if (addr.sin6_family != AF_INET6) + { + fprintf(stderr, "bad sockaddr_in6\n"); + exit(1); + } + u32* words = (u32*) addr.sin6_addr.s6_addr; + TraversalInetAddress result = {0}; + if (words[0] == 0 && words[1] == 0 && words[2] == 0xffff0000) + { + result.isIPV6 = false; + result.address[0] = words[3]; + } + else + { + result.isIPV6 = true; + memcpy(result.address, words, sizeof(result.address)); + } + result.port = addr.sin6_port; + return result; +} + +static sockaddr_in6 MakeSinAddr(const TraversalInetAddress& addr) +{ + sockaddr_in6 result; +#ifdef SIN6_LEN + result.sin6_len = sizeof(result); +#endif + result.sin6_family = AF_INET6; + result.sin6_port = addr.port; + result.sin6_flowinfo = 0; + if (addr.isIPV6) + { + memcpy(&result.sin6_addr, addr.address, 16); + } + else + { + u32* words = (u32*) result.sin6_addr.s6_addr; + words[0] = 0; + words[1] = 0; + words[2] = 0xffff0000; + words[3] = addr.address[0]; + } + result.sin6_scope_id = 0; + return result; +} + +static void GetRandomBytes(void* output, size_t size) +{ + static u8 bytes[8192]; + static size_t bytesLeft = 0; + if (bytesLeft < size) + { + ssize_t rv = read(urandomFd, bytes, sizeof(bytes)); + if (rv != sizeof(bytes)) + { + perror("read from /dev/urandom"); + exit(1); + } + bytesLeft = sizeof(bytes); + } + memcpy(output, bytes + (bytesLeft -= size), size); +} + +static void GetRandomHostId(TraversalHostId* hostId) +{ + char buf[9]; + u32 num; + GetRandomBytes(&num, sizeof(num)); + sprintf(buf, "%08x", num); + memcpy(hostId->data(), buf, 8); +} + +static const char* SenderName(sockaddr_in6* addr) +{ + static char buf[INET6_ADDRSTRLEN + 10]; + inet_ntop(PF_INET6, &addr->sin6_addr, buf, sizeof(buf)); + sprintf(buf + strlen(buf), ":%d", ntohs(addr->sin6_port)); + return buf; +} + +static void TrySend(const void* buffer, size_t size, sockaddr_in6* addr) +{ +#if DEBUG + printf("-> %d %lu %s\n", ((TraversalPacket*) buffer)->type, ((TraversalPacket*) buffer)->requestId, SenderName(addr)); +#endif + if ((size_t) sendto(sock, buffer, size, 0, (sockaddr*) addr, sizeof(*addr)) != size) + { + perror("sendto"); + } +} + +static TraversalPacket* AllocPacket(const sockaddr_in6& dest, TraversalRequestId misc = 0) +{ + TraversalRequestId requestId; + GetRandomBytes(&requestId, sizeof(requestId)); + OutgoingPacketInfo* info = &outgoingPackets[requestId]; + info->dest = dest; + info->misc = misc; + info->tries = 0; + info->sendTime = currentTime; + TraversalPacket* result = &info->packet; + memset(result, 0, sizeof(*result)); + result->requestId = requestId; + return result; +} + +static void SendPacket(OutgoingPacketInfo* info) +{ + info->tries++; + info->sendTime = currentTime; + TrySend(&info->packet, sizeof(info->packet), &info->dest); +} + + +static void ResendPackets() +{ + std::vector> todoFailures; + todoFailures.clear(); + for (auto it = outgoingPackets.begin(); it != outgoingPackets.end();) + { + OutgoingPacketInfo* info = &it->second; + if (currentTime - info->sendTime >= (u64) (300000 * info->tries)) + { + if (info->tries >= NUMBER_OF_TRIES) + { + if (info->packet.type == TraversalPacketPleaseSendPacket) + { + todoFailures.push_back(std::make_pair(info->packet.pleaseSendPacket.address, info->misc)); + } + it = outgoingPackets.erase(it); + continue; + } + else + { + SendPacket(info); + } + } + ++it; + } + + for (const auto& p : todoFailures) + { + TraversalPacket* fail = AllocPacket(MakeSinAddr(p.first)); + fail->type = TraversalPacketConnectFailed; + fail->connectFailed.requestId = p.second; + fail->connectFailed.reason = TraversalConnectFailedClientDidntRespond; + } +} + +static void HandlePacket(TraversalPacket* packet, sockaddr_in6* addr) +{ +#if DEBUG + printf("<- %d %lu %s\n", packet->type, packet->requestId, SenderName(addr)); +#endif + bool packetOk = true; + switch (packet->type) + { + case TraversalPacketAck: + { + auto it = outgoingPackets.find(packet->requestId); + if (it == outgoingPackets.end()) + break; + + OutgoingPacketInfo* info = &it->second; + + if (info->packet.type == TraversalPacketPleaseSendPacket) + { + TraversalPacket* ready = AllocPacket(MakeSinAddr(info->packet.pleaseSendPacket.address)); + if (packet->ack.ok) + { + ready->type = TraversalPacketConnectReady; + ready->connectReady.requestId = info->misc; + ready->connectReady.address = MakeInetAddress(info->dest); + } + else + { + ready->type = TraversalPacketConnectFailed; + ready->connectFailed.requestId = info->misc; + ready->connectFailed.reason = TraversalConnectFailedClientFailure; + } + } + + outgoingPackets.erase(it); + break; + } + case TraversalPacketPing: + { + auto r = EvictFind(connectedClients, packet->ping.hostId, true); + packetOk = r.found; + break; + } + case TraversalPacketHelloFromClient: + { + u8 ok = packet->helloFromClient.protoVersion <= TraversalProtoVersion; + TraversalPacket* reply = AllocPacket(*addr); + reply->type = TraversalPacketHelloFromServer; + reply->helloFromServer.ok = ok; + if (ok) + { + TraversalHostId hostId; + TraversalInetAddress* iaddr; + // not that there is any significant change of + // duplication, but... + GetRandomHostId(&hostId); + while (true) + { + auto r = EvictFind(connectedClients, hostId); + if (!r.found) + { + iaddr = EvictSet(connectedClients, hostId); + break; + } + } + + *iaddr = MakeInetAddress(*addr); + + reply->helloFromServer.yourAddress = *iaddr; + reply->helloFromServer.yourHostId = hostId; + } + break; + } + case TraversalPacketConnectPlease: + { + TraversalHostId& hostId = packet->connectPlease.hostId; + auto r = EvictFind(connectedClients, hostId); + if (!r.found) + { + TraversalPacket* reply = AllocPacket(*addr); + reply->type = TraversalPacketConnectFailed; + reply->connectFailed.requestId = packet->requestId; + reply->connectFailed.reason = TraversalConnectFailedNoSuchClient; + } + else + { + TraversalPacket* please = AllocPacket(MakeSinAddr(*r.value), packet->requestId); + please->type = TraversalPacketPleaseSendPacket; + please->pleaseSendPacket.address = MakeInetAddress(*addr); + } + break; + } + default: + fprintf(stderr, "received unknown packet type %d from %s\n", packet->type, SenderName(addr)); + } + if (packet->type != TraversalPacketAck) + { + TraversalPacket ack = {}; + ack.type = TraversalPacketAck; + ack.requestId = packet->requestId; + ack.ack.ok = packetOk; + TrySend(&ack, sizeof(ack), addr); + } +} + +int main() +{ + int rv; + + urandomFd = open("/dev/urandom", O_RDONLY); + if (urandomFd < 0) + { + perror("open /dev/urandom"); + return 1; + } + + sock = socket(PF_INET6, SOCK_DGRAM, 0); + if (sock == -1) + { + perror("socket"); + return 1; + } + int no = 0; + rv = setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, &no, sizeof(no)); + if (rv < 0) + { + perror("setsockopt IPV6_V6ONLY"); + return 1; + } + in6_addr any = IN6ADDR_ANY_INIT; + sockaddr_in6 addr; +#ifdef SIN6_LEN + addr.sin6_len = sizeof(addr); +#endif + addr.sin6_family = AF_INET6; + addr.sin6_port = htons(6262); + addr.sin6_flowinfo = 0; + addr.sin6_addr = any; + addr.sin6_scope_id = 0; + + rv = bind(sock, (sockaddr*) &addr, sizeof(addr)); + if (rv < 0) + { + perror("bind"); + return 1; + } + + timeval tv; + tv.tv_sec = 0; + tv.tv_usec = 300000; + rv = setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); + if (rv < 0) + { + perror("setsockopt SO_RCVTIMEO"); + return 1; + } + + while (true) + { + sockaddr_in6 raddr; + socklen_t addrLen = sizeof(raddr); + TraversalPacket packet; + // note: switch to recvmmsg (yes, mmsg) if this becomes + // expensive + rv = recvfrom(sock, &packet, sizeof(packet), 0, (sockaddr*) &raddr, &addrLen); + if (gettimeofday(&tv, NULL) < 0) + { + perror("gettimeofday"); + exit(1); + } + currentTime = (u64) tv.tv_sec * 1000000 + tv.tv_usec; + if (rv < 0) + { + if (errno != EINTR && errno != EAGAIN) + { + perror("recvfrom"); + return 1; + } + } + else if ((size_t) rv < sizeof(packet)) + { + fprintf(stderr, "received short packet from %s\n", SenderName(&raddr)); + } + else + { + HandlePacket(&packet, &raddr); + } + ResendPackets(); + } +} diff --git a/Source/Core/DolphinWX/NetWindow.cpp b/Source/Core/DolphinWX/NetWindow.cpp index 0768b615ed..02e16cda3b 100644 --- a/Source/Core/DolphinWX/NetWindow.cpp +++ b/Source/Core/DolphinWX/NetWindow.cpp @@ -135,14 +135,14 @@ NetPlaySetupDiag::NetPlaySetupDiag(wxWindow* const parent, const CGameListCtrl* nick_szr->Add(m_nickname_text, 0, wxALL, 5); std::string centralServer; - netplay_section.Get("TraversalServer", ¢ralServer, "vps.qoid.us"); + netplay_section.Get("TraversalServer", ¢ralServer, ""); m_traversal_server_lbl = new wxStaticText(panel, wxID_ANY, _("Traversal:")); m_traversal_server = new wxTextCtrl(panel, wxID_ANY, StrToWxStr(centralServer)); nick_szr->Add(m_traversal_server_lbl, 0, wxCENTER); nick_szr->Add(m_traversal_server, 0, wxALL, 5); std::string centralPort; - netplay_section.Get("TraversalPort", ¢ralPort, "6262"); + netplay_section.Get("TraversalPort", ¢ralPort, ""); m_traversal_port_lbl = new wxStaticText(panel, wxID_ANY, _("Port:")); m_traversal_port = new wxTextCtrl(panel, wxID_ANY, StrToWxStr(centralPort)); nick_szr->Add(m_traversal_port_lbl, 0, wxCENTER);