Socket: Keep track of the socket connection progress

Workaround for mbedtls handshake issues with partially connected sockets
This commit is contained in:
Sepalani 2022-05-27 01:57:00 +04:00
parent cbadc6e81a
commit c53a4c8c1a
2 changed files with 129 additions and 0 deletions

View File

@ -19,6 +19,8 @@
#include "Common/BitUtils.h"
#include "Common/FileUtil.h"
#include "Common/IOFile.h"
#include "Common/Network.h"
#include "Common/ScopeGuard.h"
#include "Core/Config/MainSettings.h"
#include "Core/Core.h"
#include "Core/IOS/Device.h"
@ -225,6 +227,7 @@ s32 WiiSocket::CloseFd()
GetIOS()->EnqueueIPCReply(it->request, -SO_ENOTCONN);
it = pending_sockops.erase(it);
}
connecting_state = ConnectingState::None;
return ReturnValue;
}
@ -297,6 +300,7 @@ void WiiSocket::Update(bool read, bool write, bool except)
int ret = connect(fd, (sockaddr*)&local_name, sizeof(local_name));
ReturnValue = WiiSockMan::GetNetErrorCode(ret, "SO_CONNECT", false);
UpdateConnectingState(ReturnValue);
INFO_LOG_FMT(IOS_NET, "IOCTL_SO_CONNECT ({:08x}, {}:{}) = {}", wii_fd,
inet_ntoa(local_name.sin_addr), Common::swap16(local_name.sin_port), ret);
@ -342,10 +346,12 @@ void WiiSocket::Update(bool read, bool write, bool except)
{
ReturnValue = -SO_ENETUNREACH;
ResetTimeout();
connecting_state = ConnectingState::Error;
}
break;
case -SO_EISCONN:
ReturnValue = SO_SUCCESS;
connecting_state = ConnectingState::Connected;
[[fallthrough]];
default:
ResetTimeout();
@ -393,6 +399,24 @@ void WiiSocket::Update(bool read, bool write, bool except)
{
case IOCTLV_NET_SSL_DOHANDSHAKE:
{
// The Wii allows a socket with an in-progress connection to
// perform the SSL handshake. MbedTLS doesn't support it so
// we have to check it manually.
connecting_state = GetConnectingState();
if (connecting_state == ConnectingState::Connecting)
{
WriteReturnValue(SSL_ERR_RAGAIN, BufferIn);
ReturnValue = SSL_ERR_RAGAIN;
break;
}
else if (connecting_state == ConnectingState::None ||
connecting_state == ConnectingState::Error)
{
WriteReturnValue(SSL_ERR_SYSCALL, BufferIn);
ReturnValue = SSL_ERR_SYSCALL;
break;
}
mbedtls_ssl_context* ctx = &NetSSLDevice::_SSL[sslID].ctx;
const int ret = mbedtls_ssl_handshake(ctx);
if (ret != 0)
@ -673,6 +697,100 @@ void WiiSocket::Update(bool read, bool write, bool except)
}
}
void WiiSocket::UpdateConnectingState(s32 connect_rv)
{
if (connect_rv == -SO_EAGAIN || connect_rv == -SO_EALREADY || connect_rv == -SO_EINPROGRESS)
{
connecting_state = ConnectingState::Connecting;
}
else if (connect_rv >= 0)
{
connecting_state = ConnectingState::Connected;
}
else
{
connecting_state = ConnectingState::Error;
}
}
WiiSocket::ConnectingState WiiSocket::GetConnectingState() const
{
const auto state = Common::SaveNetworkErrorState();
Common::ScopeGuard guard([&state] { Common::RestoreNetworkErrorState(state); });
#ifdef _WIN32
constexpr int (*get_errno)() = &WSAGetLastError;
#else
constexpr int (*get_errno)() = []() { return errno; };
#endif
switch (connecting_state)
{
case ConnectingState::Error:
case ConnectingState::Connected:
case ConnectingState::None:
break;
case ConnectingState::Connecting:
{
const s32 nfds = fd + 1;
fd_set read_fds;
fd_set write_fds;
fd_set except_fds;
struct timeval t = {0, 0};
FD_ZERO(&read_fds);
FD_ZERO(&write_fds);
FD_ZERO(&except_fds);
FD_SET(fd, &write_fds);
FD_SET(fd, &except_fds);
auto& sm = WiiSockMan::GetInstance();
if (select(nfds, &read_fds, &write_fds, &except_fds, &t) < 0)
{
const s32 error = get_errno();
ERROR_LOG_FMT(IOS_SSL, "Failed to get socket (fd={}) connection state (err={}): {}", wii_fd,
error, sm.DecodeError(error));
return ConnectingState::Error;
}
if (FD_ISSET(fd, &write_fds) == 0 && FD_ISSET(fd, &except_fds) == 0)
break;
s32 error = 0;
socklen_t len = sizeof(error);
if (getsockopt(fd, SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&error), &len) != 0)
{
error = get_errno();
ERROR_LOG_FMT(IOS_SSL, "Failed to get socket (fd={}) error state (err={}): {}", wii_fd, error,
sm.DecodeError(error));
return ConnectingState::Error;
}
if (error != 0)
{
ERROR_LOG_FMT(IOS_SSL, "Non-blocking connect (fd={}) failed (err={}): {}", wii_fd, error,
sm.DecodeError(error));
return ConnectingState::Error;
}
// Get peername to ensure the socket is connected
sockaddr_in peer;
socklen_t peer_len = sizeof(peer);
if (getpeername(fd, reinterpret_cast<sockaddr*>(&peer), &peer_len) != 0)
{
error = get_errno();
ERROR_LOG_FMT(IOS_SSL, "Non-blocking connect (fd={}) failed to get peername (err={}): {}",
wii_fd, error, sm.DecodeError(error));
return ConnectingState::Error;
}
INFO_LOG_FMT(IOS_SSL, "Non-blocking connect (fd={}) succeeded", wii_fd);
return ConnectingState::Connected;
}
}
return connecting_state;
}
const WiiSocket::Timeout& WiiSocket::GetTimeout()
{
if (!timeout.has_value())

View File

@ -199,6 +199,14 @@ private:
void Abort(s32 value);
};
enum class ConnectingState
{
None,
Connecting,
Connected,
Error
};
friend class WiiSockMan;
void SetFd(s32 s);
void SetWiiFd(s32 s);
@ -212,11 +220,14 @@ private:
void DoSock(Request request, NET_IOCTL type);
void DoSock(Request request, SSL_IOCTL type);
void Update(bool read, bool write, bool except);
void UpdateConnectingState(s32 connect_rv);
ConnectingState GetConnectingState() const;
bool IsValid() const { return fd >= 0; }
s32 fd = -1;
s32 wii_fd = -1;
bool nonBlock = false;
ConnectingState connecting_state = ConnectingState::None;
std::list<sockop> pending_sockops;
std::optional<Timeout> timeout;