diff --git a/src/frontend-common/CMakeLists.txt b/src/frontend-common/CMakeLists.txt index 48a5f6c7d..d4d323aa0 100644 --- a/src/frontend-common/CMakeLists.txt +++ b/src/frontend-common/CMakeLists.txt @@ -62,6 +62,8 @@ if(NOT BUILD_LIBRETRO_CORE) if(WIN32) target_sources(frontend-common PRIVATE + dinput_controller_interface.cpp + dinput_controller_interface.h xinput_controller_interface.cpp xinput_controller_interface.h ) diff --git a/src/frontend-common/controller_interface.cpp b/src/frontend-common/controller_interface.cpp index c27f093fc..9a0730457 100644 --- a/src/frontend-common/controller_interface.cpp +++ b/src/frontend-common/controller_interface.cpp @@ -89,6 +89,7 @@ static constexpr std::array(ControllerInterface::B #endif #ifdef WIN32 TRANSLATABLE("ControllerInterface", "XInput"), + TRANSLATABLE("ControllerInterface", "DInput"), #endif }}; @@ -124,6 +125,7 @@ ControllerInterface::Backend ControllerInterface::GetDefaultBackend() #include "sdl_controller_interface.h" #endif #ifdef WIN32 +#include "dinput_controller_interface.h" #include "xinput_controller_interface.h" #endif @@ -136,6 +138,8 @@ std::unique_ptr ControllerInterface::Create(Backend type) #ifdef WIN32 if (type == Backend::XInput) return std::make_unique(); + if (type == Backend::DInput) + return std::make_unique(); #endif return {}; diff --git a/src/frontend-common/controller_interface.h b/src/frontend-common/controller_interface.h index d40c167e5..5b012b264 100644 --- a/src/frontend-common/controller_interface.h +++ b/src/frontend-common/controller_interface.h @@ -22,6 +22,7 @@ public: #endif #ifdef WIN32 XInput, + DInput, #endif Count }; @@ -29,7 +30,12 @@ public: enum : int { MAX_NUM_AXISES = 7, - MAX_NUM_BUTTONS = 15 + MAX_NUM_BUTTONS = 15, + NUM_HAT_DIRECTIONS = 4, + HAT_DIRECTION_UP = 0, + HAT_DIRECTION_DOWN = 1, + HAT_DIRECTION_LEFT = 2, + HAT_DIRECTION_RIGHT = 3, }; enum AxisSide diff --git a/src/frontend-common/dinput_controller_interface.cpp b/src/frontend-common/dinput_controller_interface.cpp new file mode 100644 index 000000000..292f04c13 --- /dev/null +++ b/src/frontend-common/dinput_controller_interface.cpp @@ -0,0 +1,448 @@ +#define INITGUID + +#include "dinput_controller_interface.h" +#include "common/assert.h" +#include "common/file_system.h" +#include "common/log.h" +#include "common/string_util.h" +#include "core/controller.h" +#include "core/host_interface.h" +#include "core/system.h" +#include +#include +Log_SetChannel(DInputControllerInterface); + +using PFNDIRECTINPUT8CREATE = HRESULT(WINAPI*)(HINSTANCE hinst, DWORD dwVersion, REFIID riidltf, LPVOID* ppvOut, + LPUNKNOWN punkOuter); +using PFNGETDFDIJOYSTICK = LPCDIDATAFORMAT(WINAPI*)(); + +DInputControllerInterface::DInputControllerInterface() = default; + +DInputControllerInterface::~DInputControllerInterface() +{ + if (m_dinput_module) + FreeLibrary(m_dinput_module); +} + +ControllerInterface::Backend DInputControllerInterface::GetBackend() const +{ + return ControllerInterface::Backend::XInput; +} + +bool DInputControllerInterface::Initialize(CommonHostInterface* host_interface) +{ + m_dinput_module = LoadLibraryW(L"dinput8"); + if (!m_dinput_module) + { + Log_ErrorPrintf("Failed to load DInput module."); + return false; + } + + PFNDIRECTINPUT8CREATE create = + reinterpret_cast(GetProcAddress(m_dinput_module, "DirectInput8Create")); + PFNGETDFDIJOYSTICK get_joystick_data_format = + reinterpret_cast(GetProcAddress(m_dinput_module, "GetdfDIJoystick")); + if (!create || !get_joystick_data_format) + { + Log_ErrorPrintf("Failed to get DInput function pointers."); + return false; + } + + if (!ControllerInterface::Initialize(host_interface)) + return false; + + HRESULT hr = create(GetModuleHandleA(nullptr), DIRECTINPUT_VERSION, IID_IDirectInput8A, + reinterpret_cast(m_dinput.GetAddressOf()), nullptr); + m_joystick_data_format = get_joystick_data_format(); + if (FAILED(hr) || !m_joystick_data_format) + { + Log_ErrorPrintf("DirectInput8Create() failed: %08X", hr); + return false; + } + + AddDevices(); + + return true; +} + +void DInputControllerInterface::Shutdown() +{ + ControllerInterface::Shutdown(); +} + +static BOOL CALLBACK EnumCallback(LPCDIDEVICEINSTANCE lpddi, LPVOID pvRef) +{ + static_cast*>(pvRef)->push_back(*lpddi); + return DIENUM_CONTINUE; +} + +void DInputControllerInterface::AddDevices() +{ + std::vector devices; + m_dinput->EnumDevices(DI8DEVCLASS_GAMECTRL, EnumCallback, &devices, DIEDFL_ATTACHEDONLY); + + Log_InfoPrintf("Enumerated %zud evices", devices.size()); + + for (DIDEVICEINSTANCE inst : devices) + { + ControllerData cd; + HRESULT hr = m_dinput->CreateDevice(inst.guidInstance, cd.device.GetAddressOf(), nullptr); + if (FAILED(hr)) + { + Log_WarningPrintf("Failed to create instance of device [%s, %s]", inst.tszProductName, inst.tszInstanceName); + continue; + } + + if (AddDevice(cd, inst.tszProductName)) + m_controllers.push_back(std::move(cd)); + } +} + +bool DInputControllerInterface::AddDevice(ControllerData& cd, const char* name) +{ + HRESULT hr = cd.device->SetCooperativeLevel(static_cast(m_host_interface->GetTopLevelWindowHandle()), + DISCL_BACKGROUND | DISCL_EXCLUSIVE); + if (FAILED(hr)) + { + hr = cd.device->SetCooperativeLevel(static_cast(m_host_interface->GetTopLevelWindowHandle()), + DISCL_BACKGROUND | DISCL_NONEXCLUSIVE); + if (FAILED(hr)) + { + Log_ErrorPrintf("Failed to set cooperative level for '%s'", name); + return false; + } + + Log_WarningPrintf("Failed to set exclusive mode for '%s'", name); + } + + hr = cd.device->SetDataFormat(m_joystick_data_format); + if (FAILED(hr)) + { + Log_ErrorPrintf("Failed to set data format for '%s'", name); + return false; + } + + hr = cd.device->Acquire(); + if (FAILED(hr)) + { + Log_ErrorPrintf("Failed to acquire device '%s'", name); + return false; + } + + DIDEVCAPS caps = {}; + caps.dwSize = sizeof(caps); + hr = cd.device->GetCapabilities(&caps); + if (FAILED(hr)) + { + Log_ErrorPrintf("Failed to get capabilities for '%s'", name); + return false; + } + + cd.num_buttons = caps.dwButtons; + if (cd.num_buttons > NUM_BUTTONS) + { + Log_WarningPrintf("Device '%s' has too many buttons (%u), using %u instead.", name, cd.num_buttons, NUM_BUTTONS); + cd.num_buttons = NUM_BUTTONS; + } + + static constexpr std::array axis_offsets = { + {offsetof(DIJOYSTATE, lX), offsetof(DIJOYSTATE, lY), offsetof(DIJOYSTATE, lZ), offsetof(DIJOYSTATE, lRz), + offsetof(DIJOYSTATE, lRx), offsetof(DIJOYSTATE, lRy), offsetof(DIJOYSTATE, rglSlider[0]), + offsetof(DIJOYSTATE, rglSlider[1])}}; + for (u32 i = 0; i < NUM_AXISES; i++) + { + // ask for 16 bits of axis range + DIPROPRANGE range = {}; + range.diph.dwSize = sizeof(range); + range.diph.dwHeaderSize = sizeof(range.diph); + range.diph.dwHow = DIPH_BYOFFSET; + range.diph.dwObj = axis_offsets[i]; + range.lMin = std::numeric_limits::min(); + range.lMax = std::numeric_limits::max(); + hr = cd.device->SetProperty(DIPROP_RANGE, &range.diph); + + // did it apply? + if (SUCCEEDED(cd.device->GetProperty(DIPROP_RANGE, &range.diph))) + { + cd.axis_offsets[cd.num_axes] = axis_offsets[i]; + cd.num_axes++; + } + } + + cd.has_hat = (caps.dwPOVs > 0); + + hr = cd.device->Poll(); + if (hr == DI_NOEFFECT) + cd.needs_poll = false; + else if (hr != DI_OK) + Log_WarningPrintf("Polling device '%s' failed: %08X", name, hr); + + hr = cd.device->GetDeviceState(sizeof(cd.last_state), &cd.last_state); + if (hr != DI_OK) + Log_WarningPrintf("GetDeviceState() for '%s' failed: %08X", name, hr); + + Log_InfoPrintf("%s has %u buttons, %u axes%s", name, cd.num_buttons, cd.num_axes, cd.has_hat ? ", and a hat" : ""); + + return (cd.num_buttons > 0 || cd.num_axes > 0 || cd.has_hat); +} + +void DInputControllerInterface::PollEvents() +{ + for (u32 i = 0; i < static_cast(m_controllers.size()); i++) + { + ControllerData& cd = m_controllers[i]; + if (!cd.device) + continue; + + if (cd.needs_poll) + cd.device->Poll(); + + DIJOYSTATE js; + HRESULT hr = cd.device->GetDeviceState(sizeof(js), &js); + if (hr == DIERR_INPUTLOST || hr == DIERR_NOTACQUIRED) + { + hr = cd.device->Acquire(); + if (hr == DI_OK) + hr = cd.device->GetDeviceState(sizeof(js), &js); + + if (hr != DI_OK) + { + cd = {}; + OnControllerDisconnected(static_cast(i)); + continue; + } + } + else if (hr != DI_OK) + { + Log_WarningPrintf("GetDeviceState() failed: %08X", hr); + continue; + } + + CheckForStateChanges(i, js); + } +} + +u32 DInputControllerInterface::GetHatDirection(DWORD hat) +{ + const WORD hv = LOWORD(hat); + if (hv == 0xFFFF) + return NUM_HAT_DIRECTIONS; + else if (hv < 9000) + return HAT_DIRECTION_UP; + else if (hv < 18000) + return HAT_DIRECTION_RIGHT; + else if (hv < 27000) + return HAT_DIRECTION_DOWN; + else + return HAT_DIRECTION_LEFT; +} + +void DInputControllerInterface::CheckForStateChanges(u32 index, const DIJOYSTATE& new_state) +{ + ControllerData& cd = m_controllers[index]; + DIJOYSTATE& last_state = cd.last_state; + + for (u32 i = 0; i < cd.num_axes; i++) + { + LONG new_value; + LONG old_value; + std::memcpy(&old_value, reinterpret_cast(&cd.last_state) + cd.axis_offsets[i], sizeof(old_value)); + std::memcpy(&new_value, reinterpret_cast(&new_state) + cd.axis_offsets[i], sizeof(new_value)); + if (old_value != new_value) + { + HandleAxisEvent(index, i, new_value); + std::memcpy(reinterpret_cast(&cd.last_state) + cd.axis_offsets[i], &new_value, sizeof(new_value)); + } + } + + for (u32 i = 0; i < cd.num_buttons; i++) + { + if (last_state.rgbButtons[i] != new_state.rgbButtons[i]) + { + HandleButtonEvent(index, i, new_state.rgbButtons[i] != 0); + last_state.rgbButtons[i] = new_state.rgbButtons[i]; + } + } + + if (cd.has_hat) + { + if (last_state.rgdwPOV[0] != new_state.rgdwPOV[0]) + { + // map hats to the last buttons + const u32 old_direction = GetHatDirection(last_state.rgdwPOV[0]); + if (old_direction != NUM_HAT_DIRECTIONS) + HandleButtonEvent(index, cd.num_buttons + old_direction, false); + + const u32 new_direction = GetHatDirection(new_state.rgdwPOV[0]); + if (new_direction != NUM_HAT_DIRECTIONS) + HandleButtonEvent(index, cd.num_buttons + new_direction, true); + last_state.rgdwPOV[0] = new_state.rgdwPOV[0]; + } + } +} + +void DInputControllerInterface::ClearBindings() +{ + for (ControllerData& cd : m_controllers) + { + cd.axis_mapping.fill({}); + cd.button_mapping.fill({}); + cd.axis_button_mapping.fill({}); + cd.button_axis_mapping.fill({}); + } +} + +bool DInputControllerInterface::BindControllerAxis(int controller_index, int axis_number, AxisSide axis_side, + AxisCallback callback) +{ + if (static_cast(controller_index) >= m_controllers.size()) + return false; + + if (axis_number < 0 || axis_number >= NUM_AXISES) + return false; + + m_controllers[controller_index].axis_mapping[axis_number][axis_side] = std::move(callback); + return true; +} + +bool DInputControllerInterface::BindControllerButton(int controller_index, int button_number, ButtonCallback callback) +{ + if (static_cast(controller_index) >= m_controllers.size()) + return false; + + if (button_number < 0 || button_number >= TOTAL_NUM_BUTTONS) + return false; + + m_controllers[controller_index].button_mapping[button_number] = std::move(callback); + return true; +} + +bool DInputControllerInterface::BindControllerAxisToButton(int controller_index, int axis_number, bool direction, + ButtonCallback callback) +{ + if (static_cast(controller_index) >= m_controllers.size()) + return false; + + if (axis_number < 0 || axis_number >= NUM_AXISES) + return false; + + m_controllers[controller_index].axis_button_mapping[axis_number][BoolToUInt8(direction)] = std::move(callback); + return true; +} + +bool DInputControllerInterface::BindControllerHatToButton(int controller_index, int hat_number, + std::string_view hat_position, ButtonCallback callback) +{ + // Hats don't exist in XInput + return false; +} + +bool DInputControllerInterface::BindControllerButtonToAxis(int controller_index, int button_number, + AxisCallback callback) +{ + if (static_cast(controller_index) >= m_controllers.size()) + return false; + + if (button_number < 0 || button_number >= TOTAL_NUM_BUTTONS) + return false; + + m_controllers[controller_index].button_axis_mapping[button_number] = std::move(callback); + return true; +} + +bool DInputControllerInterface::HandleAxisEvent(u32 index, u32 axis, s32 value) +{ + const float f_value = static_cast(value) / (value < 0 ? 32768.0f : 32767.0f); + Log_DevPrintf("controller %u axis %u %d %f", index, static_cast(axis), value, f_value); + DebugAssert(index < m_controllers.size()); + + if (DoEventHook(Hook::Type::Axis, index, static_cast(axis), f_value)) + return true; + + const AxisCallback& cb = m_controllers[index].axis_mapping[static_cast(axis)][AxisSide::Full]; + if (cb) + { + // Extend triggers from a 0 - 1 range to a -1 - 1 range for consistency with other inputs + if (axis == 4 || axis == 5) + { + cb((f_value * 2.0f) - 1.0f); + } + else + { + cb(f_value); + } + return true; + } + + // set the other direction to false so large movements don't leave the opposite on + const bool outside_deadzone = (std::abs(f_value) >= m_controllers[index].deadzone); + const bool positive = (f_value >= 0.0f); + const ButtonCallback& other_button_cb = + m_controllers[index].axis_button_mapping[static_cast(axis)][BoolToUInt8(!positive)]; + const ButtonCallback& button_cb = + m_controllers[index].axis_button_mapping[static_cast(axis)][BoolToUInt8(positive)]; + if (button_cb) + { + button_cb(outside_deadzone); + if (other_button_cb) + other_button_cb(false); + return true; + } + else if (other_button_cb) + { + other_button_cb(false); + return true; + } + else + { + return false; + } +} + +bool DInputControllerInterface::HandleButtonEvent(u32 index, u32 button, bool pressed) +{ + Log_DevPrintf("controller %u button %u %s", index, button, pressed ? "pressed" : "released"); + DebugAssert(index < m_controllers.size()); + + if (DoEventHook(Hook::Type::Button, index, button, pressed ? 1.0f : 0.0f)) + return true; + + const ButtonCallback& cb = m_controllers[index].button_mapping[button]; + if (cb) + { + cb(pressed); + return true; + } + + const AxisCallback& axis_cb = m_controllers[index].button_axis_mapping[button]; + if (axis_cb) + { + axis_cb(pressed ? 1.0f : -1.0f); + } + return true; +} + +u32 DInputControllerInterface::GetControllerRumbleMotorCount(int controller_index) +{ + if (static_cast(controller_index) >= m_controllers.size()) + return 0; + + return 0; +} + +void DInputControllerInterface::SetControllerRumbleStrength(int controller_index, const float* strengths, + u32 num_motors) +{ + DebugAssert(static_cast(controller_index) < m_controllers.size()); +} + +bool DInputControllerInterface::SetControllerDeadzone(int controller_index, float size /* = 0.25f */) +{ + if (static_cast(controller_index) >= m_controllers.size()) + return false; + + m_controllers[static_cast(controller_index)].deadzone = std::clamp(std::abs(size), 0.01f, 0.99f); + Log_InfoPrintf("Controller %d deadzone size set to %f", controller_index, + m_controllers[static_cast(controller_index)].deadzone); + return true; +} diff --git a/src/frontend-common/dinput_controller_interface.h b/src/frontend-common/dinput_controller_interface.h new file mode 100644 index 000000000..fa8183b39 --- /dev/null +++ b/src/frontend-common/dinput_controller_interface.h @@ -0,0 +1,96 @@ +#pragma once +#define DIRECTINPUT_VERSION 0x0800 +#include "common/windows_headers.h" +#include "controller_interface.h" +#include "core/types.h" +#include +#include +#include +#include +#include +#include + +class DInputControllerInterface final : public ControllerInterface +{ +public: + DInputControllerInterface(); + ~DInputControllerInterface() override; + + Backend GetBackend() const override; + bool Initialize(CommonHostInterface* host_interface) override; + void Shutdown() override; + + // Removes all bindings. Call before setting new bindings. + void ClearBindings() override; + + // Binding to events. If a binding for this axis/button already exists, returns false. + bool BindControllerAxis(int controller_index, int axis_number, AxisSide axis_side, AxisCallback callback) override; + bool BindControllerButton(int controller_index, int button_number, ButtonCallback callback) override; + bool BindControllerAxisToButton(int controller_index, int axis_number, bool direction, + ButtonCallback callback) override; + bool BindControllerHatToButton(int controller_index, int hat_number, std::string_view hat_position, + ButtonCallback callback) override; + bool BindControllerButtonToAxis(int controller_index, int button_number, AxisCallback callback) override; + + // Changing rumble strength. + u32 GetControllerRumbleMotorCount(int controller_index) override; + void SetControllerRumbleStrength(int controller_index, const float* strengths, u32 num_motors) override; + + // Set deadzone that will be applied on axis-to-button mappings + bool SetControllerDeadzone(int controller_index, float size = 0.25f) override; + + void PollEvents() override; + +private: + template + using ComPtr = Microsoft::WRL::ComPtr; + + enum : u32 + { + NUM_AXISES = 8, + NUM_BUTTONS = 16, + NUM_HATS = 1, + + TOTAL_NUM_BUTTONS = NUM_BUTTONS + (NUM_HATS * NUM_HAT_DIRECTIONS), + }; + + struct ControllerData + { + ComPtr device; + DIJOYSTATE last_state = {}; + u32 num_buttons = 0; + u32 num_axes = 0; + + float deadzone = 0.25f; + + std::array axis_offsets; + + std::array, NUM_AXISES> axis_mapping; + std::array button_mapping; + std::array, NUM_AXISES> axis_button_mapping; + std::array button_axis_mapping; + + bool has_hat = false; + bool needs_poll = true; + }; + + using ControllerDataArray = std::vector; + + void AddDevices(); + bool AddDevice(ControllerData& cd, const char* name); + + static u32 GetHatDirection(DWORD hat); + + void CheckForStateChanges(u32 index, const DIJOYSTATE& new_state); + + bool HandleAxisEvent(u32 index, u32 axis, s32 value); + bool HandleButtonEvent(u32 index, u32 button, bool pressed); + + ControllerDataArray m_controllers; + + HMODULE m_dinput_module{}; + LPCDIDATAFORMAT m_joystick_data_format{}; + ComPtr m_dinput; + std::mutex m_event_intercept_mutex; + Hook::Callback m_event_intercept_callback; +}; diff --git a/src/frontend-common/frontend-common.vcxproj b/src/frontend-common/frontend-common.vcxproj index c1ca84874..9b5b26760 100644 --- a/src/frontend-common/frontend-common.vcxproj +++ b/src/frontend-common/frontend-common.vcxproj @@ -87,6 +87,7 @@ + @@ -111,6 +112,7 @@ + diff --git a/src/frontend-common/frontend-common.vcxproj.filters b/src/frontend-common/frontend-common.vcxproj.filters index 7a5eeadfa..42615ac11 100644 --- a/src/frontend-common/frontend-common.vcxproj.filters +++ b/src/frontend-common/frontend-common.vcxproj.filters @@ -23,6 +23,7 @@ + @@ -47,6 +48,7 @@ +