Added TraversalServer.cpp to Core/Common

This commit is contained in:
Ziek 2015-02-16 22:53:50 -08:00
parent 619a3a5171
commit 779f275486
6 changed files with 472 additions and 12 deletions

View File

@ -76,6 +76,11 @@
<AdditionalIncludeDirectories>$(ExternalsDir)enet\include;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> <AdditionalIncludeDirectories>$(ExternalsDir)enet\include;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
</ClCompile> </ClCompile>
</ItemDefinitionGroup> </ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<ClCompile>
<AdditionalIncludeDirectories>$(ExternalsDir)enet\include;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
</ClCompile>
</ItemDefinitionGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" /> <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets"> <ImportGroup Label="ExtensionTargets">
</ImportGroup> </ImportGroup>

View File

@ -67,3 +67,4 @@ if(NOT APPLE AND NOT ANDROID)
endif() endif()
add_dolphin_library(common "${SRCS}" "${LIBS}") add_dolphin_library(common "${SRCS}" "${LIBS}")
add_executable(traversal_server TraversalServer.cpp)

View File

@ -1,6 +1,6 @@
// This file is public domain, in case it's useful to anyone. -comex // This file is public domain, in case it's useful to anyone. -comex
#include "Timer.h" #include "Common/Timer.h"
#include "Common/TraversalClient.h" #include "Common/TraversalClient.h"
static void GetRandomishBytes(u8* buf, size_t size) static void GetRandomishBytes(u8* buf, size_t size)
@ -207,7 +207,7 @@ void TraversalClient::HandleServerPacket(TraversalPacket* packet)
} }
} }
void TraversalClient::OnFailure(int reason) void TraversalClient::OnFailure(FailureReason reason)
{ {
m_State = Failure; m_State = Failure;
m_FailureReason = reason; m_FailureReason = reason;
@ -216,8 +216,7 @@ void TraversalClient::OnFailure(int reason)
{ {
case TraversalClient::BadHost: case TraversalClient::BadHost:
{ {
auto server = "dolphin-emu.org"; PanicAlertT("Couldn't look up central server %s", m_Server.c_str());
PanicAlertT("Couldn't look up central server %s", server);
break; break;
} }
case TraversalClient::VersionTooOld: case TraversalClient::VersionTooOld:
@ -232,9 +231,6 @@ void TraversalClient::OnFailure(int reason)
case TraversalClient::ResendTimeout: case TraversalClient::ResendTimeout:
PanicAlertT("Timeout connecting to traversal server"); PanicAlertT("Timeout connecting to traversal server");
break; break;
default:
PanicAlertT("Unknown error %x", reason);
break;
} }
if (m_Client) if (m_Client)
@ -279,7 +275,7 @@ void TraversalClient::HandlePing()
enet_uint32 now = enet_time_get(); enet_uint32 now = enet_time_get();
if (m_State == Connected && now - m_PingTime >= 500) if (m_State == Connected && now - m_PingTime >= 500)
{ {
TraversalPacket ping = {0}; TraversalPacket ping = {};
ping.type = TraversalPacketPing; ping.type = TraversalPacketPing;
ping.ping.hostId = m_HostId; ping.ping.hostId = m_HostId;
SendTraversalPacket(ping); SendTraversalPacket(ping);

View File

@ -4,10 +4,10 @@
#include <functional> #include <functional>
#include <list> #include <list>
#include <memory> #include <memory>
#include <enet/include/enet/enet.h>
#include "Common/Common.h" #include "Common/Common.h"
#include "Common/Thread.h" #include "Common/Thread.h"
#include "Common/TraversalProto.h" #include "Common/TraversalProto.h"
#include "enet/include/enet/enet.h"
class TraversalClientClient class TraversalClientClient
{ {
@ -62,7 +62,7 @@ private:
void HandleServerPacket(TraversalPacket* packet); void HandleServerPacket(TraversalPacket* packet);
void ResendPacket(OutgoingTraversalPacketInfo* info); void ResendPacket(OutgoingTraversalPacketInfo* info);
TraversalRequestId SendTraversalPacket(const TraversalPacket& packet); TraversalRequestId SendTraversalPacket(const TraversalPacket& packet);
void OnFailure(int reason); void OnFailure(FailureReason reason);
void HandlePing(); void HandlePing();
static int ENET_CALLBACK InterceptCallback(ENetHost* host, ENetEvent* event); static int ENET_CALLBACK InterceptCallback(ENetHost* host, ENetEvent* event);
TraversalRequestId m_ConnectRequestId; TraversalRequestId m_ConnectRequestId;

View File

@ -0,0 +1,458 @@
// This file is public domain, in case it's useful to anyone. -comex
// The central server implementation.
#include <cerrno>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <fcntl.h>
#include <unistd.h>
#include <unordered_map>
#include <utility>
#include <vector>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <sys/types.h>
#include "Common/TraversalProto.h"
#define DEBUG 1
#define NUMBER_OF_TRIES 5
static u64 currentTime;
struct OutgoingPacketInfo
{
TraversalPacket packet;
TraversalRequestId misc;
sockaddr_in6 dest;
int tries;
u64 sendTime;
};
template <typename T>
struct EvictEntry
{
u64 updateTime;
T value;
};
template <typename V>
struct EvictFindResult
{
bool found;
V* value;
};
template <typename K, typename V>
EvictFindResult<V> EvictFind(std::unordered_map<K, EvictEntry<V>>& map, const K& key, bool refresh = false)
{
retry:
const u64 expiryTime = 30 * 1000000; // 30s
EvictFindResult<V> result;
if (map.bucket_count())
{
auto bucket = map.bucket(key);
auto it = map.begin(bucket);
for (; it != map.end(bucket); ++it)
{
if (currentTime - it->second.updateTime > expiryTime)
{
map.erase(it->first);
goto retry;
}
if (it->first == key)
{
if (refresh)
it->second.updateTime = currentTime;
result.found = true;
result.value = &it->second.value;
return result;
}
}
}
#if DEBUG
printf("failed to find key '");
for (size_t i = 0; i < sizeof(key); i++) {
printf("%02x", ((u8 *) &key)[i]);
}
printf("'\n");
#endif
result.found = false;
return result;
}
template <typename K, typename V>
V* EvictSet(std::unordered_map<K, EvictEntry<V>>& map, const K& key)
{
// can't use a local_iterator to emplace...
auto& result = map[key];
result.updateTime = currentTime;
return &result.value;
}
namespace std
{
template <>
struct hash<TraversalHostId>
{
size_t operator()(const TraversalHostId& id) const
{
auto p = (u32*) id.data();
return p[0] ^ ((p[1] << 13) | (p[1] >> 19));
}
};
}
static int sock;
static int urandomFd;
static std::unordered_map<
TraversalRequestId,
OutgoingPacketInfo
> outgoingPackets;
static std::unordered_map<
TraversalHostId,
EvictEntry<TraversalInetAddress>
> connectedClients;
static TraversalInetAddress MakeInetAddress(const sockaddr_in6& addr)
{
if (addr.sin6_family != AF_INET6)
{
fprintf(stderr, "bad sockaddr_in6\n");
exit(1);
}
u32* words = (u32*) addr.sin6_addr.s6_addr;
TraversalInetAddress result = {0};
if (words[0] == 0 && words[1] == 0 && words[2] == 0xffff0000)
{
result.isIPV6 = false;
result.address[0] = words[3];
}
else
{
result.isIPV6 = true;
memcpy(result.address, words, sizeof(result.address));
}
result.port = addr.sin6_port;
return result;
}
static sockaddr_in6 MakeSinAddr(const TraversalInetAddress& addr)
{
sockaddr_in6 result;
#ifdef SIN6_LEN
result.sin6_len = sizeof(result);
#endif
result.sin6_family = AF_INET6;
result.sin6_port = addr.port;
result.sin6_flowinfo = 0;
if (addr.isIPV6)
{
memcpy(&result.sin6_addr, addr.address, 16);
}
else
{
u32* words = (u32*) result.sin6_addr.s6_addr;
words[0] = 0;
words[1] = 0;
words[2] = 0xffff0000;
words[3] = addr.address[0];
}
result.sin6_scope_id = 0;
return result;
}
static void GetRandomBytes(void* output, size_t size)
{
static u8 bytes[8192];
static size_t bytesLeft = 0;
if (bytesLeft < size)
{
ssize_t rv = read(urandomFd, bytes, sizeof(bytes));
if (rv != sizeof(bytes))
{
perror("read from /dev/urandom");
exit(1);
}
bytesLeft = sizeof(bytes);
}
memcpy(output, bytes + (bytesLeft -= size), size);
}
static void GetRandomHostId(TraversalHostId* hostId)
{
char buf[9];
u32 num;
GetRandomBytes(&num, sizeof(num));
sprintf(buf, "%08x", num);
memcpy(hostId->data(), buf, 8);
}
static const char* SenderName(sockaddr_in6* addr)
{
static char buf[INET6_ADDRSTRLEN + 10];
inet_ntop(PF_INET6, &addr->sin6_addr, buf, sizeof(buf));
sprintf(buf + strlen(buf), ":%d", ntohs(addr->sin6_port));
return buf;
}
static void TrySend(const void* buffer, size_t size, sockaddr_in6* addr)
{
#if DEBUG
printf("-> %d %lu %s\n", ((TraversalPacket*) buffer)->type, ((TraversalPacket*) buffer)->requestId, SenderName(addr));
#endif
if ((size_t) sendto(sock, buffer, size, 0, (sockaddr*) addr, sizeof(*addr)) != size)
{
perror("sendto");
}
}
static TraversalPacket* AllocPacket(const sockaddr_in6& dest, TraversalRequestId misc = 0)
{
TraversalRequestId requestId;
GetRandomBytes(&requestId, sizeof(requestId));
OutgoingPacketInfo* info = &outgoingPackets[requestId];
info->dest = dest;
info->misc = misc;
info->tries = 0;
info->sendTime = currentTime;
TraversalPacket* result = &info->packet;
memset(result, 0, sizeof(*result));
result->requestId = requestId;
return result;
}
static void SendPacket(OutgoingPacketInfo* info)
{
info->tries++;
info->sendTime = currentTime;
TrySend(&info->packet, sizeof(info->packet), &info->dest);
}
static void ResendPackets()
{
std::vector<std::pair<TraversalInetAddress, TraversalRequestId>> todoFailures;
todoFailures.clear();
for (auto it = outgoingPackets.begin(); it != outgoingPackets.end();)
{
OutgoingPacketInfo* info = &it->second;
if (currentTime - info->sendTime >= (u64) (300000 * info->tries))
{
if (info->tries >= NUMBER_OF_TRIES)
{
if (info->packet.type == TraversalPacketPleaseSendPacket)
{
todoFailures.push_back(std::make_pair(info->packet.pleaseSendPacket.address, info->misc));
}
it = outgoingPackets.erase(it);
continue;
}
else
{
SendPacket(info);
}
}
++it;
}
for (const auto& p : todoFailures)
{
TraversalPacket* fail = AllocPacket(MakeSinAddr(p.first));
fail->type = TraversalPacketConnectFailed;
fail->connectFailed.requestId = p.second;
fail->connectFailed.reason = TraversalConnectFailedClientDidntRespond;
}
}
static void HandlePacket(TraversalPacket* packet, sockaddr_in6* addr)
{
#if DEBUG
printf("<- %d %lu %s\n", packet->type, packet->requestId, SenderName(addr));
#endif
bool packetOk = true;
switch (packet->type)
{
case TraversalPacketAck:
{
auto it = outgoingPackets.find(packet->requestId);
if (it == outgoingPackets.end())
break;
OutgoingPacketInfo* info = &it->second;
if (info->packet.type == TraversalPacketPleaseSendPacket)
{
TraversalPacket* ready = AllocPacket(MakeSinAddr(info->packet.pleaseSendPacket.address));
if (packet->ack.ok)
{
ready->type = TraversalPacketConnectReady;
ready->connectReady.requestId = info->misc;
ready->connectReady.address = MakeInetAddress(info->dest);
}
else
{
ready->type = TraversalPacketConnectFailed;
ready->connectFailed.requestId = info->misc;
ready->connectFailed.reason = TraversalConnectFailedClientFailure;
}
}
outgoingPackets.erase(it);
break;
}
case TraversalPacketPing:
{
auto r = EvictFind(connectedClients, packet->ping.hostId, true);
packetOk = r.found;
break;
}
case TraversalPacketHelloFromClient:
{
u8 ok = packet->helloFromClient.protoVersion <= TraversalProtoVersion;
TraversalPacket* reply = AllocPacket(*addr);
reply->type = TraversalPacketHelloFromServer;
reply->helloFromServer.ok = ok;
if (ok)
{
TraversalHostId hostId;
TraversalInetAddress* iaddr;
// not that there is any significant change of
// duplication, but...
GetRandomHostId(&hostId);
while (true)
{
auto r = EvictFind(connectedClients, hostId);
if (!r.found)
{
iaddr = EvictSet(connectedClients, hostId);
break;
}
}
*iaddr = MakeInetAddress(*addr);
reply->helloFromServer.yourAddress = *iaddr;
reply->helloFromServer.yourHostId = hostId;
}
break;
}
case TraversalPacketConnectPlease:
{
TraversalHostId& hostId = packet->connectPlease.hostId;
auto r = EvictFind(connectedClients, hostId);
if (!r.found)
{
TraversalPacket* reply = AllocPacket(*addr);
reply->type = TraversalPacketConnectFailed;
reply->connectFailed.requestId = packet->requestId;
reply->connectFailed.reason = TraversalConnectFailedNoSuchClient;
}
else
{
TraversalPacket* please = AllocPacket(MakeSinAddr(*r.value), packet->requestId);
please->type = TraversalPacketPleaseSendPacket;
please->pleaseSendPacket.address = MakeInetAddress(*addr);
}
break;
}
default:
fprintf(stderr, "received unknown packet type %d from %s\n", packet->type, SenderName(addr));
}
if (packet->type != TraversalPacketAck)
{
TraversalPacket ack = {};
ack.type = TraversalPacketAck;
ack.requestId = packet->requestId;
ack.ack.ok = packetOk;
TrySend(&ack, sizeof(ack), addr);
}
}
int main()
{
int rv;
urandomFd = open("/dev/urandom", O_RDONLY);
if (urandomFd < 0)
{
perror("open /dev/urandom");
return 1;
}
sock = socket(PF_INET6, SOCK_DGRAM, 0);
if (sock == -1)
{
perror("socket");
return 1;
}
int no = 0;
rv = setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, &no, sizeof(no));
if (rv < 0)
{
perror("setsockopt IPV6_V6ONLY");
return 1;
}
in6_addr any = IN6ADDR_ANY_INIT;
sockaddr_in6 addr;
#ifdef SIN6_LEN
addr.sin6_len = sizeof(addr);
#endif
addr.sin6_family = AF_INET6;
addr.sin6_port = htons(6262);
addr.sin6_flowinfo = 0;
addr.sin6_addr = any;
addr.sin6_scope_id = 0;
rv = bind(sock, (sockaddr*) &addr, sizeof(addr));
if (rv < 0)
{
perror("bind");
return 1;
}
timeval tv;
tv.tv_sec = 0;
tv.tv_usec = 300000;
rv = setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
if (rv < 0)
{
perror("setsockopt SO_RCVTIMEO");
return 1;
}
while (true)
{
sockaddr_in6 raddr;
socklen_t addrLen = sizeof(raddr);
TraversalPacket packet;
// note: switch to recvmmsg (yes, mmsg) if this becomes
// expensive
rv = recvfrom(sock, &packet, sizeof(packet), 0, (sockaddr*) &raddr, &addrLen);
if (gettimeofday(&tv, NULL) < 0)
{
perror("gettimeofday");
exit(1);
}
currentTime = (u64) tv.tv_sec * 1000000 + tv.tv_usec;
if (rv < 0)
{
if (errno != EINTR && errno != EAGAIN)
{
perror("recvfrom");
return 1;
}
}
else if ((size_t) rv < sizeof(packet))
{
fprintf(stderr, "received short packet from %s\n", SenderName(&raddr));
}
else
{
HandlePacket(&packet, &raddr);
}
ResendPackets();
}
}

View File

@ -135,14 +135,14 @@ NetPlaySetupDiag::NetPlaySetupDiag(wxWindow* const parent, const CGameListCtrl*
nick_szr->Add(m_nickname_text, 0, wxALL, 5); nick_szr->Add(m_nickname_text, 0, wxALL, 5);
std::string centralServer; std::string centralServer;
netplay_section.Get("TraversalServer", &centralServer, "vps.qoid.us"); netplay_section.Get("TraversalServer", &centralServer, "");
m_traversal_server_lbl = new wxStaticText(panel, wxID_ANY, _("Traversal:")); m_traversal_server_lbl = new wxStaticText(panel, wxID_ANY, _("Traversal:"));
m_traversal_server = new wxTextCtrl(panel, wxID_ANY, StrToWxStr(centralServer)); m_traversal_server = new wxTextCtrl(panel, wxID_ANY, StrToWxStr(centralServer));
nick_szr->Add(m_traversal_server_lbl, 0, wxCENTER); nick_szr->Add(m_traversal_server_lbl, 0, wxCENTER);
nick_szr->Add(m_traversal_server, 0, wxALL, 5); nick_szr->Add(m_traversal_server, 0, wxALL, 5);
std::string centralPort; std::string centralPort;
netplay_section.Get("TraversalPort", &centralPort, "6262"); netplay_section.Get("TraversalPort", &centralPort, "");
m_traversal_port_lbl = new wxStaticText(panel, wxID_ANY, _("Port:")); m_traversal_port_lbl = new wxStaticText(panel, wxID_ANY, _("Port:"));
m_traversal_port = new wxTextCtrl(panel, wxID_ANY, StrToWxStr(centralPort)); m_traversal_port = new wxTextCtrl(panel, wxID_ANY, StrToWxStr(centralPort));
nick_szr->Add(m_traversal_port_lbl, 0, wxCENTER); nick_szr->Add(m_traversal_port_lbl, 0, wxCENTER);