From 20e5c43bc8bec8140cdf6d7a31f25218ef29f7eb Mon Sep 17 00:00:00 2001
From: YoshiRulz <OSSYoshiRulz@gmail.com>
Date: Tue, 18 Apr 2023 19:15:02 +1000
Subject: [PATCH] Add `--socket_udp` CLI flag

---
 src/BizHawk.Client.Common/Api/SocketServer.cs | 13 ++++++++++---
 src/BizHawk.Client.Common/ArgParser.cs        |  7 +++++++
 src/BizHawk.Client.Common/ParsedCLIFlags.cs   |  5 +++++
 src/BizHawk.Client.EmuHawk/MainForm.cs        |  2 +-
 4 files changed, 23 insertions(+), 4 deletions(-)

diff --git a/src/BizHawk.Client.Common/Api/SocketServer.cs b/src/BizHawk.Client.Common/Api/SocketServer.cs
index fbac3c706a..ba8e7b7e64 100644
--- a/src/BizHawk.Client.Common/Api/SocketServer.cs
+++ b/src/BizHawk.Client.Common/Api/SocketServer.cs
@@ -16,9 +16,11 @@ namespace BizHawk.Client.Common
 			=> Encoding.ASCII.GetBytes(payload.Length.ToString()).Concat(LENGTH_PREFIX_SEPARATOR).ToArray()
 				.ConcatArray(payload);
 
+		private readonly ProtocolType _protocol;
+
 		private IPEndPoint _remoteEp;
 
-		private Socket _soc = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+		private Socket _soc;
 
 		private readonly Func<byte[]> _takeScreenshotCallback;
 
@@ -64,16 +66,21 @@ namespace BizHawk.Client.Common
 
 		public bool Successful { get; private set; }
 
-		public SocketServer(Func<byte[]> takeScreenshotCallback, string ip, int port)
+		public SocketServer(Func<byte[]> takeScreenshotCallback, ProtocolType protocol, string ip, int port)
 		{
+			_protocol = protocol;
+			ReinitSocket(out _soc);
 			_takeScreenshotCallback = takeScreenshotCallback;
 			TargetAddress = (ip, port);
 		}
 
+		private void ReinitSocket(out Socket socket)
+			=> socket = new(AddressFamily.InterNetwork, SocketType.Stream, _protocol);
+
 		private void Connect()
 		{
 			_remoteEp = new IPEndPoint(IPAddress.Parse(_targetAddr.HostIP), _targetAddr.Port);
-			_soc = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+			ReinitSocket(out _soc);
 			_soc.Connect(_remoteEp);
 			Connected = true;
 		}
diff --git a/src/BizHawk.Client.Common/ArgParser.cs b/src/BizHawk.Client.Common/ArgParser.cs
index 29afd16609..3b8c61d652 100644
--- a/src/BizHawk.Client.Common/ArgParser.cs
+++ b/src/BizHawk.Client.Common/ArgParser.cs
@@ -4,6 +4,7 @@ using System;
 using System.Collections.Generic;
 using System.Linq;
 using System.IO;
+using System.Net.Sockets;
 
 using BizHawk.Common.CollectionExtensions;
 
@@ -42,6 +43,7 @@ namespace BizHawk.Client.Common
 			string? urlPost = null;
 			bool? audiosync = null;
 			string? openExtToolDll = null;
+			var socketProtocol = ProtocolType.Tcp;
 			List<(string Key, string Value)>? userdataUnparsedPairs = null;
 			string? cmdRom = null;
 
@@ -135,6 +137,10 @@ namespace BizHawk.Client.Common
 				{
 					socketIP = argDowncased.Substring(argDowncased.IndexOf('=') + 1);
 				}
+				else if (argDowncased.StartsWith("--socket_udp"))
+				{
+					socketProtocol = ProtocolType.Udp;
+				}
 				else if (argDowncased.StartsWith("--mmf="))
 				{
 					mmfFilename = arg.Substring(arg.IndexOf('=') + 1);
@@ -211,6 +217,7 @@ namespace BizHawk.Client.Common
 				httpAddresses: httpAddresses,
 				audiosync: audiosync,
 				openExtToolDll: openExtToolDll,
+				socketProtocol: socketProtocol,
 				userdataUnparsedPairs: userdataUnparsedPairs,
 				cmdRom: cmdRom
 			);
diff --git a/src/BizHawk.Client.Common/ParsedCLIFlags.cs b/src/BizHawk.Client.Common/ParsedCLIFlags.cs
index 50def138eb..d2f0c0338f 100644
--- a/src/BizHawk.Client.Common/ParsedCLIFlags.cs
+++ b/src/BizHawk.Client.Common/ParsedCLIFlags.cs
@@ -1,6 +1,7 @@
 #nullable enable
 
 using System.Collections.Generic;
+using System.Net.Sockets;
 
 namespace BizHawk.Client.Common
 {
@@ -36,6 +37,8 @@ namespace BizHawk.Client.Common
 
 		public readonly (string IP, int Port)? SocketAddress;
 
+		public readonly ProtocolType SocketProtocol;
+
 		public readonly IReadOnlyList<(string Key, string Value)>? UserdataUnparsedPairs;
 
 		public readonly string? MMFFilename;
@@ -68,6 +71,7 @@ namespace BizHawk.Client.Common
 			(string? UrlGet, string? UrlPost)? httpAddresses,
 			bool? audiosync,
 			string? openExtToolDll,
+			ProtocolType socketProtocol,
 			IReadOnlyList<(string Key, string Value)>? userdataUnparsedPairs,
 			string? cmdRom)
 		{
@@ -90,6 +94,7 @@ namespace BizHawk.Client.Common
 			HTTPAddresses = httpAddresses;
 			this.audiosync = audiosync;
 			this.openExtToolDll = openExtToolDll;
+			SocketProtocol = socketProtocol;
 			UserdataUnparsedPairs = userdataUnparsedPairs;
 			this.cmdRom = cmdRom;
 		}
diff --git a/src/BizHawk.Client.EmuHawk/MainForm.cs b/src/BizHawk.Client.EmuHawk/MainForm.cs
index 5190f6cb54..6ea67c3e4c 100644
--- a/src/BizHawk.Client.EmuHawk/MainForm.cs
+++ b/src/BizHawk.Client.EmuHawk/MainForm.cs
@@ -466,7 +466,7 @@ namespace BizHawk.Client.EmuHawk
 					: null,
 				new MemoryMappedFiles(NetworkingTakeScreenshot, _argParser.MMFFilename),
 				_argParser.SocketAddress is var (socketIP, socketPort)
-					? new SocketServer(NetworkingTakeScreenshot, socketIP, socketPort)
+					? new SocketServer(NetworkingTakeScreenshot, _argParser.SocketProtocol, socketIP, socketPort)
 					: null
 			);