diff --git a/Source/Core/Common/Network.cpp b/Source/Core/Common/Network.cpp index 2c548347a3..faa2a33c05 100644 --- a/Source/Core/Common/Network.cpp +++ b/Source/Core/Common/Network.cpp @@ -186,4 +186,22 @@ u16 ComputeNetworkChecksum(const void* data, u16 length, u32 initial_value) checksum = (checksum >> 16) + (checksum & 0xFFFF); return ~static_cast(checksum); } + +NetworkErrorState SaveNetworkErrorState() +{ + return { + errno, +#ifdef _WIN32 + WSAGetLastError(), +#endif + }; +} + +void RestoreNetworkErrorState(const NetworkErrorState& state) +{ + errno = state.error; +#ifdef _WIN32 + WSASetLastError(state.wsa_error); +#endif +} } // namespace Common diff --git a/Source/Core/Common/Network.h b/Source/Core/Common/Network.h index a1fcc9ea42..0b56e09d34 100644 --- a/Source/Core/Common/Network.h +++ b/Source/Core/Common/Network.h @@ -99,8 +99,18 @@ struct UDPHeader }; static_assert(sizeof(UDPHeader) == UDPHeader::SIZE); +struct NetworkErrorState +{ + int error; +#ifdef _WIN32 + int wsa_error; +#endif +}; + MACAddress GenerateMacAddress(MACConsumer type); std::string MacAddressToString(const MACAddress& mac); std::optional StringToMacAddress(std::string_view mac_string); u16 ComputeNetworkChecksum(const void* data, u16 length, u32 initial_value = 0); +NetworkErrorState SaveNetworkErrorState(); +void RestoreNetworkErrorState(const NetworkErrorState& state); } // namespace Common diff --git a/Source/Core/Core/IOS/Network/Socket.cpp b/Source/Core/Core/IOS/Network/Socket.cpp index 9537c8cdff..dcccf5977f 100644 --- a/Source/Core/Core/IOS/Network/Socket.cpp +++ b/Source/Core/Core/IOS/Network/Socket.cpp @@ -16,8 +16,11 @@ #include #endif +#include "Common/BitUtils.h" #include "Common/FileUtil.h" #include "Common/IOFile.h" +#include "Common/Network.h" +#include "Common/ScopeGuard.h" #include "Core/Config/MainSettings.h" #include "Core/Core.h" #include "Core/IOS/Device.h" @@ -224,6 +227,7 @@ s32 WiiSocket::CloseFd() GetIOS()->EnqueueIPCReply(it->request, -SO_ENOTCONN); it = pending_sockops.erase(it); } + connecting_state = ConnectingState::None; return ReturnValue; } @@ -278,8 +282,8 @@ void WiiSocket::Update(bool read, bool write, bool except) case IOCTL_SO_BIND: { sockaddr_in local_name; - WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(ioctl.buffer_in + 8); - WiiSockMan::Convert(*wii_name, local_name); + const u8* addr = Memory::GetPointer(ioctl.buffer_in + 8); + WiiSockMan::ToNativeAddrIn(addr, &local_name); int ret = bind(fd, (sockaddr*)&local_name, sizeof(local_name)); ReturnValue = WiiSockMan::GetNetErrorCode(ret, "SO_BIND", false); @@ -291,11 +295,12 @@ void WiiSocket::Update(bool read, bool write, bool except) case IOCTL_SO_CONNECT: { sockaddr_in local_name; - WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(ioctl.buffer_in + 8); - WiiSockMan::Convert(*wii_name, local_name); + const u8* addr = Memory::GetPointer(ioctl.buffer_in + 8); + WiiSockMan::ToNativeAddrIn(addr, &local_name); int ret = connect(fd, (sockaddr*)&local_name, sizeof(local_name)); ReturnValue = WiiSockMan::GetNetErrorCode(ret, "SO_CONNECT", false); + UpdateConnectingState(ReturnValue); INFO_LOG_FMT(IOS_NET, "IOCTL_SO_CONNECT ({:08x}, {}:{}) = {}", wii_fd, inet_ntoa(local_name.sin_addr), Common::swap16(local_name.sin_port), ret); @@ -307,13 +312,13 @@ void WiiSocket::Update(bool read, bool write, bool except) if (ioctl.buffer_out_size > 0) { sockaddr_in local_name; - WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(ioctl.buffer_out); - WiiSockMan::Convert(*wii_name, local_name); + u8* addr = Memory::GetPointer(ioctl.buffer_out); + WiiSockMan::ToNativeAddrIn(addr, &local_name); socklen_t addrlen = sizeof(sockaddr_in); ret = static_cast(accept(fd, (sockaddr*)&local_name, &addrlen)); - WiiSockMan::Convert(local_name, *wii_name, addrlen); + WiiSockMan::ToWiiAddrIn(local_name, addr, addrlen); } else { @@ -341,10 +346,12 @@ void WiiSocket::Update(bool read, bool write, bool except) { ReturnValue = -SO_ENETUNREACH; ResetTimeout(); + connecting_state = ConnectingState::Error; } break; case -SO_EISCONN: ReturnValue = SO_SUCCESS; + connecting_state = ConnectingState::Connected; [[fallthrough]]; default: ResetTimeout(); @@ -392,6 +399,24 @@ void WiiSocket::Update(bool read, bool write, bool except) { case IOCTLV_NET_SSL_DOHANDSHAKE: { + // The Wii allows a socket with an in-progress connection to + // perform the SSL handshake. MbedTLS doesn't support it so + // we have to check it manually. + connecting_state = GetConnectingState(); + if (connecting_state == ConnectingState::Connecting) + { + WriteReturnValue(SSL_ERR_RAGAIN, BufferIn); + ReturnValue = SSL_ERR_RAGAIN; + break; + } + else if (connecting_state == ConnectingState::None || + connecting_state == ConnectingState::Error) + { + WriteReturnValue(SSL_ERR_SYSCALL, BufferIn); + ReturnValue = SSL_ERR_SYSCALL; + break; + } + mbedtls_ssl_context* ctx = &NetSSLDevice::_SSL[sslID].ctx; const int ret = mbedtls_ssl_handshake(ctx); if (ret != 0) @@ -550,6 +575,16 @@ void WiiSocket::Update(bool read, bool write, bool except) { case IOCTLV_SO_SENDTO: { + // The Wii allows a socket with a connection in progress to use + // sendto(). This might not be supported by the operating system. + // We have to enforce it manually. + connecting_state = GetConnectingState(); + if (nonBlock && IsTCP() && connecting_state == ConnectingState::Connecting) + { + ReturnValue = -SO_EAGAIN; + break; + } + u32 flags = Memory::Read_U32(BufferIn2 + 0x04); u32 has_destaddr = Memory::Read_U32(BufferIn2 + 0x08); @@ -564,8 +599,8 @@ void WiiSocket::Update(bool read, bool write, bool except) sockaddr_in local_name = {0}; if (has_destaddr) { - WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(BufferIn2 + 0x0C); - WiiSockMan::Convert(*wii_name, local_name); + const u8* addr = Memory::GetPointer(BufferIn2 + 0x0C); + WiiSockMan::ToNativeAddrIn(addr, &local_name); } auto* to = has_destaddr ? reinterpret_cast(&local_name) : nullptr; @@ -587,6 +622,16 @@ void WiiSocket::Update(bool read, bool write, bool except) } case IOCTLV_SO_RECVFROM: { + // The Wii allows a socket with a connection in progress to use + // recvfrom(). This might not be supported by the operating system. + // We have to enforce it manually. + connecting_state = GetConnectingState(); + if (nonBlock && IsTCP() && connecting_state == ConnectingState::Connecting) + { + ReturnValue = -SO_EAGAIN; + break; + } + u32 flags = Memory::Read_U32(BufferIn + 0x04); // Not a string, Windows requires a char* for recvfrom char* data = (char*)Memory::GetPointer(BufferOut); @@ -597,8 +642,8 @@ void WiiSocket::Update(bool read, bool write, bool except) if (BufferOutSize2 != 0) { - WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(BufferOut2); - WiiSockMan::Convert(*wii_name, local_name); + const u8* addr = Memory::GetPointer(BufferOut2); + WiiSockMan::ToNativeAddrIn(addr, &local_name); } // Act as non blocking when SO_MSG_NONBLOCK is specified @@ -634,8 +679,8 @@ void WiiSocket::Update(bool read, bool write, bool except) if (BufferOutSize2 != 0) { - WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(BufferOut2); - WiiSockMan::Convert(local_name, *wii_name, addrlen); + u8* addr = Memory::GetPointer(BufferOut2); + WiiSockMan::ToWiiAddrIn(local_name, addr, addrlen); } break; } @@ -672,6 +717,112 @@ void WiiSocket::Update(bool read, bool write, bool except) } } +void WiiSocket::UpdateConnectingState(s32 connect_rv) +{ + if (connect_rv == -SO_EAGAIN || connect_rv == -SO_EALREADY || connect_rv == -SO_EINPROGRESS) + { + connecting_state = ConnectingState::Connecting; + } + else if (connect_rv >= 0) + { + connecting_state = ConnectingState::Connected; + } + else + { + connecting_state = ConnectingState::Error; + } +} + +WiiSocket::ConnectingState WiiSocket::GetConnectingState() const +{ + const auto state = Common::SaveNetworkErrorState(); + Common::ScopeGuard guard([&state] { Common::RestoreNetworkErrorState(state); }); + +#ifdef _WIN32 + constexpr int (*get_errno)() = &WSAGetLastError; +#else + constexpr int (*get_errno)() = []() { return errno; }; +#endif + + switch (connecting_state) + { + case ConnectingState::Error: + case ConnectingState::Connected: + case ConnectingState::None: + break; + case ConnectingState::Connecting: + { + const s32 nfds = fd + 1; + fd_set read_fds; + fd_set write_fds; + fd_set except_fds; + struct timeval t = {0, 0}; + FD_ZERO(&read_fds); + FD_ZERO(&write_fds); + FD_ZERO(&except_fds); + FD_SET(fd, &write_fds); + FD_SET(fd, &except_fds); + + auto& sm = WiiSockMan::GetInstance(); + if (select(nfds, &read_fds, &write_fds, &except_fds, &t) < 0) + { + const s32 error = get_errno(); + ERROR_LOG_FMT(IOS_SSL, "Failed to get socket (fd={}) connection state (err={}): {}", wii_fd, + error, sm.DecodeError(error)); + return ConnectingState::Error; + } + + if (FD_ISSET(fd, &write_fds) == 0 && FD_ISSET(fd, &except_fds) == 0) + break; + + s32 error = 0; + socklen_t len = sizeof(error); + if (getsockopt(fd, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len) != 0) + { + error = get_errno(); + ERROR_LOG_FMT(IOS_SSL, "Failed to get socket (fd={}) error state (err={}): {}", wii_fd, error, + sm.DecodeError(error)); + return ConnectingState::Error; + } + + if (error != 0) + { + ERROR_LOG_FMT(IOS_SSL, "Non-blocking connect (fd={}) failed (err={}): {}", wii_fd, error, + sm.DecodeError(error)); + return ConnectingState::Error; + } + + // Get peername to ensure the socket is connected + sockaddr_in peer; + socklen_t peer_len = sizeof(peer); + if (getpeername(fd, reinterpret_cast(&peer), &peer_len) != 0) + { + error = get_errno(); + ERROR_LOG_FMT(IOS_SSL, "Non-blocking connect (fd={}) failed to get peername (err={}): {}", + wii_fd, error, sm.DecodeError(error)); + return ConnectingState::Error; + } + + INFO_LOG_FMT(IOS_SSL, "Non-blocking connect (fd={}) succeeded", wii_fd); + return ConnectingState::Connected; + } + } + + return connecting_state; +} + +bool WiiSocket::IsTCP() const +{ + const auto state = Common::SaveNetworkErrorState(); + Common::ScopeGuard guard([&state] { Common::RestoreNetworkErrorState(state); }); + + int socket_type; + socklen_t option_length = sizeof(socket_type); + return getsockopt(fd, SOL_SOCKET, SO_TYPE, reinterpret_cast(&socket_type), + &option_length) == 0 && + socket_type == SOCK_STREAM; +} + const WiiSocket::Timeout& WiiSocket::GetTimeout() { if (!timeout.has_value()) @@ -937,11 +1088,12 @@ void WiiSockMan::UpdatePollCommands() pending_polls.end()); } -void WiiSockMan::Convert(WiiSockAddrIn const& from, sockaddr_in& to) +void WiiSockMan::ToNativeAddrIn(const u8* addr, sockaddr_in* to) { - to.sin_addr.s_addr = from.addr.addr; - to.sin_family = from.family; - to.sin_port = from.port; + const WiiSockAddrIn from = Common::BitCastPtr(addr); + to->sin_addr.s_addr = from.addr.addr; + to->sin_family = from.family; + to->sin_port = from.port; } s32 WiiSockMan::ConvertEvents(s32 events, ConvertDirection dir) @@ -981,15 +1133,15 @@ s32 WiiSockMan::ConvertEvents(s32 events, ConvertDirection dir) return converted_events; } -void WiiSockMan::Convert(sockaddr_in const& from, WiiSockAddrIn& to, s32 addrlen) +void WiiSockMan::ToWiiAddrIn(const sockaddr_in& from, u8* to, socklen_t addrlen) { - to.addr.addr = from.sin_addr.s_addr; - to.family = from.sin_family & 0xFF; - to.port = from.sin_port; - if (addrlen < 0 || addrlen > static_cast(sizeof(WiiSockAddrIn))) - to.len = sizeof(WiiSockAddrIn); - else - to.len = addrlen; + to[offsetof(WiiSockAddrIn, len)] = + u8(addrlen > sizeof(WiiSockAddrIn) ? sizeof(WiiSockAddrIn) : addrlen); + to[offsetof(WiiSockAddrIn, family)] = u8(from.sin_family & 0xFF); + const u16& from_port = from.sin_port; + memcpy(to + offsetof(WiiSockAddrIn, port), &from_port, sizeof(from_port)); + const u32& from_addr = from.sin_addr.s_addr; + memcpy(to + offsetof(WiiSockAddrIn, addr.addr), &from_addr, sizeof(from_addr)); } void WiiSockMan::DoState(PointerWrap& p) diff --git a/Source/Core/Core/IOS/Network/Socket.h b/Source/Core/Core/IOS/Network/Socket.h index f9555da9e7..143391cd89 100644 --- a/Source/Core/Core/IOS/Network/Socket.h +++ b/Source/Core/Core/IOS/Network/Socket.h @@ -199,6 +199,14 @@ private: void Abort(s32 value); }; + enum class ConnectingState + { + None, + Connecting, + Connected, + Error + }; + friend class WiiSockMan; void SetFd(s32 s); void SetWiiFd(s32 s); @@ -212,11 +220,15 @@ private: void DoSock(Request request, NET_IOCTL type); void DoSock(Request request, SSL_IOCTL type); void Update(bool read, bool write, bool except); + void UpdateConnectingState(s32 connect_rv); + ConnectingState GetConnectingState() const; bool IsValid() const { return fd >= 0; } + bool IsTCP() const; s32 fd = -1; s32 wii_fd = -1; bool nonBlock = false; + ConnectingState connecting_state = ConnectingState::None; std::list pending_sockops; std::optional timeout; @@ -248,8 +260,9 @@ public: return instance; // Instantiated on first use. } void Update(); - static void Convert(WiiSockAddrIn const& from, sockaddr_in& to); - static void Convert(sockaddr_in const& from, WiiSockAddrIn& to, s32 addrlen = -1); + static void ToNativeAddrIn(const u8* from, sockaddr_in* to); + static void ToWiiAddrIn(const sockaddr_in& from, u8* to, + socklen_t addrlen = sizeof(WiiSockAddrIn)); static s32 ConvertEvents(s32 events, ConvertDirection dir); void DoState(PointerWrap& p); diff --git a/Source/Core/Core/NetworkCaptureLogger.cpp b/Source/Core/Core/NetworkCaptureLogger.cpp index 2e4242fd88..064b5ed25b 100644 --- a/Source/Core/Core/NetworkCaptureLogger.cpp +++ b/Source/Core/Core/NetworkCaptureLogger.cpp @@ -15,6 +15,7 @@ #include "Common/IOFile.h" #include "Common/Network.h" #include "Common/PcapFile.h" +#include "Common/ScopeGuard.h" #include "Core/Config/MainSettings.h" #include "Core/ConfigManager.h" @@ -90,24 +91,6 @@ void PCAPSSLCaptureLogger::OnNewSocket(s32 socket) m_write_sequence_number[socket] = 0; } -PCAPSSLCaptureLogger::ErrorState PCAPSSLCaptureLogger::SaveState() const -{ - return { - errno, -#ifdef _WIN32 - WSAGetLastError(), -#endif - }; -} - -void PCAPSSLCaptureLogger::RestoreState(const PCAPSSLCaptureLogger::ErrorState& state) const -{ - errno = state.error; -#ifdef _WIN32 - WSASetLastError(state.wsa_error); -#endif -} - void PCAPSSLCaptureLogger::LogSSLRead(const void* data, std::size_t length, s32 socket) { if (!Config::Get(Config::MAIN_NETWORK_SSL_DUMP_READ)) @@ -135,7 +118,8 @@ void PCAPSSLCaptureLogger::LogWrite(const void* data, std::size_t length, s32 so void PCAPSSLCaptureLogger::Log(LogType log_type, const void* data, std::size_t length, s32 socket, sockaddr* other) { - const auto state = SaveState(); + const auto state = Common::SaveNetworkErrorState(); + Common::ScopeGuard guard([&state] { Common::RestoreNetworkErrorState(state); }); sockaddr_in sock; sockaddr_in peer; sockaddr_in* from; @@ -144,16 +128,10 @@ void PCAPSSLCaptureLogger::Log(LogType log_type, const void* data, std::size_t l socklen_t peer_len = sizeof(sock); if (getsockname(socket, reinterpret_cast(&sock), &sock_len) != 0) - { - RestoreState(state); return; - } if (other == nullptr && getpeername(socket, reinterpret_cast(&peer), &peer_len) != 0) - { - RestoreState(state); return; - } if (log_type == LogType::Read) { @@ -168,7 +146,6 @@ void PCAPSSLCaptureLogger::Log(LogType log_type, const void* data, std::size_t l LogIPv4(log_type, reinterpret_cast(data), static_cast(length), socket, *from, *to); - RestoreState(state); } void PCAPSSLCaptureLogger::LogIPv4(LogType log_type, const u8* data, u16 length, s32 socket, diff --git a/Source/Core/Core/NetworkCaptureLogger.h b/Source/Core/Core/NetworkCaptureLogger.h index 032fde267c..d5067252d4 100644 --- a/Source/Core/Core/NetworkCaptureLogger.h +++ b/Source/Core/Core/NetworkCaptureLogger.h @@ -99,15 +99,6 @@ private: Read, Write, }; - struct ErrorState - { - int error; -#ifdef _WIN32 - int wsa_error; -#endif - }; - ErrorState SaveState() const; - void RestoreState(const ErrorState& state) const; void Log(LogType log_type, const void* data, std::size_t length, s32 socket, sockaddr* other); void LogIPv4(LogType log_type, const u8* data, u16 length, s32 socket, const sockaddr_in& from,