diff --git a/Source/Core/Common/Common.vcxproj b/Source/Core/Common/Common.vcxproj
index dfe91db48b..fdd3183587 100644
--- a/Source/Core/Common/Common.vcxproj
+++ b/Source/Core/Common/Common.vcxproj
@@ -80,6 +80,8 @@
+
+
@@ -116,6 +118,7 @@
+
@@ -142,4 +145,4 @@
-
+
\ No newline at end of file
diff --git a/Source/Core/Common/Common.vcxproj.filters b/Source/Core/Common/Common.vcxproj.filters
index 2f121ca882..da11bd379b 100644
--- a/Source/Core/Common/Common.vcxproj.filters
+++ b/Source/Core/Common/Common.vcxproj.filters
@@ -71,6 +71,9 @@
+
+
+
@@ -117,8 +120,10 @@
+
+
-
+
\ No newline at end of file
diff --git a/Source/Core/Common/TraversalClient.cpp b/Source/Core/Common/TraversalClient.cpp
new file mode 100644
index 0000000000..9ac5eca238
--- /dev/null
+++ b/Source/Core/Common/TraversalClient.cpp
@@ -0,0 +1,370 @@
+// This file is public domain, in case it's useful to anyone. -comex
+
+#include "Common/TraversalClient.h"
+#include "enet/enet.h"
+#include "Timer.h"
+
+static void GetRandomishBytes(u8* buf, size_t size)
+{
+ // We don't need high quality random numbers (which might not be available),
+ // just non-repeating numbers!
+ srand(enet_time_get());
+ for (size_t i = 0; i < size; i++)
+ buf[i] = rand() & 0xff;
+}
+
+TraversalClient::TraversalClient(ENetHost* netHost, const std::string& server)
+ : m_NetHost(netHost)
+ , m_Server(server)
+ , m_Client(nullptr)
+ , m_FailureReason(0)
+ , m_ConnectRequestId(0)
+ , m_PendingConnect(false)
+ , m_PingTime(0)
+{
+ netHost->intercept = TraversalClient::InterceptCallback;
+
+ Reset();
+
+ ReconnectToServer();
+}
+
+TraversalClient::~TraversalClient()
+{
+}
+
+void TraversalClient::ReconnectToServer()
+{
+ m_Server = "vps.qoid.us"; // XXX
+ if (enet_address_set_host(&m_ServerAddress, m_Server.c_str()))
+ {
+ OnFailure(BadHost);
+ return;
+ }
+ m_ServerAddress.port = 6262;
+
+ m_State = Connecting;
+
+ TraversalPacket hello = {};
+ hello.type = TraversalPacketHelloFromClient;
+ hello.helloFromClient.protoVersion = TraversalProtoVersion;
+ SendTraversalPacket(hello);
+ if (m_Client)
+ m_Client->OnTraversalStateChanged();
+}
+
+static ENetAddress MakeENetAddress(TraversalInetAddress* address)
+{
+ ENetAddress eaddr;
+ if (address->isIPV6)
+ {
+ eaddr.port = 0; // no support yet :(
+ }
+ else
+ {
+ eaddr.host = address->address[0];
+ eaddr.port = ntohs(address->port);
+ }
+ return eaddr;
+}
+
+void TraversalClient::ConnectToClient(const std::string& host)
+{
+ if (host.size() > sizeof(TraversalHostId))
+ {
+ PanicAlert("host too long");
+ return;
+ }
+ TraversalPacket packet = {};
+ packet.type = TraversalPacketConnectPlease;
+ memcpy(packet.connectPlease.hostId.data(), host.c_str(), host.size());
+ m_ConnectRequestId = SendTraversalPacket(packet);
+ m_PendingConnect = true;
+}
+
+bool TraversalClient::TestPacket(u8* data, size_t size, ENetAddress* from)
+{
+ if (from->host == m_ServerAddress.host &&
+ from->port == m_ServerAddress.port)
+ {
+ if (size < sizeof(TraversalPacket))
+ {
+ ERROR_LOG(NETPLAY, "Received too-short traversal packet.");
+ }
+ else
+ {
+ HandleServerPacket((TraversalPacket*) data);
+ return true;
+ }
+ }
+ return false;
+}
+
+//--Temporary until more of the old netplay branch is moved over
+void TraversalClient::Update()
+{
+ ENetEvent netEvent;
+ if (enet_host_service(m_NetHost, &netEvent, 4) > 0)
+ {
+ switch (netEvent.type)
+ {
+ case ENET_EVENT_TYPE_RECEIVE:
+ TestPacket(netEvent.packet->data, netEvent.packet->dataLength, &netEvent.peer->address);
+
+ enet_packet_destroy(netEvent.packet);
+ break;
+ }
+ }
+ HandleResends();
+}
+
+void TraversalClient::HandleServerPacket(TraversalPacket* packet)
+{
+ u8 ok = 1;
+ switch (packet->type)
+ {
+ case TraversalPacketAck:
+ if (!packet->ack.ok)
+ {
+ OnFailure(ServerForgotAboutUs);
+ break;
+ }
+ for (auto it = m_OutgoingTraversalPackets.begin(); it != m_OutgoingTraversalPackets.end(); ++it)
+ {
+ if (it->packet.requestId == packet->requestId)
+ {
+ m_OutgoingTraversalPackets.erase(it);
+ break;
+ }
+ }
+ break;
+ case TraversalPacketHelloFromServer:
+ if (m_State != Connecting)
+ break;
+ if (!packet->helloFromServer.ok)
+ {
+ OnFailure(VersionTooOld);
+ break;
+ }
+ m_HostId = packet->helloFromServer.yourHostId;
+ m_State = Connected;
+ if (m_Client)
+ m_Client->OnTraversalStateChanged();
+ break;
+ case TraversalPacketPleaseSendPacket:
+ {
+ // security is overrated.
+ ENetAddress addr = MakeENetAddress(&packet->pleaseSendPacket.address);
+ if (addr.port != 0)
+ {
+ char message[] = "Hello from Dolphin Netplay...";
+ ENetBuffer buf;
+ buf.data = message;
+ buf.dataLength = sizeof(message) - 1;
+ enet_socket_send(m_NetHost->socket, &addr, &buf, 1);
+ }
+ else
+ {
+ // invalid IPV6
+ ok = 0;
+ }
+ break;
+ }
+ case TraversalPacketConnectReady:
+ case TraversalPacketConnectFailed:
+ {
+ if (!m_PendingConnect || packet->connectReady.requestId != m_ConnectRequestId)
+ break;
+
+ m_PendingConnect = false;
+
+ if (!m_Client)
+ break;
+
+ if (packet->type == TraversalPacketConnectReady)
+ m_Client->OnConnectReady(MakeENetAddress(&packet->connectReady.address));
+ else
+ m_Client->OnConnectFailed(packet->connectFailed.reason);
+ break;
+ }
+ default:
+ WARN_LOG(NETPLAY, "Received unknown packet with type %d", packet->type);
+ break;
+ }
+ if (packet->type != TraversalPacketAck)
+ {
+ TraversalPacket ack = {};
+ ack.type = TraversalPacketAck;
+ ack.requestId = packet->requestId;
+ ack.ack.ok = ok;
+
+ ENetBuffer buf;
+ buf.data = &ack;
+ buf.dataLength = sizeof(ack);
+ if (enet_socket_send(m_NetHost->socket, &m_ServerAddress, &buf, 1) == -1)
+ OnFailure(SocketSendError);
+ }
+}
+
+void TraversalClient::OnFailure(int reason)
+{
+ m_State = Failure;
+ m_FailureReason = reason;
+
+ switch (reason)
+ {
+ case TraversalClient::BadHost:
+ {
+ auto server = "dolphin-emu.org";
+ PanicAlertT("Couldn't look up central server %s", server);
+ break;
+ }
+ case TraversalClient::VersionTooOld:
+ PanicAlertT("Dolphin too old for traversal server");
+ break;
+ case TraversalClient::ServerForgotAboutUs:
+ PanicAlertT("Disconnected from traversal server");
+ break;
+ case TraversalClient::SocketSendError:
+ PanicAlertT("Socket error sending to traversal server");
+ break;
+ case TraversalClient::ResendTimeout:
+ PanicAlertT("Timeout connecting to traversal server");
+ break;
+ default:
+ PanicAlertT("Unknown error %x", reason);
+ break;
+ }
+
+ if (m_Client)
+ m_Client->OnTraversalStateChanged();
+}
+
+void TraversalClient::ResendPacket(OutgoingTraversalPacketInfo* info)
+{
+ info->sendTime = enet_time_get();
+ info->tries++;
+ ENetBuffer buf;
+ buf.data = &info->packet;
+ buf.dataLength = sizeof(info->packet);
+ if (enet_socket_send(m_NetHost->socket, &m_ServerAddress, &buf, 1) == -1)
+ OnFailure(SocketSendError);
+}
+
+void TraversalClient::HandleResends()
+{
+ enet_uint32 now = enet_time_get();
+ for (auto& tpi : m_OutgoingTraversalPackets)
+ {
+ if (now - tpi.sendTime >= (u32) (300 * tpi.tries))
+ {
+ if (tpi.tries >= 5)
+ {
+ OnFailure(ResendTimeout);
+ m_OutgoingTraversalPackets.clear();
+ break;
+ }
+ else
+ {
+ ResendPacket(&tpi);
+ }
+ }
+ }
+ HandlePing();
+}
+
+void TraversalClient::HandlePing()
+{
+ enet_uint32 now = enet_time_get();
+ if (m_State == Connected && now - m_PingTime >= 500)
+ {
+ TraversalPacket ping = {0};
+ ping.type = TraversalPacketPing;
+ ping.ping.hostId = m_HostId;
+ SendTraversalPacket(ping);
+ m_PingTime = now;
+ }
+}
+
+TraversalRequestId TraversalClient::SendTraversalPacket(const TraversalPacket& packet)
+{
+ OutgoingTraversalPacketInfo info;
+ info.packet = packet;
+ GetRandomishBytes((u8*) &info.packet.requestId, sizeof(info.packet.requestId));
+ info.tries = 0;
+ m_OutgoingTraversalPackets.push_back(info);
+ ResendPacket(&m_OutgoingTraversalPackets.back());
+ return info.packet.requestId;
+}
+
+void TraversalClient::Reset()
+{
+ m_PendingConnect = false;
+ m_Client = nullptr;
+}
+
+int ENET_CALLBACK TraversalClient::InterceptCallback(ENetHost* host, ENetEvent* event)
+{
+ auto traversalClient = g_TraversalClient.get();
+ if (traversalClient->TestPacket(host->receivedData, host->receivedDataLength, &host->receivedAddress))
+ {
+ event->type = (ENetEventType)42;
+ return 1;
+ }
+ return 0;
+}
+
+std::unique_ptr g_TraversalClient;
+std::unique_ptr g_MainNetHost;
+
+// The settings at the previous TraversalClient reset - notably, we
+// need to know not just what port it's on, but whether it was
+// explicitly requested.
+static std::string g_OldServer;
+static u16 g_OldPort;
+
+bool EnsureTraversalClient(const std::string& server, u16 port)
+{
+ if (!g_MainNetHost || !g_TraversalClient || server != g_OldServer || port != g_OldPort)
+ {
+ g_OldServer = server;
+ g_OldPort = port;
+
+ ENetAddress addr = { ENET_HOST_ANY, port };
+ ENetHost* host = enet_host_create(
+ &addr, // address
+ 50, // peerCount
+ 1, // channelLimit
+ 0, // incomingBandwidth
+ 0); // outgoingBandwidth
+ if (!host)
+ {
+ g_MainNetHost.reset();
+ return false;
+ }
+ g_MainNetHost.reset(host);
+
+ g_TraversalClient.reset(new TraversalClient(g_MainNetHost.get(), server));
+
+ }
+ return true;
+}
+
+void ReleaseTraversalClient()
+{
+ if (!g_TraversalClient)
+ return;
+
+ if (g_OldPort != 0)
+ {
+ // If we were listening at a specific port, kill the
+ // TraversalClient to avoid hanging on to the port.
+ g_TraversalClient.reset();
+ g_MainNetHost.reset();
+ }
+ else
+ {
+ // Reset any pending connection attempts.
+ g_TraversalClient->Reset();
+ }
+}
diff --git a/Source/Core/Common/TraversalClient.h b/Source/Core/Common/TraversalClient.h
new file mode 100644
index 0000000000..838ea6448a
--- /dev/null
+++ b/Source/Core/Common/TraversalClient.h
@@ -0,0 +1,88 @@
+// This file is public domain, in case it's useful to anyone. -comex
+
+#pragma once
+#include
+#include
+#include
+#include "Common/Common.h"
+#include "Common/Thread.h"
+#include "Common/TraversalProto.h"
+
+#include "enet/enet.h"
+
+class TraversalClientClient
+{
+public:
+ virtual ~TraversalClientClient(){};
+ virtual void OnTraversalStateChanged()=0;
+ virtual void OnConnectReady(ENetAddress addr)=0;
+ virtual void OnConnectFailed(u8 reason)=0;
+};
+
+class TraversalClient
+{
+public:
+ enum State
+ {
+ Connecting,
+ Connected,
+ Failure
+ };
+
+ enum FailureReason
+ {
+ BadHost = 0x300,
+ VersionTooOld,
+ ServerForgotAboutUs,
+ SocketSendError,
+ ResendTimeout,
+ ConnectFailedError = 0x400,
+ };
+
+ TraversalClient(ENetHost* netHost, const std::string& server);
+ ~TraversalClient();
+ void Reset();
+ void ConnectToClient(const std::string& host);
+ void ReconnectToServer();
+ void Update();
+
+ // called from NetHost
+ bool TestPacket(u8* data, size_t size, ENetAddress* from);
+ void HandleResends();
+
+ ENetHost* m_NetHost;
+ TraversalClientClient* m_Client;
+ TraversalHostId m_HostId;
+ State m_State;
+ int m_FailureReason;
+
+private:
+ struct OutgoingTraversalPacketInfo
+ {
+ TraversalPacket packet;
+ int tries;
+ enet_uint32 sendTime;
+ };
+
+ void HandleServerPacket(TraversalPacket* packet);
+ void ResendPacket(OutgoingTraversalPacketInfo* info);
+ TraversalRequestId SendTraversalPacket(const TraversalPacket& packet);
+ void OnFailure(int reason);
+ void HandlePing();
+ static int ENET_CALLBACK InterceptCallback(ENetHost* host, ENetEvent* event);
+
+ TraversalRequestId m_ConnectRequestId;
+ bool m_PendingConnect;
+ std::list m_OutgoingTraversalPackets;
+ ENetAddress m_ServerAddress;
+ std::string m_Server;
+ enet_uint32 m_PingTime;
+};
+
+extern std::unique_ptr g_TraversalClient;
+// the NetHost connected to the TraversalClient.
+extern std::unique_ptr g_MainNetHost;
+
+// Create g_TraversalClient and g_MainNetHost if necessary.
+bool EnsureTraversalClient(const std::string& server, u16 port);
+void ReleaseTraversalClient();
diff --git a/Source/Core/Common/TraversalProto.h b/Source/Core/Common/TraversalProto.h
new file mode 100644
index 0000000000..32891beac6
--- /dev/null
+++ b/Source/Core/Common/TraversalProto.h
@@ -0,0 +1,96 @@
+// This file is public domain, in case it's useful to anyone. -comex
+
+#pragma once
+#include
+#include "Common/CommonTypes.h"
+
+
+typedef std::array TraversalHostId;
+typedef u64 TraversalRequestId;
+
+enum TraversalPacketType
+{
+ // [*->*]
+ TraversalPacketAck = 0,
+ // [c->s]
+ TraversalPacketPing = 1,
+ // [c->s]
+ TraversalPacketHelloFromClient = 2,
+ // [s->c]
+ TraversalPacketHelloFromServer = 3,
+ // [c->s] When connecting, first the client asks the central server...
+ TraversalPacketConnectPlease = 4,
+ // [s->c] ...who asks the game host to send a UDP packet to the
+ // client... (an ack implies success)
+ TraversalPacketPleaseSendPacket = 5,
+ // [s->c] ...which the central server relays back to the client.
+ TraversalPacketConnectReady = 6,
+ // [s->c] Alternately, the server might not have heard of this host.
+ TraversalPacketConnectFailed = 7
+};
+
+enum
+{
+ TraversalProtoVersion = 0
+};
+
+enum TraversalConnectFailedReason
+{
+ TraversalConnectFailedClientDidntRespond = 0,
+ TraversalConnectFailedClientFailure,
+ TraversalConnectFailedNoSuchClient
+};
+
+#pragma pack(push, 1)
+struct TraversalInetAddress
+{
+ u8 isIPV6;
+ u32 address[4];
+ u16 port;
+};
+struct TraversalPacket
+{
+ u8 type;
+ TraversalRequestId requestId;
+ union
+ {
+ struct
+ {
+ u8 ok;
+ } ack;
+ struct
+ {
+ TraversalHostId hostId;
+ } ping;
+ struct
+ {
+ u8 protoVersion;
+ } helloFromClient;
+ struct
+ {
+ u8 ok;
+ TraversalHostId yourHostId;
+ TraversalInetAddress yourAddress; // currently unused
+ } helloFromServer;
+ struct
+ {
+ TraversalHostId hostId;
+ } connectPlease;
+ struct
+ {
+ TraversalInetAddress address;
+ } pleaseSendPacket;
+ struct
+ {
+ TraversalRequestId requestId;
+ TraversalInetAddress address;
+ } connectReady;
+ struct
+ {
+ TraversalRequestId requestId;
+ u8 reason;
+ } connectFailed;
+ };
+};
+#pragma pack(pop)
+
diff --git a/Source/Core/Core/NetPlayClient.cpp b/Source/Core/Core/NetPlayClient.cpp
index 8bb26feeb8..77c9a06e71 100644
--- a/Source/Core/Core/NetPlayClient.cpp
+++ b/Source/Core/Core/NetPlayClient.cpp
@@ -38,10 +38,25 @@ NetPlayClient::~NetPlayClient()
Disconnect();
}
+ if (g_MainNetHost.get() == m_client)
+ {
+ g_MainNetHost.release();
+ }
+ if (m_client)
+ {
+ enet_host_destroy(m_client);
+ m_client = nullptr;
+ }
+
+ if (m_traversal_client)
+ {
+ ReleaseTraversalClient();
+ }
+
}
// called from ---GUI--- thread
-NetPlayClient::NetPlayClient(const std::string& address, const u16 port, NetPlayUI* dialog, const std::string& name)
+NetPlayClient::NetPlayClient(const std::string& address, const u16 port, NetPlayUI* dialog, const std::string& name, bool traversal)
: m_dialog(dialog)
, m_client(nullptr)
, m_server(nullptr)
@@ -53,6 +68,8 @@ NetPlayClient::NetPlayClient(const std::string& address, const u16 port, NetPlay
, m_is_recording(false)
, m_pid(0)
, m_connecting(false)
+ , m_traversal_client(nullptr)
+ , m_state(Failure)
{
m_target_buffer_size = 20;
ClearBuffers();
@@ -61,35 +78,81 @@ NetPlayClient::NetPlayClient(const std::string& address, const u16 port, NetPlay
m_player_name = name;
- //Direct Connection
- m_client = enet_host_create(nullptr, 1, 3, 0, 0);
-
- if (m_client == nullptr)
+ if (!traversal)
{
- PanicAlertT("Couldn't Create Client");
- }
+ //Direct Connection
+ m_client = enet_host_create(nullptr, 1, 3, 0, 0);
- ENetAddress addr;
- enet_address_set_host(&addr, address.c_str());
- addr.port = port;
+ if (m_client == nullptr)
+ {
+ PanicAlertT("Couldn't Create Client");
+ }
- m_server = enet_host_connect(m_client, &addr, 3, 0);
+ ENetAddress addr;
+ enet_address_set_host(&addr, address.c_str());
+ addr.port = port;
- if (m_server == nullptr)
- {
- PanicAlertT("Couldn't create peer.");
- }
+ m_server = enet_host_connect(m_client, &addr, 3, 0);
+
+ if (m_server == nullptr)
+ {
+ PanicAlertT("Couldn't create peer.");
+ }
+
+ ENetEvent netEvent;
+ int net = enet_host_service(m_client, &netEvent, 5000);
+ if (net > 0 && netEvent.type == ENET_EVENT_TYPE_CONNECT)
+ {
+ if (Connect())
+ m_thread = std::thread(&NetPlayClient::ThreadFunc, this);
+ }
+ else
+ {
+ PanicAlertT("Failed to Connect!");
+ }
- ENetEvent netEvent;
- int net = enet_host_service(m_client, &netEvent, 5000);
- if (net > 0 && netEvent.type == ENET_EVENT_TYPE_CONNECT)
- {
- if (Connect())
- m_thread = std::thread(&NetPlayClient::ThreadFunc, this);
}
else
{
- PanicAlertT("Failed to Connect!");
+ //Traversal Server
+ if (!EnsureTraversalClient("dolphin-emu.org", 0))
+ return;
+ m_client = g_MainNetHost.get();
+
+ m_traversal_client = g_TraversalClient.get();
+
+ // If we were disconnected in the background, reconnect.
+ if (m_traversal_client->m_State == TraversalClient::Failure)
+ m_traversal_client->ReconnectToServer();
+ m_traversal_client->m_Client = this;
+ m_host_spec = address;
+ m_state = WaitingForTraversalClientConnection;
+ OnTraversalStateChanged();
+ m_connecting = true;
+
+ while (m_connecting)
+ {
+ ENetEvent netEvent;
+ if (m_traversal_client)
+ m_traversal_client->HandleResends();
+
+ while (enet_host_service(m_client, &netEvent, 4) > 0)
+ {
+ sf::Packet rpac;
+ switch (netEvent.type)
+ {
+ case ENET_EVENT_TYPE_CONNECT:
+ m_server = netEvent.peer;
+ if (Connect())
+ {
+ m_state = Connected;
+ m_thread = std::thread(&NetPlayClient::ThreadFunc, this);
+ }
+ return;
+ }
+ }
+ }
+ PanicAlertT("Failed To Connect!");
}
}
@@ -380,6 +443,7 @@ void NetPlayClient::Send(sf::Packet& packet)
void NetPlayClient::Disconnect()
{
ENetEvent netEvent;
+ m_state = Failure;
enet_peer_disconnect(m_server, 0);
while (enet_host_service(m_client, &netEvent, 3000) > 0)
{
@@ -407,6 +471,8 @@ void NetPlayClient::ThreadFunc()
int net;
{
std::lock_guard lks(m_crit.send);
+ if (m_traversal_client)
+ m_traversal_client->HandleResends();
net = enet_host_service(m_client, &netEvent, 4);
}
if (net > 0)
@@ -630,6 +696,55 @@ void NetPlayClient::ClearBuffers()
}
}
+// called from ---NETPLAY--- thread
+void NetPlayClient::OnTraversalStateChanged()
+{
+ if (m_state == WaitingForTraversalClientConnection &&
+ m_traversal_client->m_State == TraversalClient::Connected)
+ {
+ m_state = WaitingForTraversalClientConnectReady;
+ m_traversal_client->ConnectToClient(m_host_spec);
+ }
+ else if (m_state != Failure &&
+ m_traversal_client->m_State == TraversalClient::Failure)
+ {
+ Disconnect();
+ }
+}
+
+// called from ---NETPLAY--- thread
+void NetPlayClient::OnConnectReady(ENetAddress addr)
+{
+ if (m_state == WaitingForTraversalClientConnectReady)
+ {
+ m_state = Connecting;
+ enet_host_connect(m_client, &addr, 0, 0);
+ }
+}
+
+// called from ---NETPLAY--- thread
+void NetPlayClient::OnConnectFailed(u8 reason)
+{
+ m_connecting = false;
+ m_state = Failure;
+ int swtch = TraversalClient::ConnectFailedError + reason;
+ switch (swtch)
+ {
+ case TraversalClient::ConnectFailedError + TraversalConnectFailedClientDidntRespond:
+ PanicAlertT("Traversal server timed out connecting to the host");
+ break;
+ case TraversalClient::ConnectFailedError + TraversalConnectFailedClientFailure:
+ PanicAlertT("Server rejected traversal attempt");
+ break;
+ case TraversalClient::ConnectFailedError + TraversalConnectFailedNoSuchClient:
+ PanicAlertT("Invalid host");
+ break;
+ default:
+ PanicAlertT("Unknown error %x", swtch);
+ break;
+ }
+}
+
// called from ---CPU--- thread
bool NetPlayClient::GetNetPads(const u8 pad_nb, GCPadStatus* pad_status)
{
diff --git a/Source/Core/Core/NetPlayClient.h b/Source/Core/Core/NetPlayClient.h
index 6d923f0697..fab8e6548f 100644
--- a/Source/Core/Core/NetPlayClient.h
+++ b/Source/Core/Core/NetPlayClient.h
@@ -7,12 +7,11 @@
#include