344 lines
8.5 KiB
C++
344 lines
8.5 KiB
C++
// SPDX-License-Identifier: CC0-1.0
|
|
|
|
#include "Common/TraversalClient.h"
|
|
|
|
#include <cstddef>
|
|
#include <cstring>
|
|
#include <string>
|
|
|
|
#include "Common/CommonTypes.h"
|
|
#include "Common/Logging/Log.h"
|
|
#include "Common/MsgHandler.h"
|
|
#include "Common/Random.h"
|
|
#include "Core/NetPlayProto.h"
|
|
|
|
TraversalClient::TraversalClient(ENetHost* netHost, const std::string& server, const u16 port)
|
|
: m_NetHost(netHost), m_Server(server), m_port(port)
|
|
{
|
|
netHost->intercept = TraversalClient::InterceptCallback;
|
|
|
|
Reset();
|
|
|
|
ReconnectToServer();
|
|
}
|
|
|
|
TraversalClient::~TraversalClient() = default;
|
|
|
|
TraversalHostId TraversalClient::GetHostID() const
|
|
{
|
|
return m_HostId;
|
|
}
|
|
|
|
TraversalClient::State TraversalClient::GetState() const
|
|
{
|
|
return m_State;
|
|
}
|
|
|
|
TraversalClient::FailureReason TraversalClient::GetFailureReason() const
|
|
{
|
|
return m_FailureReason;
|
|
}
|
|
|
|
void TraversalClient::ReconnectToServer()
|
|
{
|
|
if (enet_address_set_host(&m_ServerAddress, m_Server.c_str()))
|
|
{
|
|
OnFailure(FailureReason::BadHost);
|
|
return;
|
|
}
|
|
m_ServerAddress.port = m_port;
|
|
|
|
m_State = State::Connecting;
|
|
|
|
TraversalPacket hello = {};
|
|
hello.type = TraversalPacketType::HelloFromClient;
|
|
hello.helloFromClient.protoVersion = TraversalProtoVersion;
|
|
SendTraversalPacket(hello);
|
|
if (m_Client)
|
|
m_Client->OnTraversalStateChanged();
|
|
}
|
|
|
|
static ENetAddress MakeENetAddress(const TraversalInetAddress& address)
|
|
{
|
|
ENetAddress eaddr{};
|
|
if (address.isIPV6)
|
|
{
|
|
eaddr.port = 0; // no support yet :(
|
|
}
|
|
else
|
|
{
|
|
eaddr.host = address.address[0];
|
|
eaddr.port = ntohs(address.port);
|
|
}
|
|
return eaddr;
|
|
}
|
|
|
|
void TraversalClient::ConnectToClient(std::string_view host)
|
|
{
|
|
if (host.size() > sizeof(TraversalHostId))
|
|
{
|
|
PanicAlertFmt("Host too long");
|
|
return;
|
|
}
|
|
TraversalPacket packet = {};
|
|
packet.type = TraversalPacketType::ConnectPlease;
|
|
memcpy(packet.connectPlease.hostId.data(), host.data(), host.size());
|
|
m_ConnectRequestId = SendTraversalPacket(packet);
|
|
m_PendingConnect = true;
|
|
}
|
|
|
|
bool TraversalClient::TestPacket(u8* data, size_t size, ENetAddress* from)
|
|
{
|
|
if (from->host == m_ServerAddress.host && from->port == m_ServerAddress.port)
|
|
{
|
|
if (size < sizeof(TraversalPacket))
|
|
{
|
|
ERROR_LOG_FMT(NETPLAY, "Received too-short traversal packet.");
|
|
}
|
|
else
|
|
{
|
|
HandleServerPacket((TraversalPacket*)data);
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
//--Temporary until more of the old netplay branch is moved over
|
|
void TraversalClient::Update()
|
|
{
|
|
ENetEvent netEvent;
|
|
if (enet_host_service(m_NetHost, &netEvent, 4) > 0)
|
|
{
|
|
switch (netEvent.type)
|
|
{
|
|
case ENET_EVENT_TYPE_RECEIVE:
|
|
TestPacket(netEvent.packet->data, netEvent.packet->dataLength, &netEvent.peer->address);
|
|
|
|
enet_packet_destroy(netEvent.packet);
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
HandleResends();
|
|
}
|
|
|
|
void TraversalClient::HandleServerPacket(TraversalPacket* packet)
|
|
{
|
|
u8 ok = 1;
|
|
switch (packet->type)
|
|
{
|
|
case TraversalPacketType::Ack:
|
|
if (!packet->ack.ok)
|
|
{
|
|
OnFailure(FailureReason::ServerForgotAboutUs);
|
|
break;
|
|
}
|
|
for (auto it = m_OutgoingTraversalPackets.begin(); it != m_OutgoingTraversalPackets.end(); ++it)
|
|
{
|
|
if (it->packet.requestId == packet->requestId)
|
|
{
|
|
m_OutgoingTraversalPackets.erase(it);
|
|
break;
|
|
}
|
|
}
|
|
break;
|
|
case TraversalPacketType::HelloFromServer:
|
|
if (!IsConnecting())
|
|
break;
|
|
if (!packet->helloFromServer.ok)
|
|
{
|
|
OnFailure(FailureReason::VersionTooOld);
|
|
break;
|
|
}
|
|
m_HostId = packet->helloFromServer.yourHostId;
|
|
m_State = State::Connected;
|
|
if (m_Client)
|
|
m_Client->OnTraversalStateChanged();
|
|
break;
|
|
case TraversalPacketType::PleaseSendPacket:
|
|
{
|
|
// security is overrated.
|
|
ENetAddress addr = MakeENetAddress(packet->pleaseSendPacket.address);
|
|
if (addr.port != 0)
|
|
{
|
|
char message[] = "Hello from Dolphin Netplay...";
|
|
ENetBuffer buf;
|
|
buf.data = message;
|
|
buf.dataLength = sizeof(message) - 1;
|
|
enet_socket_send(m_NetHost->socket, &addr, &buf, 1);
|
|
}
|
|
else
|
|
{
|
|
// invalid IPV6
|
|
ok = 0;
|
|
}
|
|
break;
|
|
}
|
|
case TraversalPacketType::ConnectReady:
|
|
case TraversalPacketType::ConnectFailed:
|
|
{
|
|
if (!m_PendingConnect || packet->connectReady.requestId != m_ConnectRequestId)
|
|
break;
|
|
|
|
m_PendingConnect = false;
|
|
|
|
if (!m_Client)
|
|
break;
|
|
|
|
if (packet->type == TraversalPacketType::ConnectReady)
|
|
m_Client->OnConnectReady(MakeENetAddress(packet->connectReady.address));
|
|
else
|
|
m_Client->OnConnectFailed(packet->connectFailed.reason);
|
|
break;
|
|
}
|
|
default:
|
|
WARN_LOG_FMT(NETPLAY, "Received unknown packet with type {}", static_cast<int>(packet->type));
|
|
break;
|
|
}
|
|
if (packet->type != TraversalPacketType::Ack)
|
|
{
|
|
TraversalPacket ack = {};
|
|
ack.type = TraversalPacketType::Ack;
|
|
ack.requestId = packet->requestId;
|
|
ack.ack.ok = ok;
|
|
|
|
ENetBuffer buf;
|
|
buf.data = &ack;
|
|
buf.dataLength = sizeof(ack);
|
|
if (enet_socket_send(m_NetHost->socket, &m_ServerAddress, &buf, 1) == -1)
|
|
OnFailure(FailureReason::SocketSendError);
|
|
}
|
|
}
|
|
|
|
void TraversalClient::OnFailure(FailureReason reason)
|
|
{
|
|
m_State = State::Failure;
|
|
m_FailureReason = reason;
|
|
|
|
if (m_Client)
|
|
m_Client->OnTraversalStateChanged();
|
|
}
|
|
|
|
void TraversalClient::ResendPacket(OutgoingTraversalPacketInfo* info)
|
|
{
|
|
info->sendTime = enet_time_get();
|
|
info->tries++;
|
|
ENetBuffer buf;
|
|
buf.data = &info->packet;
|
|
buf.dataLength = sizeof(info->packet);
|
|
if (enet_socket_send(m_NetHost->socket, &m_ServerAddress, &buf, 1) == -1)
|
|
OnFailure(FailureReason::SocketSendError);
|
|
}
|
|
|
|
void TraversalClient::HandleResends()
|
|
{
|
|
const u32 now = enet_time_get();
|
|
for (auto& tpi : m_OutgoingTraversalPackets)
|
|
{
|
|
if (now - tpi.sendTime >= (u32)(300 * tpi.tries))
|
|
{
|
|
if (tpi.tries >= 5)
|
|
{
|
|
OnFailure(FailureReason::ResendTimeout);
|
|
m_OutgoingTraversalPackets.clear();
|
|
break;
|
|
}
|
|
else
|
|
{
|
|
ResendPacket(&tpi);
|
|
}
|
|
}
|
|
}
|
|
HandlePing();
|
|
}
|
|
|
|
void TraversalClient::HandlePing()
|
|
{
|
|
const u32 now = enet_time_get();
|
|
if (IsConnected() && now - m_PingTime >= 500)
|
|
{
|
|
TraversalPacket ping = {};
|
|
ping.type = TraversalPacketType::Ping;
|
|
ping.ping.hostId = m_HostId;
|
|
SendTraversalPacket(ping);
|
|
m_PingTime = now;
|
|
}
|
|
}
|
|
|
|
TraversalRequestId TraversalClient::SendTraversalPacket(const TraversalPacket& packet)
|
|
{
|
|
OutgoingTraversalPacketInfo info;
|
|
info.packet = packet;
|
|
info.packet.requestId = Common::Random::GenerateValue<TraversalRequestId>();
|
|
info.tries = 0;
|
|
m_OutgoingTraversalPackets.push_back(info);
|
|
ResendPacket(&m_OutgoingTraversalPackets.back());
|
|
return info.packet.requestId;
|
|
}
|
|
|
|
void TraversalClient::Reset()
|
|
{
|
|
m_PendingConnect = false;
|
|
m_Client = nullptr;
|
|
}
|
|
|
|
int ENET_CALLBACK TraversalClient::InterceptCallback(ENetHost* host, ENetEvent* event)
|
|
{
|
|
auto traversalClient = g_TraversalClient.get();
|
|
if (traversalClient->TestPacket(host->receivedData, host->receivedDataLength,
|
|
&host->receivedAddress) ||
|
|
(host->receivedDataLength == 1 && host->receivedData[0] == 0))
|
|
{
|
|
event->type = (ENetEventType)42;
|
|
return 1;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
std::unique_ptr<TraversalClient> g_TraversalClient;
|
|
std::unique_ptr<ENetHost> g_MainNetHost;
|
|
|
|
// The settings at the previous TraversalClient reset - notably, we
|
|
// need to know not just what port it's on, but whether it was
|
|
// explicitly requested.
|
|
static std::string g_OldServer;
|
|
static u16 g_OldServerPort;
|
|
static u16 g_OldListenPort;
|
|
|
|
bool EnsureTraversalClient(const std::string& server, u16 server_port, u16 listen_port)
|
|
{
|
|
if (!g_MainNetHost || !g_TraversalClient || server != g_OldServer ||
|
|
server_port != g_OldServerPort || listen_port != g_OldListenPort)
|
|
{
|
|
g_OldServer = server;
|
|
g_OldServerPort = server_port;
|
|
g_OldListenPort = listen_port;
|
|
|
|
ENetAddress addr = {ENET_HOST_ANY, listen_port};
|
|
ENetHost* host = enet_host_create(&addr, // address
|
|
50, // peerCount
|
|
NetPlay::CHANNEL_COUNT, // channelLimit
|
|
0, // incomingBandwidth
|
|
0); // outgoingBandwidth
|
|
if (!host)
|
|
{
|
|
g_MainNetHost.reset();
|
|
return false;
|
|
}
|
|
g_MainNetHost.reset(host);
|
|
g_TraversalClient.reset(new TraversalClient(g_MainNetHost.get(), server, server_port));
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void ReleaseTraversalClient()
|
|
{
|
|
if (!g_TraversalClient)
|
|
return;
|
|
|
|
g_TraversalClient.reset();
|
|
g_MainNetHost.reset();
|
|
}
|