// SPDX-License-Identifier: CC0-1.0 // The central server implementation. #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifdef HAVE_LIBSYSTEMD #include #endif #include "Common/Random.h" #include "Common/TraversalProto.h" #define DEBUG 0 #define NUMBER_OF_TRIES 5 #define PORT 6262 #define PORT_ALT 6226 static u64 currentTime; struct OutgoingPacketInfo { Common::TraversalPacket packet; Common::TraversalRequestId misc; bool fromAlt; sockaddr_in6 dest; int tries; u64 sendTime; }; template struct EvictEntry { u64 updateTime; T value; }; template struct EvictFindResult { bool found; V* value; }; template EvictFindResult EvictFind(std::unordered_map>& map, const K& key, bool refresh = false) { retry: const u64 expiryTime = 30 * 1000000; // 30s EvictFindResult result; if (map.bucket_count()) { auto bucket = map.bucket(key); auto it = map.begin(bucket); for (; it != map.end(bucket); ++it) { if (currentTime - it->second.updateTime > expiryTime) { map.erase(it->first); goto retry; } if (it->first == key) { if (refresh) it->second.updateTime = currentTime; result.found = true; result.value = &it->second.value; return result; } } } #if DEBUG fmt::print("failed to find key '"); for (size_t i = 0; i < sizeof(key); i++) { fmt::print("{:02x}", ((u8*)&key)[i]); } fmt::print("'\n"); #endif result.found = false; return result; } template V* EvictSet(std::unordered_map>& map, const K& key) { // can't use a local_iterator to emplace... auto& result = map[key]; result.updateTime = currentTime; return &result.value; } namespace std { template <> struct hash { size_t operator()(const Common::TraversalHostId& id) const noexcept { auto p = (u32*)id.data(); return p[0] ^ ((p[1] << 13) | (p[1] >> 19)); } }; } // namespace std using ConnectedClients = std::unordered_map>; using OutgoingPackets = std::unordered_map; static int sock; static int sockAlt; static OutgoingPackets outgoingPackets; static ConnectedClients connectedClients; static Common::TraversalInetAddress MakeInetAddress(const sockaddr_in6& addr) { if (addr.sin6_family != AF_INET6) { fmt::print(stderr, "bad sockaddr_in6\n"); exit(1); } u32* words = (u32*)addr.sin6_addr.s6_addr; Common::TraversalInetAddress result = {}; if (words[0] == 0 && words[1] == 0 && words[2] == 0xffff0000) { result.isIPV6 = false; result.address[0] = words[3]; } else { result.isIPV6 = true; memcpy(result.address, words, sizeof(result.address)); } result.port = addr.sin6_port; return result; } static sockaddr_in6 MakeSinAddr(const Common::TraversalInetAddress& addr) { sockaddr_in6 result; #ifdef SIN6_LEN result.sin6_len = sizeof(result); #endif result.sin6_family = AF_INET6; result.sin6_port = addr.port; result.sin6_flowinfo = 0; if (addr.isIPV6) { memcpy(&result.sin6_addr, addr.address, 16); } else { u32* words = (u32*)result.sin6_addr.s6_addr; words[0] = 0; words[1] = 0; words[2] = 0xffff0000; words[3] = addr.address[0]; } result.sin6_scope_id = 0; return result; } static void GetRandomHostId(Common::TraversalHostId* hostId) { char buf[9]{}; const u32 num = Common::Random::GenerateValue(); fmt::format_to_n(buf, sizeof(buf) - 1, "{:08x}", num); memcpy(hostId->data(), buf, 8); } static const char* SenderName(sockaddr_in6* addr) { static char buf[INET6_ADDRSTRLEN + 10]{}; inet_ntop(PF_INET6, &addr->sin6_addr, buf, sizeof(buf)); fmt::format_to(buf + strlen(buf), ":{}", ntohs(addr->sin6_port)); return buf; } static void TrySend(const void* buffer, size_t size, sockaddr_in6* addr, bool fromAlt) { #if DEBUG const auto* packet = static_cast(buffer); fmt::print("{}-> {} {} {}\n", fromAlt ? "alt " : "", static_cast(packet->type), static_cast(packet->requestId), SenderName(addr)); #endif if ((size_t)sendto(fromAlt ? sockAlt : sock, buffer, size, 0, (sockaddr*)addr, sizeof(*addr)) != size) { perror("sendto"); } } static Common::TraversalPacket* AllocPacket(const sockaddr_in6& dest, bool fromAlt, Common::TraversalRequestId misc = 0) { Common::TraversalRequestId requestId{}; Common::Random::Generate(&requestId, sizeof(requestId)); OutgoingPacketInfo* info = &outgoingPackets[requestId]; info->fromAlt = fromAlt; info->dest = dest; info->misc = misc; info->tries = 0; info->sendTime = currentTime; Common::TraversalPacket* result = &info->packet; memset(result, 0, sizeof(*result)); result->requestId = requestId; return result; } static void SendPacket(OutgoingPacketInfo* info) { info->tries++; info->sendTime = currentTime; TrySend(&info->packet, sizeof(info->packet), &info->dest, info->fromAlt); } static void ResendPackets() { std::vector> todoFailures; todoFailures.clear(); for (auto it = outgoingPackets.begin(); it != outgoingPackets.end();) { OutgoingPacketInfo* info = &it->second; if (currentTime - info->sendTime >= (u64)(300000 * info->tries)) { if (info->tries >= NUMBER_OF_TRIES) { if (info->packet.type == Common::TraversalPacketType::PleaseSendPacket) { todoFailures.push_back( std::make_tuple(info->packet.pleaseSendPacket.address, info->fromAlt, info->misc)); } it = outgoingPackets.erase(it); continue; } else { SendPacket(info); } } ++it; } for (const auto& p : todoFailures) { Common::TraversalPacket* fail = AllocPacket(MakeSinAddr(std::get<0>(p)), std::get<1>(p)); fail->type = Common::TraversalPacketType::ConnectFailed; fail->connectFailed.requestId = std::get<2>(p); fail->connectFailed.reason = Common::TraversalConnectFailedReason::ClientDidntRespond; } } static void HandlePacket(Common::TraversalPacket* packet, sockaddr_in6* addr, bool toAlt) { #if DEBUG fmt::print("<- {} {} {}\n", static_cast(packet->type), static_cast(packet->requestId), SenderName(addr)); #endif bool packetOk = true; switch (packet->type) { case Common::TraversalPacketType::Ack: { auto it = outgoingPackets.find(packet->requestId); if (it == outgoingPackets.end()) break; OutgoingPacketInfo* info = &it->second; if (info->packet.type == Common::TraversalPacketType::PleaseSendPacket) { auto* ready = AllocPacket(MakeSinAddr(info->packet.pleaseSendPacket.address), toAlt); if (packet->ack.ok) { ready->type = Common::TraversalPacketType::ConnectReady; ready->connectReady.requestId = info->misc; ready->connectReady.address = MakeInetAddress(info->dest); } else { ready->type = Common::TraversalPacketType::ConnectFailed; ready->connectFailed.requestId = info->misc; ready->connectFailed.reason = Common::TraversalConnectFailedReason::ClientFailure; } } outgoingPackets.erase(it); break; } case Common::TraversalPacketType::Ping: { auto r = EvictFind(connectedClients, packet->ping.hostId, true); packetOk = r.found; break; } case Common::TraversalPacketType::HelloFromClient: { u8 ok = packet->helloFromClient.protoVersion <= Common::TraversalProtoVersion; Common::TraversalPacket* reply = AllocPacket(*addr, toAlt); reply->type = Common::TraversalPacketType::HelloFromServer; reply->helloFromServer.ok = ok; if (ok) { Common::TraversalHostId hostId{}; Common::TraversalInetAddress* iaddr{}; // not that there is any significant change of // duplication, but... while (true) { GetRandomHostId(&hostId); auto r = EvictFind(connectedClients, hostId); if (!r.found) { iaddr = EvictSet(connectedClients, hostId); break; } } *iaddr = MakeInetAddress(*addr); reply->helloFromServer.yourAddress = *iaddr; reply->helloFromServer.yourHostId = hostId; } break; } case Common::TraversalPacketType::ConnectPlease: { Common::TraversalHostId& hostId = packet->connectPlease.hostId; auto r = EvictFind(connectedClients, hostId); if (!r.found) { Common::TraversalPacket* reply = AllocPacket(*addr, toAlt); reply->type = Common::TraversalPacketType::ConnectFailed; reply->connectFailed.requestId = packet->requestId; reply->connectFailed.reason = Common::TraversalConnectFailedReason::NoSuchClient; } else { Common::TraversalPacket* please = AllocPacket(MakeSinAddr(*r.value), toAlt, packet->requestId); please->type = Common::TraversalPacketType::PleaseSendPacket; please->pleaseSendPacket.address = MakeInetAddress(*addr); } break; } case Common::TraversalPacketType::TestPlease: { Common::TraversalHostId& hostId = packet->testPlease.hostId; auto r = EvictFind(connectedClients, hostId); if (r.found) { Common::TraversalPacket ack = {}; ack.type = Common::TraversalPacketType::Ack; ack.requestId = packet->requestId; ack.ack.ok = true; sockaddr_in6 mainAddr = MakeSinAddr(*r.value); TrySend(&ack, sizeof(ack), &mainAddr, toAlt); } break; } default: fmt::print(stderr, "received unknown packet type {} from {}\n", static_cast(packet->type), SenderName(addr)); break; } if (packet->type != Common::TraversalPacketType::Ack) { Common::TraversalPacket ack = {}; ack.type = Common::TraversalPacketType::Ack; ack.requestId = packet->requestId; ack.ack.ok = packetOk; TrySend(&ack, sizeof(ack), addr, packet->type != Common::TraversalPacketType::TestPlease ? toAlt : !toAlt); } } int main() { int rv; sock = socket(PF_INET6, SOCK_DGRAM, 0); if (sock == -1) { perror("socket"); return 1; } sockAlt = socket(PF_INET6, SOCK_DGRAM, 0); if (sockAlt == -1) { perror("socket alt"); return 1; } int no = 0; rv = setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, &no, sizeof(no)); if (rv < 0) { perror("setsockopt IPV6_V6ONLY"); return 1; } rv = setsockopt(sockAlt, IPPROTO_IPV6, IPV6_V6ONLY, &no, sizeof(no)); if (rv < 0) { perror("setsockopt IPV6_V6ONLY alt"); return 1; } in6_addr any = IN6ADDR_ANY_INIT; sockaddr_in6 addr; #ifdef SIN6_LEN addr.sin6_len = sizeof(addr); #endif addr.sin6_family = AF_INET6; addr.sin6_port = htons(PORT); addr.sin6_flowinfo = 0; addr.sin6_addr = any; addr.sin6_scope_id = 0; rv = bind(sock, (sockaddr*)&addr, sizeof(addr)); if (rv < 0) { perror("bind"); return 1; } addr.sin6_port = htons(PORT_ALT); rv = bind(sockAlt, (sockaddr*)&addr, sizeof(addr)); if (rv < 0) { perror("bind alt"); return 1; } timeval tv; tv.tv_sec = 0; tv.tv_usec = 300000; rv = setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); if (rv < 0) { perror("setsockopt SO_RCVTIMEO"); return 1; } rv = setsockopt(sockAlt, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); if (rv < 0) { perror("setsockopt SO_RCVTIMEO alt"); return 1; } #ifdef HAVE_LIBSYSTEMD sd_notifyf(0, "READY=1\nSTATUS=Listening on port %d (alt port: %d)", PORT, PORT_ALT); #endif while (true) { tv.tv_sec = 0; tv.tv_usec = 300000; fd_set readSet; FD_ZERO(&readSet); FD_SET(sock, &readSet); FD_SET(sockAlt, &readSet); rv = select(std::max(sock, sockAlt) + 1, &readSet, nullptr, nullptr, &tv); if (rv < 0) { if (errno != EINTR && errno != EAGAIN) { perror("recvfrom"); return 1; } } int recvsock; if (FD_ISSET(sock, &readSet)) { recvsock = sock; } else if (FD_ISSET(sockAlt, &readSet)) { recvsock = sockAlt; } else { ResendPackets(); continue; } sockaddr_in6 raddr; socklen_t addrLen = sizeof(raddr); Common::TraversalPacket packet{}; // note: switch to recvmmsg (yes, mmsg) if this becomes // expensive rv = recvfrom(recvsock, &packet, sizeof(packet), 0, (sockaddr*)&raddr, &addrLen); currentTime = std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()) .count(); if (rv < 0) { if (errno != EINTR && errno != EAGAIN) { perror("recvfrom"); return 1; } } else if ((size_t)rv < sizeof(packet)) { fmt::print(stderr, "received short packet from {}\n", SenderName(&raddr)); } else { HandlePacket(&packet, &raddr, recvsock == sockAlt); } ResendPackets(); #ifdef HAVE_LIBSYSTEMD sd_notify(0, "WATCHDOG=1"); #endif } }