IOS/USB: Move scan thread logic into a separate class

This moves the scan thread logic and variables into a separate
ScanThread class. By turning ScanThread instances into members of the
most derived class, this ensures that the scan thread is always
properly stopped when the most derived class is destructed and fixes
a race condition that could cause the scan thread to call virtual
member functions from a derived class whose members have already
been destructed.

A drawback of this approach is that the scan thread must be the last
member variable, so this commit also adds static assertions to ensure
that the assumption stays valid.
This commit is contained in:
Léo Lam 2020-08-17 14:41:35 +02:00
parent 6104018fe1
commit d1439a1fa9
No known key found for this signature in database
GPG Key ID: 0DF30F9081000741
10 changed files with 70 additions and 27 deletions

View File

@ -31,19 +31,16 @@ USBHost::USBHost(Kernel& ios, const std::string& device_name) : Device(ios, devi
{
}
USBHost::~USBHost()
{
StopThreads();
}
USBHost::~USBHost() = default;
IPCCommandResult USBHost::Open(const OpenRequest& request)
{
if (!m_has_initialised && !Core::WantsDeterminism())
{
StartThreads();
GetScanThread().Start();
// Force a device scan to complete, because some games (including Your Shape) only care
// about the initial device list (in the first GETDEVICECHANGE reply).
m_first_scan_complete_event.Wait();
GetScanThread().WaitForFirstScan();
m_has_initialised = true;
}
return GetDefaultReply(IPC_SUCCESS);
@ -52,9 +49,9 @@ IPCCommandResult USBHost::Open(const OpenRequest& request)
void USBHost::UpdateWantDeterminism(const bool new_want_determinism)
{
if (new_want_determinism)
StopThreads();
GetScanThread().Stop();
else if (IsOpened())
StartThreads();
GetScanThread().Start();
}
void USBHost::DoState(PointerWrap& p)
@ -112,7 +109,6 @@ bool USBHost::UpdateDevices(const bool always_add_hooks)
return false;
DetectRemovedDevices(plugged_devices, hooks);
DispatchHooks(hooks);
m_first_scan_complete_event.Set();
return true;
}
@ -177,33 +173,44 @@ void USBHost::DispatchHooks(const DeviceChangeHooks& hooks)
OnDeviceChangeEnd();
}
void USBHost::StartThreads()
USBHost::ScanThread::~ScanThread()
{
Stop();
}
void USBHost::ScanThread::WaitForFirstScan()
{
m_first_scan_complete_event.Wait();
}
void USBHost::ScanThread::Start()
{
if (Core::WantsDeterminism())
return;
if (m_scan_thread_running.TestAndSet())
if (m_thread_running.TestAndSet())
{
m_scan_thread = std::thread([this] {
m_thread = std::thread([this] {
Common::SetCurrentThreadName("USB Scan Thread");
while (m_scan_thread_running.IsSet())
while (m_thread_running.IsSet())
{
UpdateDevices();
if (m_host->UpdateDevices())
m_first_scan_complete_event.Set();
Common::SleepCurrentThread(50);
}
});
}
}
void USBHost::StopThreads()
void USBHost::ScanThread::Stop()
{
if (m_scan_thread_running.TestAndClear())
m_scan_thread.join();
if (m_thread_running.TestAndClear())
m_thread.join();
// Clear all devices and dispatch removal hooks.
DeviceChangeHooks hooks;
DetectRemovedDevices(std::set<u64>(), hooks);
DispatchHooks(hooks);
m_host->DetectRemovedDevices(std::set<u64>(), hooks);
m_host->DispatchHooks(hooks);
}
IPCCommandResult USBHost::HandleTransfer(std::shared_ptr<USB::Device> device, u32 request,

View File

@ -46,6 +46,23 @@ protected:
};
using DeviceChangeHooks = std::map<std::shared_ptr<USB::Device>, ChangeEvent>;
class ScanThread final
{
public:
explicit ScanThread(USBHost* host) : m_host(host) {}
~ScanThread();
void Start();
void Stop();
void WaitForFirstScan();
private:
USBHost* m_host = nullptr;
Common::Flag m_thread_running;
std::thread m_thread;
Common::Event m_first_scan_complete_event;
Common::Flag m_is_initialized;
};
std::map<u64, std::shared_ptr<USB::Device>> m_devices;
mutable std::mutex m_devices_mutex;
@ -53,25 +70,18 @@ protected:
virtual void OnDeviceChange(ChangeEvent event, std::shared_ptr<USB::Device> changed_device);
virtual void OnDeviceChangeEnd();
virtual bool ShouldAddDevice(const USB::Device& device) const;
virtual ScanThread& GetScanThread() = 0;
IPCCommandResult HandleTransfer(std::shared_ptr<USB::Device> device, u32 request,
std::function<s32()> submit) const;
private:
void StartThreads();
void StopThreads();
bool AddDevice(std::unique_ptr<USB::Device> device);
bool UpdateDevices(bool always_add_hooks = false);
bool AddNewDevices(std::set<u64>& new_devices, DeviceChangeHooks& hooks, bool always_add_hooks);
void DetectRemovedDevices(const std::set<u64>& plugged_devices, DeviceChangeHooks& hooks);
void DispatchHooks(const DeviceChangeHooks& hooks);
// Device scanning thread
Common::Flag m_scan_thread_running;
std::thread m_scan_thread;
Common::Event m_first_scan_complete_event;
bool m_has_initialised = false;
LibusbUtils::Context m_context;
};

View File

@ -24,6 +24,8 @@ namespace IOS::HLE::Device
{
OH0::OH0(Kernel& ios, const std::string& device_name) : USBHost(ios, device_name)
{
static_assert(offsetof(OH0, m_scan_thread) == sizeof(OH0) - sizeof(ScanThread),
"ScanThread must be the last data member");
}
IPCCommandResult OH0::Open(const OpenRequest& request)

View File

@ -66,6 +66,8 @@ private:
template <typename T>
void TriggerHook(std::map<T, u32>& hooks, T value, ReturnCode return_value);
ScanThread& GetScanThread() override { return m_scan_thread; }
struct DeviceEntry
{
u32 unknown;
@ -79,6 +81,8 @@ private:
std::map<u64, u32> m_removal_hooks;
std::set<u64> m_opened_devices;
std::mutex m_hooks_mutex;
ScanThread m_scan_thread{this};
};
} // namespace Device
} // namespace IOS::HLE

