Merge pull request #10700 from sepalani/ssl-handshake
Socket: Fix some non-blocking connect edge cases
This commit is contained in:
commit
e50e45f400
|
@ -186,4 +186,22 @@ u16 ComputeNetworkChecksum(const void* data, u16 length, u32 initial_value)
|
|||
checksum = (checksum >> 16) + (checksum & 0xFFFF);
|
||||
return ~static_cast<u16>(checksum);
|
||||
}
|
||||
|
||||
NetworkErrorState SaveNetworkErrorState()
|
||||
{
|
||||
return {
|
||||
errno,
|
||||
#ifdef _WIN32
|
||||
WSAGetLastError(),
|
||||
#endif
|
||||
};
|
||||
}
|
||||
|
||||
void RestoreNetworkErrorState(const NetworkErrorState& state)
|
||||
{
|
||||
errno = state.error;
|
||||
#ifdef _WIN32
|
||||
WSASetLastError(state.wsa_error);
|
||||
#endif
|
||||
}
|
||||
} // namespace Common
|
||||
|
|
|
@ -99,8 +99,18 @@ struct UDPHeader
|
|||
};
|
||||
static_assert(sizeof(UDPHeader) == UDPHeader::SIZE);
|
||||
|
||||
struct NetworkErrorState
|
||||
{
|
||||
int error;
|
||||
#ifdef _WIN32
|
||||
int wsa_error;
|
||||
#endif
|
||||
};
|
||||
|
||||
MACAddress GenerateMacAddress(MACConsumer type);
|
||||
std::string MacAddressToString(const MACAddress& mac);
|
||||
std::optional<MACAddress> StringToMacAddress(std::string_view mac_string);
|
||||
u16 ComputeNetworkChecksum(const void* data, u16 length, u32 initial_value = 0);
|
||||
NetworkErrorState SaveNetworkErrorState();
|
||||
void RestoreNetworkErrorState(const NetworkErrorState& state);
|
||||
} // namespace Common
|
||||
|
|
|
@ -16,8 +16,11 @@
|
|||
#include <sys/select.h>
|
||||
#endif
|
||||
|
||||
#include "Common/BitUtils.h"
|
||||
#include "Common/FileUtil.h"
|
||||
#include "Common/IOFile.h"
|
||||
#include "Common/Network.h"
|
||||
#include "Common/ScopeGuard.h"
|
||||
#include "Core/Config/MainSettings.h"
|
||||
#include "Core/Core.h"
|
||||
#include "Core/IOS/Device.h"
|
||||
|
@ -224,6 +227,7 @@ s32 WiiSocket::CloseFd()
|
|||
GetIOS()->EnqueueIPCReply(it->request, -SO_ENOTCONN);
|
||||
it = pending_sockops.erase(it);
|
||||
}
|
||||
connecting_state = ConnectingState::None;
|
||||
return ReturnValue;
|
||||
}
|
||||
|
||||
|
@ -278,8 +282,8 @@ void WiiSocket::Update(bool read, bool write, bool except)
|
|||
case IOCTL_SO_BIND:
|
||||
{
|
||||
sockaddr_in local_name;
|
||||
WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(ioctl.buffer_in + 8);
|
||||
WiiSockMan::Convert(*wii_name, local_name);
|
||||
const u8* addr = Memory::GetPointer(ioctl.buffer_in + 8);
|
||||
WiiSockMan::ToNativeAddrIn(addr, &local_name);
|
||||
|
||||
int ret = bind(fd, (sockaddr*)&local_name, sizeof(local_name));
|
||||
ReturnValue = WiiSockMan::GetNetErrorCode(ret, "SO_BIND", false);
|
||||
|
@ -291,11 +295,12 @@ void WiiSocket::Update(bool read, bool write, bool except)
|
|||
case IOCTL_SO_CONNECT:
|
||||
{
|
||||
sockaddr_in local_name;
|
||||
WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(ioctl.buffer_in + 8);
|
||||
WiiSockMan::Convert(*wii_name, local_name);
|
||||
const u8* addr = Memory::GetPointer(ioctl.buffer_in + 8);
|
||||
WiiSockMan::ToNativeAddrIn(addr, &local_name);
|
||||
|
||||
int ret = connect(fd, (sockaddr*)&local_name, sizeof(local_name));
|
||||
ReturnValue = WiiSockMan::GetNetErrorCode(ret, "SO_CONNECT", false);
|
||||
UpdateConnectingState(ReturnValue);
|
||||
|
||||
INFO_LOG_FMT(IOS_NET, "IOCTL_SO_CONNECT ({:08x}, {}:{}) = {}", wii_fd,
|
||||
inet_ntoa(local_name.sin_addr), Common::swap16(local_name.sin_port), ret);
|
||||
|
@ -307,13 +312,13 @@ void WiiSocket::Update(bool read, bool write, bool except)
|
|||
if (ioctl.buffer_out_size > 0)
|
||||
{
|
||||
sockaddr_in local_name;
|
||||
WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(ioctl.buffer_out);
|
||||
WiiSockMan::Convert(*wii_name, local_name);
|
||||
u8* addr = Memory::GetPointer(ioctl.buffer_out);
|
||||
WiiSockMan::ToNativeAddrIn(addr, &local_name);
|
||||
|
||||
socklen_t addrlen = sizeof(sockaddr_in);
|
||||
ret = static_cast<s32>(accept(fd, (sockaddr*)&local_name, &addrlen));
|
||||
|
||||
WiiSockMan::Convert(local_name, *wii_name, addrlen);
|
||||
WiiSockMan::ToWiiAddrIn(local_name, addr, addrlen);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -341,10 +346,12 @@ void WiiSocket::Update(bool read, bool write, bool except)
|
|||
{
|
||||
ReturnValue = -SO_ENETUNREACH;
|
||||
ResetTimeout();
|
||||
connecting_state = ConnectingState::Error;
|
||||
}
|
||||
break;
|
||||
case -SO_EISCONN:
|
||||
ReturnValue = SO_SUCCESS;
|
||||
connecting_state = ConnectingState::Connected;
|
||||
[[fallthrough]];
|
||||
default:
|
||||
ResetTimeout();
|
||||
|
@ -392,6 +399,24 @@ void WiiSocket::Update(bool read, bool write, bool except)
|
|||
{
|
||||
case IOCTLV_NET_SSL_DOHANDSHAKE:
|
||||
{
|
||||
// The Wii allows a socket with an in-progress connection to
|
||||
// perform the SSL handshake. MbedTLS doesn't support it so
|
||||
// we have to check it manually.
|
||||
connecting_state = GetConnectingState();
|
||||
if (connecting_state == ConnectingState::Connecting)
|
||||
{
|
||||
WriteReturnValue(SSL_ERR_RAGAIN, BufferIn);
|
||||
ReturnValue = SSL_ERR_RAGAIN;
|
||||
break;
|
||||
}
|
||||
else if (connecting_state == ConnectingState::None ||
|
||||
connecting_state == ConnectingState::Error)
|
||||
{
|
||||
WriteReturnValue(SSL_ERR_SYSCALL, BufferIn);
|
||||
ReturnValue = SSL_ERR_SYSCALL;
|
||||
break;
|
||||
}
|
||||
|
||||
mbedtls_ssl_context* ctx = &NetSSLDevice::_SSL[sslID].ctx;
|
||||
const int ret = mbedtls_ssl_handshake(ctx);
|
||||
if (ret != 0)
|
||||
|
@ -550,6 +575,16 @@ void WiiSocket::Update(bool read, bool write, bool except)
|
|||
{
|
||||
case IOCTLV_SO_SENDTO:
|
||||
{
|
||||
// The Wii allows a socket with a connection in progress to use
|
||||
// sendto(). This might not be supported by the operating system.
|
||||
// We have to enforce it manually.
|
||||
connecting_state = GetConnectingState();
|
||||
if (nonBlock && IsTCP() && connecting_state == ConnectingState::Connecting)
|
||||
{
|
||||
ReturnValue = -SO_EAGAIN;
|
||||
break;
|
||||
}
|
||||
|
||||
u32 flags = Memory::Read_U32(BufferIn2 + 0x04);
|
||||
u32 has_destaddr = Memory::Read_U32(BufferIn2 + 0x08);
|
||||
|
||||
|
@ -564,8 +599,8 @@ void WiiSocket::Update(bool read, bool write, bool except)
|
|||
sockaddr_in local_name = {0};
|
||||
if (has_destaddr)
|
||||
{
|
||||
WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(BufferIn2 + 0x0C);
|
||||
WiiSockMan::Convert(*wii_name, local_name);
|
||||
const u8* addr = Memory::GetPointer(BufferIn2 + 0x0C);
|
||||
WiiSockMan::ToNativeAddrIn(addr, &local_name);
|
||||
}
|
||||
|
||||
auto* to = has_destaddr ? reinterpret_cast<sockaddr*>(&local_name) : nullptr;
|
||||
|
@ -587,6 +622,16 @@ void WiiSocket::Update(bool read, bool write, bool except)
|
|||
}
|
||||
case IOCTLV_SO_RECVFROM:
|
||||
{
|
||||
// The Wii allows a socket with a connection in progress to use
|
||||
// recvfrom(). This might not be supported by the operating system.
|
||||
// We have to enforce it manually.
|
||||
connecting_state = GetConnectingState();
|
||||
if (nonBlock && IsTCP() && connecting_state == ConnectingState::Connecting)
|
||||
{
|
||||
ReturnValue = -SO_EAGAIN;
|
||||
break;
|
||||
}
|
||||
|
||||
u32 flags = Memory::Read_U32(BufferIn + 0x04);
|
||||
// Not a string, Windows requires a char* for recvfrom
|
||||
char* data = (char*)Memory::GetPointer(BufferOut);
|
||||
|
@ -597,8 +642,8 @@ void WiiSocket::Update(bool read, bool write, bool except)
|
|||
|
||||
if (BufferOutSize2 != 0)
|
||||
{
|
||||
WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(BufferOut2);
|
||||
WiiSockMan::Convert(*wii_name, local_name);
|
||||
const u8* addr = Memory::GetPointer(BufferOut2);
|
||||
WiiSockMan::ToNativeAddrIn(addr, &local_name);
|
||||
}
|
||||
|
||||
// Act as non blocking when SO_MSG_NONBLOCK is specified
|
||||
|
@ -634,8 +679,8 @@ void WiiSocket::Update(bool read, bool write, bool except)
|
|||
|
||||
if (BufferOutSize2 != 0)
|
||||
{
|
||||
WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(BufferOut2);
|
||||
WiiSockMan::Convert(local_name, *wii_name, addrlen);
|
||||
u8* addr = Memory::GetPointer(BufferOut2);
|
||||
WiiSockMan::ToWiiAddrIn(local_name, addr, addrlen);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@ -672,6 +717,112 @@ void WiiSocket::Update(bool read, bool write, bool except)
|
|||
}
|
||||
}
|
||||
|
||||
void WiiSocket::UpdateConnectingState(s32 connect_rv)
|
||||
{
|
||||
if (connect_rv == -SO_EAGAIN || connect_rv == -SO_EALREADY || connect_rv == -SO_EINPROGRESS)
|
||||
{
|
||||
connecting_state = ConnectingState::Connecting;
|
||||
}
|
||||
else if (connect_rv >= 0)
|
||||
{
|
||||
connecting_state = ConnectingState::Connected;
|
||||
}
|
||||
else
|
||||
{
|
||||
connecting_state = ConnectingState::Error;
|
||||
}
|
||||
}
|
||||
|
||||
WiiSocket::ConnectingState WiiSocket::GetConnectingState() const
|
||||
{
|
||||
const auto state = Common::SaveNetworkErrorState();
|
||||
Common::ScopeGuard guard([&state] { Common::RestoreNetworkErrorState(state); });
|
||||
|
||||
#ifdef _WIN32
|
||||
constexpr int (*get_errno)() = &WSAGetLastError;
|
||||
#else
|
||||
constexpr int (*get_errno)() = []() { return errno; };
|
||||
#endif
|
||||
|
||||
switch (connecting_state)
|
||||
{
|
||||
case ConnectingState::Error:
|
||||
case ConnectingState::Connected:
|
||||
case ConnectingState::None:
|
||||
break;
|
||||
case ConnectingState::Connecting:
|
||||
{
|
||||
const s32 nfds = fd + 1;
|
||||
fd_set read_fds;
|
||||
fd_set write_fds;
|
||||
fd_set except_fds;
|
||||
struct timeval t = {0, 0};
|
||||
FD_ZERO(&read_fds);
|
||||
FD_ZERO(&write_fds);
|
||||
FD_ZERO(&except_fds);
|
||||
FD_SET(fd, &write_fds);
|
||||
FD_SET(fd, &except_fds);
|
||||
|
||||
auto& sm = WiiSockMan::GetInstance();
|
||||
if (select(nfds, &read_fds, &write_fds, &except_fds, &t) < 0)
|
||||
{
|
||||
const s32 error = get_errno();
|
||||
ERROR_LOG_FMT(IOS_SSL, "Failed to get socket (fd={}) connection state (err={}): {}", wii_fd,
|
||||
error, sm.DecodeError(error));
|
||||
return ConnectingState::Error;
|
||||
}
|
||||
|
||||
if (FD_ISSET(fd, &write_fds) == 0 && FD_ISSET(fd, &except_fds) == 0)
|
||||
break;
|
||||
|
||||
s32 error = 0;
|
||||
socklen_t len = sizeof(error);
|
||||
if (getsockopt(fd, SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&error), &len) != 0)
|
||||
{
|
||||
error = get_errno();
|
||||
ERROR_LOG_FMT(IOS_SSL, "Failed to get socket (fd={}) error state (err={}): {}", wii_fd, error,
|
||||
sm.DecodeError(error));
|
||||
return ConnectingState::Error;
|
||||
}
|
||||
|
||||
if (error != 0)
|
||||
{
|
||||
ERROR_LOG_FMT(IOS_SSL, "Non-blocking connect (fd={}) failed (err={}): {}", wii_fd, error,
|
||||
sm.DecodeError(error));
|
||||
return ConnectingState::Error;
|
||||
}
|
||||
|
||||
// Get peername to ensure the socket is connected
|
||||
sockaddr_in peer;
|
||||
socklen_t peer_len = sizeof(peer);
|
||||
if (getpeername(fd, reinterpret_cast<sockaddr*>(&peer), &peer_len) != 0)
|
||||
{
|
||||
error = get_errno();
|
||||
ERROR_LOG_FMT(IOS_SSL, "Non-blocking connect (fd={}) failed to get peername (err={}): {}",
|
||||
wii_fd, error, sm.DecodeError(error));
|
||||
return ConnectingState::Error;
|
||||
}
|
||||
|
||||
INFO_LOG_FMT(IOS_SSL, "Non-blocking connect (fd={}) succeeded", wii_fd);
|
||||
return ConnectingState::Connected;
|
||||
}
|
||||
}
|
||||
|
||||
return connecting_state;
|
||||
}
|
||||
|
||||
bool WiiSocket::IsTCP() const
|
||||
{
|
||||
const auto state = Common::SaveNetworkErrorState();
|
||||
Common::ScopeGuard guard([&state] { Common::RestoreNetworkErrorState(state); });
|
||||
|
||||
int socket_type;
|
||||
socklen_t option_length = sizeof(socket_type);
|
||||
return getsockopt(fd, SOL_SOCKET, SO_TYPE, reinterpret_cast<char*>(&socket_type),
|
||||
&option_length) == 0 &&
|
||||
socket_type == SOCK_STREAM;
|
||||
}
|
||||
|
||||
const WiiSocket::Timeout& WiiSocket::GetTimeout()
|
||||
{
|
||||
if (!timeout.has_value())
|
||||
|
@ -937,11 +1088,12 @@ void WiiSockMan::UpdatePollCommands()
|
|||
pending_polls.end());
|
||||
}
|
||||
|
||||
void WiiSockMan::Convert(WiiSockAddrIn const& from, sockaddr_in& to)
|
||||
void WiiSockMan::ToNativeAddrIn(const u8* addr, sockaddr_in* to)
|
||||
{
|
||||
to.sin_addr.s_addr = from.addr.addr;
|
||||
to.sin_family = from.family;
|
||||
to.sin_port = from.port;
|
||||
const WiiSockAddrIn from = Common::BitCastPtr<WiiSockAddrIn>(addr);
|
||||
to->sin_addr.s_addr = from.addr.addr;
|
||||
to->sin_family = from.family;
|
||||
to->sin_port = from.port;
|
||||
}
|
||||
|
||||
s32 WiiSockMan::ConvertEvents(s32 events, ConvertDirection dir)
|
||||
|
@ -981,15 +1133,15 @@ s32 WiiSockMan::ConvertEvents(s32 events, ConvertDirection dir)
|
|||
return converted_events;
|
||||
}
|
||||
|
||||
void WiiSockMan::Convert(sockaddr_in const& from, WiiSockAddrIn& to, s32 addrlen)
|
||||
void WiiSockMan::ToWiiAddrIn(const sockaddr_in& from, u8* to, socklen_t addrlen)
|
||||
{
|
||||
to.addr.addr = from.sin_addr.s_addr;
|
||||
to.family = from.sin_family & 0xFF;
|
||||
to.port = from.sin_port;
|
||||
if (addrlen < 0 || addrlen > static_cast<s32>(sizeof(WiiSockAddrIn)))
|
||||
to.len = sizeof(WiiSockAddrIn);
|
||||
else
|
||||
to.len = addrlen;
|
||||
to[offsetof(WiiSockAddrIn, len)] =
|
||||
u8(addrlen > sizeof(WiiSockAddrIn) ? sizeof(WiiSockAddrIn) : addrlen);
|
||||
to[offsetof(WiiSockAddrIn, family)] = u8(from.sin_family & 0xFF);
|
||||
const u16& from_port = from.sin_port;
|
||||
memcpy(to + offsetof(WiiSockAddrIn, port), &from_port, sizeof(from_port));
|
||||
const u32& from_addr = from.sin_addr.s_addr;
|
||||
memcpy(to + offsetof(WiiSockAddrIn, addr.addr), &from_addr, sizeof(from_addr));
|
||||
}
|
||||
|
||||
void WiiSockMan::DoState(PointerWrap& p)
|
||||
|
|
|
@ -199,6 +199,14 @@ private:
|
|||
void Abort(s32 value);
|
||||
};
|
||||
|
||||
enum class ConnectingState
|
||||
{
|
||||
None,
|
||||
Connecting,
|
||||
Connected,
|
||||
Error
|
||||
};
|
||||
|
||||
friend class WiiSockMan;
|
||||
void SetFd(s32 s);
|
||||
void SetWiiFd(s32 s);
|
||||
|
@ -212,11 +220,15 @@ private:
|
|||
void DoSock(Request request, NET_IOCTL type);
|
||||
void DoSock(Request request, SSL_IOCTL type);
|
||||
void Update(bool read, bool write, bool except);
|
||||
void UpdateConnectingState(s32 connect_rv);
|
||||
ConnectingState GetConnectingState() const;
|
||||
bool IsValid() const { return fd >= 0; }
|
||||
bool IsTCP() const;
|
||||
|
||||
s32 fd = -1;
|
||||
s32 wii_fd = -1;
|
||||
bool nonBlock = false;
|
||||
ConnectingState connecting_state = ConnectingState::None;
|
||||
std::list<sockop> pending_sockops;
|
||||
|
||||
std::optional<Timeout> timeout;
|
||||
|
@ -248,8 +260,9 @@ public:
|
|||
return instance; // Instantiated on first use.
|
||||
}
|
||||
void Update();
|
||||
static void Convert(WiiSockAddrIn const& from, sockaddr_in& to);
|
||||
static void Convert(sockaddr_in const& from, WiiSockAddrIn& to, s32 addrlen = -1);
|
||||
static void ToNativeAddrIn(const u8* from, sockaddr_in* to);
|
||||
static void ToWiiAddrIn(const sockaddr_in& from, u8* to,
|
||||
socklen_t addrlen = sizeof(WiiSockAddrIn));
|
||||
static s32 ConvertEvents(s32 events, ConvertDirection dir);
|
||||
|
||||
void DoState(PointerWrap& p);
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "Common/IOFile.h"
|
||||
#include "Common/Network.h"
|
||||
#include "Common/PcapFile.h"
|
||||
#include "Common/ScopeGuard.h"
|
||||
#include "Core/Config/MainSettings.h"
|
||||
#include "Core/ConfigManager.h"
|
||||
|
||||
|
@ -90,24 +91,6 @@ void PCAPSSLCaptureLogger::OnNewSocket(s32 socket)
|
|||
m_write_sequence_number[socket] = 0;
|
||||
}
|
||||
|
||||
PCAPSSLCaptureLogger::ErrorState PCAPSSLCaptureLogger::SaveState() const
|
||||
{
|
||||
return {
|
||||
errno,
|
||||
#ifdef _WIN32
|
||||
WSAGetLastError(),
|
||||
#endif
|
||||
};
|
||||
}
|
||||
|
||||
void PCAPSSLCaptureLogger::RestoreState(const PCAPSSLCaptureLogger::ErrorState& state) const
|
||||
{
|
||||
errno = state.error;
|
||||
#ifdef _WIN32
|
||||
WSASetLastError(state.wsa_error);
|
||||
#endif
|
||||
}
|
||||
|
||||
void PCAPSSLCaptureLogger::LogSSLRead(const void* data, std::size_t length, s32 socket)
|
||||
{
|
||||
if (!Config::Get(Config::MAIN_NETWORK_SSL_DUMP_READ))
|
||||
|
@ -135,7 +118,8 @@ void PCAPSSLCaptureLogger::LogWrite(const void* data, std::size_t length, s32 so
|
|||
void PCAPSSLCaptureLogger::Log(LogType log_type, const void* data, std::size_t length, s32 socket,
|
||||
sockaddr* other)
|
||||
{
|
||||
const auto state = SaveState();
|
||||
const auto state = Common::SaveNetworkErrorState();
|
||||
Common::ScopeGuard guard([&state] { Common::RestoreNetworkErrorState(state); });
|
||||
sockaddr_in sock;
|
||||
sockaddr_in peer;
|
||||
sockaddr_in* from;
|
||||
|
@ -144,16 +128,10 @@ void PCAPSSLCaptureLogger::Log(LogType log_type, const void* data, std::size_t l
|
|||
socklen_t peer_len = sizeof(sock);
|
||||
|
||||
if (getsockname(socket, reinterpret_cast<sockaddr*>(&sock), &sock_len) != 0)
|
||||
{
|
||||
RestoreState(state);
|
||||
return;
|
||||
}
|
||||
|
||||
if (other == nullptr && getpeername(socket, reinterpret_cast<sockaddr*>(&peer), &peer_len) != 0)
|
||||
{
|
||||
RestoreState(state);
|
||||
return;
|
||||
}
|
||||
|
||||
if (log_type == LogType::Read)
|
||||
{
|
||||
|
@ -168,7 +146,6 @@ void PCAPSSLCaptureLogger::Log(LogType log_type, const void* data, std::size_t l
|
|||
|
||||
LogIPv4(log_type, reinterpret_cast<const u8*>(data), static_cast<u16>(length), socket, *from,
|
||||
*to);
|
||||
RestoreState(state);
|
||||
}
|
||||
|
||||
void PCAPSSLCaptureLogger::LogIPv4(LogType log_type, const u8* data, u16 length, s32 socket,
|
||||
|
|
|
@ -99,15 +99,6 @@ private:
|
|||
Read,
|
||||
Write,
|
||||
};
|
||||
struct ErrorState
|
||||
{
|
||||
int error;
|
||||
#ifdef _WIN32
|
||||
int wsa_error;
|
||||
#endif
|
||||
};
|
||||
ErrorState SaveState() const;
|
||||
void RestoreState(const ErrorState& state) const;
|
||||
|
||||
void Log(LogType log_type, const void* data, std::size_t length, s32 socket, sockaddr* other);
|
||||
void LogIPv4(LogType log_type, const u8* data, u16 length, s32 socket, const sockaddr_in& from,
|
||||
|
|
Loading…
Reference in New Issue