diff --git a/Source/Core/Core/IOS/Network/Socket.cpp b/Source/Core/Core/IOS/Network/Socket.cpp index ec0ae1641c..29af6231d0 100644 --- a/Source/Core/Core/IOS/Network/Socket.cpp +++ b/Source/Core/Core/IOS/Network/Socket.cpp @@ -19,6 +19,8 @@ #include "Common/BitUtils.h" #include "Common/FileUtil.h" #include "Common/IOFile.h" +#include "Common/Network.h" +#include "Common/ScopeGuard.h" #include "Core/Config/MainSettings.h" #include "Core/Core.h" #include "Core/IOS/Device.h" @@ -225,6 +227,7 @@ s32 WiiSocket::CloseFd() GetIOS()->EnqueueIPCReply(it->request, -SO_ENOTCONN); it = pending_sockops.erase(it); } + connecting_state = ConnectingState::None; return ReturnValue; } @@ -297,6 +300,7 @@ void WiiSocket::Update(bool read, bool write, bool except) int ret = connect(fd, (sockaddr*)&local_name, sizeof(local_name)); ReturnValue = WiiSockMan::GetNetErrorCode(ret, "SO_CONNECT", false); + UpdateConnectingState(ReturnValue); INFO_LOG_FMT(IOS_NET, "IOCTL_SO_CONNECT ({:08x}, {}:{}) = {}", wii_fd, inet_ntoa(local_name.sin_addr), Common::swap16(local_name.sin_port), ret); @@ -342,10 +346,12 @@ void WiiSocket::Update(bool read, bool write, bool except) { ReturnValue = -SO_ENETUNREACH; ResetTimeout(); + connecting_state = ConnectingState::Error; } break; case -SO_EISCONN: ReturnValue = SO_SUCCESS; + connecting_state = ConnectingState::Connected; [[fallthrough]]; default: ResetTimeout(); @@ -393,6 +399,24 @@ void WiiSocket::Update(bool read, bool write, bool except) { case IOCTLV_NET_SSL_DOHANDSHAKE: { + // The Wii allows a socket with an in-progress connection to + // perform the SSL handshake. MbedTLS doesn't support it so + // we have to check it manually. + connecting_state = GetConnectingState(); + if (connecting_state == ConnectingState::Connecting) + { + WriteReturnValue(SSL_ERR_RAGAIN, BufferIn); + ReturnValue = SSL_ERR_RAGAIN; + break; + } + else if (connecting_state == ConnectingState::None || + connecting_state == ConnectingState::Error) + { + WriteReturnValue(SSL_ERR_SYSCALL, BufferIn); + ReturnValue = SSL_ERR_SYSCALL; + break; + } + mbedtls_ssl_context* ctx = &NetSSLDevice::_SSL[sslID].ctx; const int ret = mbedtls_ssl_handshake(ctx); if (ret != 0) @@ -673,6 +697,100 @@ void WiiSocket::Update(bool read, bool write, bool except) } } +void WiiSocket::UpdateConnectingState(s32 connect_rv) +{ + if (connect_rv == -SO_EAGAIN || connect_rv == -SO_EALREADY || connect_rv == -SO_EINPROGRESS) + { + connecting_state = ConnectingState::Connecting; + } + else if (connect_rv >= 0) + { + connecting_state = ConnectingState::Connected; + } + else + { + connecting_state = ConnectingState::Error; + } +} + +WiiSocket::ConnectingState WiiSocket::GetConnectingState() const +{ + const auto state = Common::SaveNetworkErrorState(); + Common::ScopeGuard guard([&state] { Common::RestoreNetworkErrorState(state); }); + +#ifdef _WIN32 + constexpr int (*get_errno)() = &WSAGetLastError; +#else + constexpr int (*get_errno)() = []() { return errno; }; +#endif + + switch (connecting_state) + { + case ConnectingState::Error: + case ConnectingState::Connected: + case ConnectingState::None: + break; + case ConnectingState::Connecting: + { + const s32 nfds = fd + 1; + fd_set read_fds; + fd_set write_fds; + fd_set except_fds; + struct timeval t = {0, 0}; + FD_ZERO(&read_fds); + FD_ZERO(&write_fds); + FD_ZERO(&except_fds); + FD_SET(fd, &write_fds); + FD_SET(fd, &except_fds); + + auto& sm = WiiSockMan::GetInstance(); + if (select(nfds, &read_fds, &write_fds, &except_fds, &t) < 0) + { + const s32 error = get_errno(); + ERROR_LOG_FMT(IOS_SSL, "Failed to get socket (fd={}) connection state (err={}): {}", wii_fd, + error, sm.DecodeError(error)); + return ConnectingState::Error; + } + + if (FD_ISSET(fd, &write_fds) == 0 && FD_ISSET(fd, &except_fds) == 0) + break; + + s32 error = 0; + socklen_t len = sizeof(error); + if (getsockopt(fd, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len) != 0) + { + error = get_errno(); + ERROR_LOG_FMT(IOS_SSL, "Failed to get socket (fd={}) error state (err={}): {}", wii_fd, error, + sm.DecodeError(error)); + return ConnectingState::Error; + } + + if (error != 0) + { + ERROR_LOG_FMT(IOS_SSL, "Non-blocking connect (fd={}) failed (err={}): {}", wii_fd, error, + sm.DecodeError(error)); + return ConnectingState::Error; + } + + // Get peername to ensure the socket is connected + sockaddr_in peer; + socklen_t peer_len = sizeof(peer); + if (getpeername(fd, reinterpret_cast(&peer), &peer_len) != 0) + { + error = get_errno(); + ERROR_LOG_FMT(IOS_SSL, "Non-blocking connect (fd={}) failed to get peername (err={}): {}", + wii_fd, error, sm.DecodeError(error)); + return ConnectingState::Error; + } + + INFO_LOG_FMT(IOS_SSL, "Non-blocking connect (fd={}) succeeded", wii_fd); + return ConnectingState::Connected; + } + } + + return connecting_state; +} + const WiiSocket::Timeout& WiiSocket::GetTimeout() { if (!timeout.has_value()) diff --git a/Source/Core/Core/IOS/Network/Socket.h b/Source/Core/Core/IOS/Network/Socket.h index 8f6e02a8bb..9e2509fcc0 100644 --- a/Source/Core/Core/IOS/Network/Socket.h +++ b/Source/Core/Core/IOS/Network/Socket.h @@ -199,6 +199,14 @@ private: void Abort(s32 value); }; + enum class ConnectingState + { + None, + Connecting, + Connected, + Error + }; + friend class WiiSockMan; void SetFd(s32 s); void SetWiiFd(s32 s); @@ -212,11 +220,14 @@ private: void DoSock(Request request, NET_IOCTL type); void DoSock(Request request, SSL_IOCTL type); void Update(bool read, bool write, bool except); + void UpdateConnectingState(s32 connect_rv); + ConnectingState GetConnectingState() const; bool IsValid() const { return fd >= 0; } s32 fd = -1; s32 wii_fd = -1; bool nonBlock = false; + ConnectingState connecting_state = ConnectingState::None; std::list pending_sockops; std::optional timeout;