View File

@ -24,6 +24,8 @@ namespace IOS::HLE::Device
{
USB_HIDv4::USB_HIDv4(Kernel& ios, const std::string& device_name) : USBHost(ios, device_name)
{
static_assert(offsetof(USB_HIDv4, m_scan_thread) == sizeof(USB_HIDv4) - sizeof(ScanThread),
"ScanThread must be the last data member");
}
IPCCommandResult USB_HIDv4::IOCtl(const IOCtlRequest& request)

View File

@ -39,6 +39,7 @@ private:
std::vector<u8> GetDeviceEntry(const USB::Device& device) const;
void OnDeviceChange(ChangeEvent, std::shared_ptr<USB::Device>) override;
bool ShouldAddDevice(const USB::Device& device) const override;
ScanThread& GetScanThread() override { return m_scan_thread; }
static constexpr u32 VERSION = 0x40001;
static constexpr u8 HID_CLASS = 0x03;
@ -51,5 +52,7 @@ private:
// IOS device IDs <=> USB device IDs
std::map<s32, u64> m_ios_ids;
std::map<u64, s32> m_device_ids;
ScanThread m_scan_thread{this};
};
} // namespace IOS::HLE::Device

View File

@ -22,6 +22,9 @@ USB_HIDv5::~USB_HIDv5() = default;
IPCCommandResult USB_HIDv5::IOCtl(const IOCtlRequest& request)
{
static_assert(offsetof(USB_HIDv5, m_scan_thread) == sizeof(USB_HIDv5) - sizeof(ScanThread),
"ScanThread must be the last data member");
request.Log(GetDeviceName(), Common::Log::IOS_USB);
switch (request.request)
{

View File

@ -27,11 +27,16 @@ private:
bool ShouldAddDevice(const USB::Device& device) const override;
bool HasInterfaceNumberInIDs() const override { return true; }
ScanThread& GetScanThread() override { return m_scan_thread; }
struct AdditionalDeviceData
{
u8 interrupt_in_endpoint = 0;
u8 interrupt_out_endpoint = 0;
};
std::array<AdditionalDeviceData, 32> m_additional_device_data{};
ScanThread m_scan_thread{this};
};
} // namespace IOS::HLE::Device

View File

@ -22,6 +22,9 @@ USB_VEN::~USB_VEN() = default;
IPCCommandResult USB_VEN::IOCtl(const IOCtlRequest& request)
{
static_assert(offsetof(USB_VEN, m_scan_thread) == sizeof(USB_VEN) - sizeof(ScanThread),
"ScanThread must be the last data member");
request.Log(GetDeviceName(), Common::Log::IOS_USB);
switch (request.request)
{

View File

@ -26,5 +26,9 @@ private:
s32 SubmitTransfer(USB::Device& device, const IOCtlVRequest& ioctlv);
bool HasInterfaceNumberInIDs() const override { return false; }
ScanThread& GetScanThread() override { return m_scan_thread; }
ScanThread m_scan_thread{this};
};
} // namespace IOS::HLE::Device