From f17a77c18e19530d9bbec241d5a03a760accc56c Mon Sep 17 00:00:00 2001 From: Connor McLaughlin Date: Mon, 18 Apr 2022 20:07:08 +1000 Subject: [PATCH] Common: Add HTTPDownloader implementation --- cmake/SearchForStuff.cmake | 4 + common/CMakeLists.txt | 15 ++ common/HTTPDownloader.cpp | 364 +++++++++++++++++++++++++++++++ common/HTTPDownloader.h | 106 +++++++++ common/HTTPDownloaderCurl.cpp | 185 ++++++++++++++++ common/HTTPDownloaderCurl.h | 54 +++++ common/HTTPDownloaderWinHTTP.cpp | 333 ++++++++++++++++++++++++++++ common/HTTPDownloaderWinHTTP.h | 53 +++++ common/ThreadPool.cpp | 137 ++++++++++++ common/ThreadPool.h | 255 ++++++++++++++++++++++ common/common.vcxproj | 17 +- common/common.vcxproj.filters | 23 +- 12 files changed, 1544 insertions(+), 2 deletions(-) create mode 100644 common/HTTPDownloader.cpp create mode 100644 common/HTTPDownloader.h create mode 100644 common/HTTPDownloaderCurl.cpp create mode 100644 common/HTTPDownloaderCurl.h create mode 100644 common/HTTPDownloaderWinHTTP.cpp create mode 100644 common/HTTPDownloaderWinHTTP.h create mode 100644 common/ThreadPool.cpp create mode 100644 common/ThreadPool.h diff --git a/cmake/SearchForStuff.cmake b/cmake/SearchForStuff.cmake index d44ec3ca0f..109d8ee951 100644 --- a/cmake/SearchForStuff.cmake +++ b/cmake/SearchForStuff.cmake @@ -236,6 +236,10 @@ if(QT_BUILD) find_optional_system_library(SDL2 3rdparty/sdl2 2.0.22) endif() +if(NOT WIN32 AND QT_BUILD) + find_package(CURL REQUIRED) +endif() + add_subdirectory(3rdparty/lzma EXCLUDE_FROM_ALL) add_subdirectory(3rdparty/libchdr EXCLUDE_FROM_ALL) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 0b535a3d4e..8d6e2b3c60 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -20,6 +20,7 @@ target_sources(common PRIVATE FastJmp.cpp FileSystem.cpp Image.cpp + HTTPDownloader.cpp Misc.cpp MD5Digest.cpp PrecompiledHeader.cpp @@ -29,6 +30,7 @@ target_sources(common PRIVATE SettingsWrapper.cpp StringUtil.cpp Timer.cpp + ThreadPool.cpp WindowInfo.cpp emitter/bmi.cpp emitter/cpudetect.cpp @@ -71,6 +73,7 @@ target_sources(common PRIVATE HashCombine.h Image.h LRUCache.h + HTTPDownloader.h MemcpyFast.h MemsetFast.inl MD5Digest.h @@ -87,6 +90,7 @@ target_sources(common PRIVATE StringUtil.h Timer.h Threading.h + ThreadPool.h TraceLog.h WindowInfo.h emitter/cpudetect_internal.h @@ -172,6 +176,8 @@ if(WIN32) CrashHandler.cpp CrashHandler.h FastJmp.asm + HTTPDownloaderWinHTTP.cpp + HTTPDownloaderWinHTTP.h StackWalker.cpp StackWalker.h D3D11/ShaderCache.cpp @@ -265,6 +271,15 @@ if (USE_GCC AND CMAKE_INTERPROCEDURAL_OPTIMIZATION) set_source_files_properties(FastJmp.cpp PROPERTIES COMPILE_FLAGS -fno-lto) endif() +if(NOT WIN32 AND (QT_BUILD OR NOGUI_BUILD)) + # libcurl-based HTTPDownloader + target_sources(common PRIVATE + HTTPDownloaderCurl.cpp + HTTPDownloaderCurl.h + ) + target_link_libraries(common PRIVATE CURL::libcurl) +endif() + target_link_libraries(common PRIVATE ${LIBC_LIBRARIES} PNG::PNG diff --git a/common/HTTPDownloader.cpp b/common/HTTPDownloader.cpp new file mode 100644 index 0000000000..846e98b620 --- /dev/null +++ b/common/HTTPDownloader.cpp @@ -0,0 +1,364 @@ +/* PCSX2 - PS2 Emulator for PCs + * Copyright (C) 2002-2022 PCSX2 Dev Team + * + * PCSX2 is free software: you can redistribute it and/or modify it under the terms + * of the GNU Lesser General Public License as published by the Free Software Found- + * ation, either version 3 of the License, or (at your option) any later version. + * + * PCSX2 is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along with PCSX2. + * If not, see . + */ + +#include "common/PrecompiledHeader.h" + +#include "common/HTTPDownloader.h" +#include "common/Assertions.h" +#include "common/Console.h" +#include "common/StringUtil.h" +#include "common/Timer.h" + +using namespace Common; + +static constexpr float DEFAULT_TIMEOUT_IN_SECONDS = 30; +static constexpr u32 DEFAULT_MAX_ACTIVE_REQUESTS = 4; + +const char HTTPDownloader::DEFAULT_USER_AGENT[] = + "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:85.0) Gecko/20100101 Firefox/85.0"; + +HTTPDownloader::HTTPDownloader() + : m_timeout(DEFAULT_TIMEOUT_IN_SECONDS) + , m_max_active_requests(DEFAULT_MAX_ACTIVE_REQUESTS) +{ +} + +HTTPDownloader::~HTTPDownloader() = default; + +void HTTPDownloader::SetTimeout(float timeout) +{ + m_timeout = timeout; +} + +void HTTPDownloader::SetMaxActiveRequests(u32 max_active_requests) +{ + pxAssert(max_active_requests > 0); + m_max_active_requests = max_active_requests; +} + +void HTTPDownloader::CreateRequest(std::string url, Request::Callback callback) +{ + Request* req = InternalCreateRequest(); + req->parent = this; + req->type = Request::Type::Get; + req->url = std::move(url); + req->callback = std::move(callback); + req->start_time = Timer::GetCurrentValue(); + + std::unique_lock lock(m_pending_http_request_lock); + if (LockedGetActiveRequestCount() < m_max_active_requests) + { + if (!StartRequest(req)) + return; + } + + LockedAddRequest(req); +} + +void HTTPDownloader::CreatePostRequest(std::string url, std::string post_data, Request::Callback callback) +{ + Request* req = InternalCreateRequest(); + req->parent = this; + req->type = Request::Type::Post; + req->url = std::move(url); + req->post_data = std::move(post_data); + req->callback = std::move(callback); + req->start_time = Timer::GetCurrentValue(); + + std::unique_lock lock(m_pending_http_request_lock); + if (LockedGetActiveRequestCount() < m_max_active_requests) + { + if (!StartRequest(req)) + return; + } + + LockedAddRequest(req); +} + +void HTTPDownloader::LockedPollRequests(std::unique_lock& lock) +{ + if (m_pending_http_requests.empty()) + return; + + InternalPollRequests(); + + const Common::Timer::Value current_time = Timer::GetCurrentValue(); + u32 active_requests = 0; + u32 unstarted_requests = 0; + + for (size_t index = 0; index < m_pending_http_requests.size();) + { + Request* req = m_pending_http_requests[index]; + if (req->state == Request::State::Pending) + { + unstarted_requests++; + index++; + continue; + } + + if (req->state == Request::State::Started && current_time >= req->start_time && + Common::Timer::ConvertValueToSeconds(current_time - req->start_time) >= m_timeout) + { + // request timed out + Console.Error("Request for '%s' timed out", req->url.c_str()); + + req->state.store(Request::State::Cancelled); + m_pending_http_requests.erase(m_pending_http_requests.begin() + index); + lock.unlock(); + + req->callback(-1, std::string(), Request::Data()); + + CloseRequest(req); + + lock.lock(); + continue; + } + + if (req->state != Request::State::Complete) + { + active_requests++; + index++; + continue; + } + + // request complete + DevCon.WriteLn("Request for '%s' complete, returned status code %u and %zu bytes", req->url.c_str(), + req->status_code, req->data.size()); + m_pending_http_requests.erase(m_pending_http_requests.begin() + index); + + // run callback with lock unheld + lock.unlock(); + req->callback(req->status_code, std::move(req->content_type), std::move(req->data)); + CloseRequest(req); + lock.lock(); + } + + // start new requests when we finished some + if (unstarted_requests > 0 && active_requests < m_max_active_requests) + { + for (size_t index = 0; index < m_pending_http_requests.size();) + { + Request* req = m_pending_http_requests[index]; + if (req->state != Request::State::Pending) + { + index++; + continue; + } + + if (!StartRequest(req)) + { + m_pending_http_requests.erase(m_pending_http_requests.begin() + index); + continue; + } + + active_requests++; + index++; + + if (active_requests >= m_max_active_requests) + break; + } + } +} + +void HTTPDownloader::PollRequests() +{ + std::unique_lock lock(m_pending_http_request_lock); + LockedPollRequests(lock); +} + +void HTTPDownloader::WaitForAllRequests() +{ + std::unique_lock lock(m_pending_http_request_lock); + while (!m_pending_http_requests.empty()) + LockedPollRequests(lock); +} + +void HTTPDownloader::LockedAddRequest(Request* request) +{ + m_pending_http_requests.push_back(request); +} + +u32 HTTPDownloader::LockedGetActiveRequestCount() +{ + u32 count = 0; + for (Request* req : m_pending_http_requests) + { + if (req->state == Request::State::Started || req->state == Request::State::Receiving) + count++; + } + return count; +} + +bool HTTPDownloader::HasAnyRequests() +{ + std::unique_lock lock(m_pending_http_request_lock); + return !m_pending_http_requests.empty(); +} + +std::string HTTPDownloader::URLEncode(const std::string_view& str) +{ + std::string ret; + ret.reserve(str.length() + ((str.length() + 3) / 4) * 3); + + for (size_t i = 0, l = str.size(); i < l; i++) + { + const char c = str[i]; + if ((c >= '0' && c <= '9') || + (c >= 'a' && c <= 'z') || + (c >= 'A' && c <= 'Z') || + c == '-' || c == '_' || c == '.' || c == '!' || c == '~' || + c == '*' || c == '\'' || c == '(' || c == ')') + { + ret.push_back(c); + } + else + { + ret.push_back('%'); + + const unsigned char n1 = static_cast(c) >> 4; + const unsigned char n2 = static_cast(c) & 0x0F; + ret.push_back((n1 >= 10) ? ('a' + (n1 - 10)) : ('0' + n1)); + ret.push_back((n2 >= 10) ? ('a' + (n2 - 10)) : ('0' + n2)); + } + } + + return ret; +} + +std::string HTTPDownloader::URLDecode(const std::string_view& str) +{ + std::string ret; + ret.reserve(str.length()); + + for (size_t i = 0, l = str.size(); i < l; i++) + { + const char c = str[i]; + if (c == '+') + { + ret.push_back(c); + } + else if (c == '%') + { + if ((i + 2) >= str.length()) + break; + + const char clower = str[i + 1]; + const char cupper = str[i + 2]; + const unsigned char lower = (clower >= '0' && clower <= '9') ? static_cast(clower - '0') : ((clower >= 'a' && clower <= 'f') ? static_cast(clower - 'a') : ((clower >= 'A' && clower <= 'F') ? static_cast(clower - 'A') : 0)); + const unsigned char upper = (cupper >= '0' && cupper <= '9') ? static_cast(cupper - '0') : ((cupper >= 'a' && cupper <= 'f') ? static_cast(cupper - 'a') : ((cupper >= 'A' && cupper <= 'F') ? static_cast(cupper - 'A') : 0)); + const char dch = static_cast(lower | (upper << 4)); + ret.push_back(dch); + } + else + { + ret.push_back(c); + } + } + + return std::string(str); +} + +std::string HTTPDownloader::GetExtensionForContentType(const std::string& content_type) +{ + // Based on https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/MIME_types/Common_types + static constexpr const char* table[][2] = { + {"audio/aac", "aac"}, + {"application/x-abiword", "abw"}, + {"application/x-freearc", "arc"}, + {"image/avif", "avif"}, + {"video/x-msvideo", "avi"}, + {"application/vnd.amazon.ebook", "azw"}, + {"application/octet-stream", "bin"}, + {"image/bmp", "bmp"}, + {"application/x-bzip", "bz"}, + {"application/x-bzip2", "bz2"}, + {"application/x-cdf", "cda"}, + {"application/x-csh", "csh"}, + {"text/css", "css"}, + {"text/csv", "csv"}, + {"application/msword", "doc"}, + {"application/vnd.openxmlformats-officedocument.wordprocessingml.document", "docx"}, + {"application/vnd.ms-fontobject", "eot"}, + {"application/epub+zip", "epub"}, + {"application/gzip", "gz"}, + {"image/gif", "gif"}, + {"text/html", "htm"}, + {"image/vnd.microsoft.icon", "ico"}, + {"text/calendar", "ics"}, + {"application/java-archive", "jar"}, + {"image/jpeg", "jpg"}, + {"text/javascript", "js"}, + {"application/json", "json"}, + {"application/ld+json", "jsonld"}, + {"audio/midi audio/x-midi", "mid"}, + {"text/javascript", "mjs"}, + {"audio/mpeg", "mp3"}, + {"video/mp4", "mp4"}, + {"video/mpeg", "mpeg"}, + {"application/vnd.apple.installer+xml", "mpkg"}, + {"application/vnd.oasis.opendocument.presentation", "odp"}, + {"application/vnd.oasis.opendocument.spreadsheet", "ods"}, + {"application/vnd.oasis.opendocument.text", "odt"}, + {"audio/ogg", "oga"}, + {"video/ogg", "ogv"}, + {"application/ogg", "ogx"}, + {"audio/opus", "opus"}, + {"font/otf", "otf"}, + {"image/png", "png"}, + {"application/pdf", "pdf"}, + {"application/x-httpd-php", "php"}, + {"application/vnd.ms-powerpoint", "ppt"}, + {"application/vnd.openxmlformats-officedocument.presentationml.presentation", "pptx"}, + {"application/vnd.rar", "rar"}, + {"application/rtf", "rtf"}, + {"application/x-sh", "sh"}, + {"image/svg+xml", "svg"}, + {"application/x-tar", "tar"}, + {"image/tiff", "tif"}, + {"video/mp2t", "ts"}, + {"font/ttf", "ttf"}, + {"text/plain", "txt"}, + {"application/vnd.visio", "vsd"}, + {"audio/wav", "wav"}, + {"audio/webm", "weba"}, + {"video/webm", "webm"}, + {"image/webp", "webp"}, + {"font/woff", "woff"}, + {"font/woff2", "woff2"}, + {"application/xhtml+xml", "xhtml"}, + {"application/vnd.ms-excel", "xls"}, + {"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", "xlsx"}, + {"application/xml", "xml"}, + {"text/xml", "xml"}, + {"application/vnd.mozilla.xul+xml", "xul"}, + {"application/zip", "zip"}, + {"video/3gpp", "3gp"}, + {"audio/3gpp", "3gp"}, + {"video/3gpp2", "3g2"}, + {"audio/3gpp2", "3g2"}, + {"application/x-7z-compressed", "7z"}, + }; + + std::string ret; + for (size_t i = 0; i < std::size(table); i++) + { + if (StringUtil::compareNoCase(table[i][0], content_type)) + { + ret = table[i][1]; + break; + } + } + return ret; +} diff --git a/common/HTTPDownloader.h b/common/HTTPDownloader.h new file mode 100644 index 0000000000..1d4e3586cc --- /dev/null +++ b/common/HTTPDownloader.h @@ -0,0 +1,106 @@ +/* PCSX2 - PS2 Emulator for PCs + * Copyright (C) 2002-2022 PCSX2 Dev Team + * + * PCSX2 is free software: you can redistribute it and/or modify it under the terms + * of the GNU Lesser General Public License as published by the Free Software Found- + * ation, either version 3 of the License, or (at your option) any later version. + * + * PCSX2 is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along with PCSX2. + * If not, see . + */ + +#pragma once +#include "common/Pcsx2Defs.h" +#include +#include +#include +#include +#include +#include +#include + +namespace Common +{ + class HTTPDownloader + { + public: + enum : s32 + { + HTTP_OK = 200 + }; + + struct Request + { + using Data = std::vector; + using Callback = std::function; + + enum class Type + { + Get, + Post, + }; + + enum class State + { + Pending, + Cancelled, + Started, + Receiving, + Complete, + }; + + HTTPDownloader* parent; + Callback callback; + std::string url; + std::string post_data; + std::string content_type; + Data data; + u64 start_time; + s32 status_code = 0; + u32 content_length = 0; + Type type = Type::Get; + std::atomic state{State::Pending}; + }; + + HTTPDownloader(); + virtual ~HTTPDownloader(); + + static std::unique_ptr Create(const char* user_agent = DEFAULT_USER_AGENT); + static std::string URLEncode(const std::string_view& str); + static std::string URLDecode(const std::string_view& str); + static std::string GetExtensionForContentType(const std::string& content_type); + + void SetTimeout(float timeout); + void SetMaxActiveRequests(u32 max_active_requests); + + void CreateRequest(std::string url, Request::Callback callback); + void CreatePostRequest(std::string url, std::string post_data, Request::Callback callback); + void PollRequests(); + void WaitForAllRequests(); + bool HasAnyRequests(); + + static const char DEFAULT_USER_AGENT[]; + + protected: + virtual Request* InternalCreateRequest() = 0; + virtual void InternalPollRequests() = 0; + + virtual bool StartRequest(Request* request) = 0; + virtual void CloseRequest(Request* request) = 0; + + void LockedAddRequest(Request* request); + u32 LockedGetActiveRequestCount(); + void LockedPollRequests(std::unique_lock& lock); + + float m_timeout; + u32 m_max_active_requests; + + std::mutex m_pending_http_request_lock; + std::vector m_pending_http_requests; + }; + +} // namespace Common \ No newline at end of file diff --git a/common/HTTPDownloaderCurl.cpp b/common/HTTPDownloaderCurl.cpp new file mode 100644 index 0000000000..8d20a3a36a --- /dev/null +++ b/common/HTTPDownloaderCurl.cpp @@ -0,0 +1,185 @@ +/* PCSX2 - PS2 Emulator for PCs + * Copyright (C) 2002-2022 PCSX2 Dev Team + * + * PCSX2 is free software: you can redistribute it and/or modify it under the terms + * of the GNU Lesser General Public License as published by the Free Software Found- + * ation, either version 3 of the License, or (at your option) any later version. + * + * PCSX2 is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along with PCSX2. + * If not, see . + */ + +#include "common/PrecompiledHeader.h" + +#include "common/HTTPDownloaderCurl.h" +#include "common/Assertions.h" +#include "common/Console.h" +#include "common/StringUtil.h" +#include "common/Timer.h" + +#include +#include +#include +#include + +using namespace Common; + +HTTPDownloaderCurl::HTTPDownloaderCurl() + : HTTPDownloader() +{ +} + +HTTPDownloaderCurl::~HTTPDownloaderCurl() = default; + +std::unique_ptr HTTPDownloader::Create(const char* user_agent) +{ + std::unique_ptr instance(std::make_unique()); + if (!instance->Initialize(user_agent)) + return {}; + + return instance; +} + +static bool s_curl_initialized = false; +static std::once_flag s_curl_initialized_once_flag; + +bool HTTPDownloaderCurl::Initialize(const char* user_agent) +{ + if (!s_curl_initialized) + { + std::call_once(s_curl_initialized_once_flag, []() { + s_curl_initialized = curl_global_init(CURL_GLOBAL_ALL) == CURLE_OK; + if (s_curl_initialized) + { + std::atexit([]() { + curl_global_cleanup(); + s_curl_initialized = false; + }); + } + }); + if (!s_curl_initialized) + { + Console.Error("curl_global_init() failed"); + return false; + } + } + + m_user_agent = user_agent; + m_thread_pool = std::make_unique(m_max_active_requests); + return true; +} + +size_t HTTPDownloaderCurl::WriteCallback(char* ptr, size_t size, size_t nmemb, void* userdata) +{ + Request* req = static_cast(userdata); + const size_t current_size = req->data.size(); + const size_t transfer_size = size * nmemb; + const size_t new_size = current_size + transfer_size; + req->data.resize(new_size); + std::memcpy(&req->data[current_size], ptr, transfer_size); + return nmemb; +} + +void HTTPDownloaderCurl::ProcessRequest(Request* req) +{ + std::unique_lock cancel_lock(m_cancel_mutex); + if (req->closed.load()) + return; + + cancel_lock.unlock(); + + // Apparently OpenSSL can fire SIGPIPE... + sigset_t old_block_mask = {}; + sigset_t new_block_mask = {}; + sigemptyset(&old_block_mask); + sigemptyset(&new_block_mask); + sigaddset(&new_block_mask, SIGPIPE); + if (pthread_sigmask(SIG_BLOCK, &new_block_mask, &old_block_mask) != 0) + Console.Warning("Failed to block SIGPIPE"); + + req->start_time = Common::Timer::GetCurrentValue(); + int ret = curl_easy_perform(req->handle); + if (ret == CURLE_OK) + { + long response_code = 0; + curl_easy_getinfo(req->handle, CURLINFO_RESPONSE_CODE, &response_code); + req->status_code = static_cast(response_code); + + char* content_type = nullptr; + if (!curl_easy_getinfo(req->handle, CURLINFO_CONTENT_TYPE, &content_type) && content_type) + req->content_type = content_type; + + DevCon.WriteLn("Request for '%s' returned status code %d and %zu bytes", req->url.c_str(), req->status_code, + req->data.size()); + } + else + { + Console.Error("Request for '%s' returned %d", req->url.c_str(), ret); + } + + curl_easy_cleanup(req->handle); + + if (pthread_sigmask(SIG_UNBLOCK, &new_block_mask, &old_block_mask) != 0) + Console.Warning("Failed to unblock SIGPIPE"); + + cancel_lock.lock(); + req->state = Request::State::Complete; + if (req->closed.load()) + delete req; + else + req->closed.store(true); +} + +HTTPDownloader::Request* HTTPDownloaderCurl::InternalCreateRequest() +{ + Request* req = new Request(); + req->handle = curl_easy_init(); + if (!req->handle) + { + delete req; + return nullptr; + } + + return req; +} + +void HTTPDownloaderCurl::InternalPollRequests() +{ + // noop - uses thread pool +} + +bool HTTPDownloaderCurl::StartRequest(HTTPDownloader::Request* request) +{ + Request* req = static_cast(request); + curl_easy_setopt(req->handle, CURLOPT_URL, request->url.c_str()); + curl_easy_setopt(req->handle, CURLOPT_USERAGENT, m_user_agent.c_str()); + curl_easy_setopt(req->handle, CURLOPT_WRITEFUNCTION, &HTTPDownloaderCurl::WriteCallback); + curl_easy_setopt(req->handle, CURLOPT_WRITEDATA, req); + curl_easy_setopt(req->handle, CURLOPT_NOSIGNAL, 1); + + if (request->type == Request::Type::Post) + { + curl_easy_setopt(req->handle, CURLOPT_POST, 1L); + curl_easy_setopt(req->handle, CURLOPT_POSTFIELDS, request->post_data.c_str()); + } + + DbgCon.WriteLn("Started HTTP request for '%s'", req->url.c_str()); + req->state = Request::State::Started; + req->start_time = Common::Timer::GetCurrentValue(); + m_thread_pool->Schedule(std::bind(&HTTPDownloaderCurl::ProcessRequest, this, req)); + return true; +} + +void HTTPDownloaderCurl::CloseRequest(HTTPDownloader::Request* request) +{ + std::unique_lock cancel_lock(m_cancel_mutex); + Request* req = static_cast(request); + if (req->closed.load()) + delete req; + else + req->closed.store(true); +} diff --git a/common/HTTPDownloaderCurl.h b/common/HTTPDownloaderCurl.h new file mode 100644 index 0000000000..79877c40b7 --- /dev/null +++ b/common/HTTPDownloaderCurl.h @@ -0,0 +1,54 @@ +/* PCSX2 - PS2 Emulator for PCs + * Copyright (C) 2002-2022 PCSX2 Dev Team + * + * PCSX2 is free software: you can redistribute it and/or modify it under the terms + * of the GNU Lesser General Public License as published by the Free Software Found- + * ation, either version 3 of the License, or (at your option) any later version. + * + * PCSX2 is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along with PCSX2. + * If not, see . + */ + +#pragma once +#include "common/HTTPDownloader.h" +#include "common/ThreadPool.h" +#include +#include +#include +#include + +namespace Common +{ + class HTTPDownloaderCurl final : public HTTPDownloader + { + public: + HTTPDownloaderCurl(); + ~HTTPDownloaderCurl() override; + + bool Initialize(const char* user_agent); + + protected: + Request* InternalCreateRequest() override; + void InternalPollRequests() override; + bool StartRequest(HTTPDownloader::Request* request) override; + void CloseRequest(HTTPDownloader::Request* request) override; + + private: + struct Request : HTTPDownloader::Request + { + CURL* handle = nullptr; + std::atomic_bool closed{false}; + }; + + static size_t WriteCallback(char* ptr, size_t size, size_t nmemb, void* userdata); + void ProcessRequest(Request* req); + + std::string m_user_agent; + std::unique_ptr m_thread_pool; + std::mutex m_cancel_mutex; + }; +} // namespace Common diff --git a/common/HTTPDownloaderWinHTTP.cpp b/common/HTTPDownloaderWinHTTP.cpp new file mode 100644 index 0000000000..cb78d63264 --- /dev/null +++ b/common/HTTPDownloaderWinHTTP.cpp @@ -0,0 +1,333 @@ +/* PCSX2 - PS2 Emulator for PCs + * Copyright (C) 2002-2022 PCSX2 Dev Team + * + * PCSX2 is free software: you can redistribute it and/or modify it under the terms + * of the GNU Lesser General Public License as published by the Free Software Found- + * ation, either version 3 of the License, or (at your option) any later version. + * + * PCSX2 is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along with PCSX2. + * If not, see . + */ + +#include "common/PrecompiledHeader.h" + +#include "common/HTTPDownloaderWinHTTP.h" +#include "common/Assertions.h" +#include "common/Console.h" +#include "common/StringUtil.h" +#include "common/Timer.h" +#include +#include + +#pragma comment(lib, "winhttp.lib") + +using namespace Common; + +HTTPDownloaderWinHttp::HTTPDownloaderWinHttp() + : HTTPDownloader() +{ +} + +HTTPDownloaderWinHttp::~HTTPDownloaderWinHttp() +{ + if (m_hSession) + { + WinHttpSetStatusCallback(m_hSession, nullptr, WINHTTP_CALLBACK_FLAG_ALL_NOTIFICATIONS, NULL); + WinHttpCloseHandle(m_hSession); + } +} + +std::unique_ptr HTTPDownloader::Create(const char* user_agent) +{ + std::unique_ptr instance(std::make_unique()); + if (!instance->Initialize(user_agent)) + return {}; + + return instance; +} + +bool HTTPDownloaderWinHttp::Initialize(const char* user_agent) +{ + // WINHTTP_ACCESS_TYPE_AUTOMATIC_PROXY is not supported before Win8.1. + const DWORD dwAccessType = + IsWindows8Point1OrGreater() ? WINHTTP_ACCESS_TYPE_AUTOMATIC_PROXY : WINHTTP_ACCESS_TYPE_DEFAULT_PROXY; + + m_hSession = WinHttpOpen(StringUtil::UTF8StringToWideString(user_agent).c_str(), dwAccessType, nullptr, nullptr, + WINHTTP_FLAG_ASYNC); + if (m_hSession == NULL) + { + Console.Error("WinHttpOpen() failed: %u", GetLastError()); + return false; + } + + const DWORD notification_flags = WINHTTP_CALLBACK_FLAG_ALL_COMPLETIONS | WINHTTP_CALLBACK_FLAG_REQUEST_ERROR | + WINHTTP_CALLBACK_FLAG_HANDLES | WINHTTP_CALLBACK_FLAG_SECURE_FAILURE; + if (WinHttpSetStatusCallback(m_hSession, HTTPStatusCallback, notification_flags, NULL) == + WINHTTP_INVALID_STATUS_CALLBACK) + { + Console.Error("WinHttpSetStatusCallback() failed: %u", GetLastError()); + return false; + } + + return true; +} + +void CALLBACK HTTPDownloaderWinHttp::HTTPStatusCallback(HINTERNET hRequest, DWORD_PTR dwContext, DWORD dwInternetStatus, + LPVOID lpvStatusInformation, DWORD dwStatusInformationLength) +{ + Request* req = reinterpret_cast(dwContext); + switch (dwInternetStatus) + { + case WINHTTP_CALLBACK_STATUS_HANDLE_CREATED: + return; + + case WINHTTP_CALLBACK_STATUS_HANDLE_CLOSING: + { + if (!req) + return; + + pxAssert(hRequest == req->hRequest); + + HTTPDownloaderWinHttp* parent = static_cast(req->parent); + std::unique_lock lock(parent->m_pending_http_request_lock); + pxAssertRel(std::none_of(parent->m_pending_http_requests.begin(), parent->m_pending_http_requests.end(), + [req](HTTPDownloader::Request* it) { return it == req; }), + "Request is not pending at close time"); + + // we can clean up the connection as well + pxAssert(req->hConnection != NULL); + WinHttpCloseHandle(req->hConnection); + delete req; + return; + } + + case WINHTTP_CALLBACK_STATUS_REQUEST_ERROR: + { + const WINHTTP_ASYNC_RESULT* res = reinterpret_cast(lpvStatusInformation); + Console.Error("WinHttp async function %p returned error %u", res->dwResult, res->dwError); + req->status_code = -1; + req->state.store(Request::State::Complete); + return; + } + case WINHTTP_CALLBACK_STATUS_SENDREQUEST_COMPLETE: + { + DbgCon.WriteLn("SendRequest complete"); + if (!WinHttpReceiveResponse(hRequest, nullptr)) + { + Console.Error("WinHttpReceiveResponse() failed: %u", GetLastError()); + req->status_code = -1; + req->state.store(Request::State::Complete); + } + + return; + } + case WINHTTP_CALLBACK_STATUS_HEADERS_AVAILABLE: + { + DbgCon.WriteLn("Headers available"); + + DWORD buffer_size = sizeof(req->status_code); + if (!WinHttpQueryHeaders(hRequest, WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER, + WINHTTP_HEADER_NAME_BY_INDEX, &req->status_code, &buffer_size, WINHTTP_NO_HEADER_INDEX)) + { + Console.Error("WinHttpQueryHeaders() for status code failed: %u", GetLastError()); + req->status_code = -1; + req->state.store(Request::State::Complete); + return; + } + + buffer_size = sizeof(req->content_length); + if (!WinHttpQueryHeaders(hRequest, WINHTTP_QUERY_CONTENT_LENGTH | WINHTTP_QUERY_FLAG_NUMBER, + WINHTTP_HEADER_NAME_BY_INDEX, &req->content_length, &buffer_size, + WINHTTP_NO_HEADER_INDEX)) + { + if (GetLastError() != ERROR_WINHTTP_HEADER_NOT_FOUND) + Console.Warning("WinHttpQueryHeaders() for content length failed: %u", GetLastError()); + + req->content_length = 0; + } + + DWORD content_type_length = 0; + if (!WinHttpQueryHeaders(hRequest, WINHTTP_QUERY_CONTENT_TYPE, WINHTTP_HEADER_NAME_BY_INDEX, + WINHTTP_NO_OUTPUT_BUFFER, &content_type_length, WINHTTP_NO_HEADER_INDEX) && + GetLastError() == ERROR_INSUFFICIENT_BUFFER && content_type_length >= sizeof(content_type_length)) + { + std::wstring content_type_wstring; + content_type_wstring.resize((content_type_length / sizeof(wchar_t)) - 1); + if (WinHttpQueryHeaders(hRequest, WINHTTP_QUERY_CONTENT_TYPE, WINHTTP_HEADER_NAME_BY_INDEX, + content_type_wstring.data(), &content_type_length, WINHTTP_NO_HEADER_INDEX)) + { + req->content_type = StringUtil::WideStringToUTF8String(content_type_wstring); + } + } + + DbgCon.WriteLn("Status code %d, content-length is %u, content-type is %s", req->status_code, req->content_length, + req->content_type.c_str()); + req->data.reserve(req->content_length); + req->state = Request::State::Receiving; + + // start reading + if (!WinHttpQueryDataAvailable(hRequest, nullptr) && GetLastError() != ERROR_IO_PENDING) + { + Console.Error("WinHttpQueryDataAvailable() failed: %u", GetLastError()); + req->status_code = -1; + req->state.store(Request::State::Complete); + } + + return; + } + case WINHTTP_CALLBACK_STATUS_DATA_AVAILABLE: + { + DWORD bytes_available; + std::memcpy(&bytes_available, lpvStatusInformation, sizeof(bytes_available)); + if (bytes_available == 0) + { + // end of request + DbgCon.WriteLn("End of request '%s', %zu bytes received", req->url.c_str(), req->data.size()); + req->state.store(Request::State::Complete); + return; + } + + // start the transfer + DbgCon.WriteLn("%u bytes available", bytes_available); + req->io_position = static_cast(req->data.size()); + req->data.resize(req->io_position + bytes_available); + if (!WinHttpReadData(hRequest, req->data.data() + req->io_position, bytes_available, nullptr) && + GetLastError() != ERROR_IO_PENDING) + { + Console.Error("WinHttpReadData() failed: %u", GetLastError()); + req->status_code = -1; + req->state.store(Request::State::Complete); + } + + return; + } + case WINHTTP_CALLBACK_STATUS_READ_COMPLETE: + { + DbgCon.WriteLn("Read of %u complete", dwStatusInformationLength); + + const u32 new_size = req->io_position + dwStatusInformationLength; + pxAssertRel(new_size <= req->data.size(), "HTTP overread occurred"); + req->data.resize(new_size); + req->start_time = Common::Timer::GetCurrentValue(); + + if (!WinHttpQueryDataAvailable(hRequest, nullptr) && GetLastError() != ERROR_IO_PENDING) + { + Console.Error("WinHttpQueryDataAvailable() failed: %u", GetLastError()); + req->status_code = -1; + req->state.store(Request::State::Complete); + } + + return; + } + default: + // unhandled, ignore + return; + } +} + +HTTPDownloader::Request* HTTPDownloaderWinHttp::InternalCreateRequest() +{ + Request* req = new Request(); + return req; +} + +void HTTPDownloaderWinHttp::InternalPollRequests() +{ + // noop - it uses windows's worker threads +} + +bool HTTPDownloaderWinHttp::StartRequest(HTTPDownloader::Request* request) +{ + Request* req = static_cast(request); + + std::wstring host_name; + host_name.resize(req->url.size()); + req->object_name.resize(req->url.size()); + + URL_COMPONENTSW uc = {}; + uc.dwStructSize = sizeof(uc); + uc.lpszHostName = host_name.data(); + uc.dwHostNameLength = static_cast(host_name.size()); + uc.lpszUrlPath = req->object_name.data(); + uc.dwUrlPathLength = static_cast(req->object_name.size()); + + const std::wstring url_wide(StringUtil::UTF8StringToWideString(req->url)); + if (!WinHttpCrackUrl(url_wide.c_str(), static_cast(url_wide.size()), 0, &uc)) + { + Console.Error("WinHttpCrackUrl() failed: %u", GetLastError()); + req->callback(-1, std::string(), Request::Data()); + delete req; + return false; + } + + host_name.resize(uc.dwHostNameLength); + req->object_name.resize(uc.dwUrlPathLength); + + req->hConnection = WinHttpConnect(m_hSession, host_name.c_str(), uc.nPort, 0); + if (!req->hConnection) + { + Console.Error("Failed to start HTTP request for '%s': %u", req->url.c_str(), GetLastError()); + req->callback(-1, std::string(), Request::Data()); + delete req; + return false; + } + + const DWORD request_flags = uc.nScheme == INTERNET_SCHEME_HTTPS ? WINHTTP_FLAG_SECURE : 0; + req->hRequest = + WinHttpOpenRequest(req->hConnection, (req->type == HTTPDownloader::Request::Type::Post) ? L"POST" : L"GET", + req->object_name.c_str(), NULL, NULL, NULL, request_flags); + if (!req->hRequest) + { + Console.Error("WinHttpOpenRequest() failed: %u", GetLastError()); + WinHttpCloseHandle(req->hConnection); + return false; + } + + BOOL result; + if (req->type == HTTPDownloader::Request::Type::Post) + { + const std::wstring_view additional_headers(L"Content-Type: application/x-www-form-urlencoded\r\n"); + result = WinHttpSendRequest(req->hRequest, additional_headers.data(), static_cast(additional_headers.size()), + req->post_data.data(), static_cast(req->post_data.size()), + static_cast(req->post_data.size()), reinterpret_cast(req)); + } + else + { + result = WinHttpSendRequest(req->hRequest, WINHTTP_NO_ADDITIONAL_HEADERS, 0, WINHTTP_NO_REQUEST_DATA, 0, 0, + reinterpret_cast(req)); + } + + if (!result && GetLastError() != ERROR_IO_PENDING) + { + Console.Error("WinHttpSendRequest() failed: %u", GetLastError()); + req->status_code = -1; + req->state.store(Request::State::Complete); + } + + DevCon.WriteLn("Started HTTP request for '%s'", req->url.c_str()); + req->state = Request::State::Started; + req->start_time = Common::Timer::GetCurrentValue(); + return true; +} + +void HTTPDownloaderWinHttp::CloseRequest(HTTPDownloader::Request* request) +{ + Request* req = static_cast(request); + + if (req->hRequest != NULL) + { + // req will be freed by the callback. + // the callback can fire immediately here if there's nothing running async, so don't touch req afterwards + WinHttpCloseHandle(req->hRequest); + return; + } + + if (req->hConnection != NULL) + WinHttpCloseHandle(req->hConnection); + + delete req; +} diff --git a/common/HTTPDownloaderWinHTTP.h b/common/HTTPDownloaderWinHTTP.h new file mode 100644 index 0000000000..5200cfea01 --- /dev/null +++ b/common/HTTPDownloaderWinHTTP.h @@ -0,0 +1,53 @@ +/* PCSX2 - PS2 Emulator for PCs + * Copyright (C) 2002-2022 PCSX2 Dev Team + * + * PCSX2 is free software: you can redistribute it and/or modify it under the terms + * of the GNU Lesser General Public License as published by the Free Software Found- + * ation, either version 3 of the License, or (at your option) any later version. + * + * PCSX2 is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along with PCSX2. + * If not, see . + */ + +#pragma once + +#include "common/HTTPDownloader.h" +#include "common/RedtapeWindows.h" + +#include + +namespace Common +{ + class HTTPDownloaderWinHttp final : public HTTPDownloader + { + public: + HTTPDownloaderWinHttp(); + ~HTTPDownloaderWinHttp() override; + + bool Initialize(const char* user_agent); + + protected: + Request* InternalCreateRequest() override; + void InternalPollRequests() override; + bool StartRequest(HTTPDownloader::Request* request) override; + void CloseRequest(HTTPDownloader::Request* request) override; + + private: + struct Request : HTTPDownloader::Request + { + std::wstring object_name; + HINTERNET hConnection = NULL; + HINTERNET hRequest = NULL; + u32 io_position = 0; + }; + + static void CALLBACK HTTPStatusCallback(HINTERNET hInternet, DWORD_PTR dwContext, DWORD dwInternetStatus, + LPVOID lpvStatusInformation, DWORD dwStatusInformationLength); + + HINTERNET m_hSession = NULL; + }; +} // namespace Common \ No newline at end of file diff --git a/common/ThreadPool.cpp b/common/ThreadPool.cpp new file mode 100644 index 0000000000..a5819a8ad9 --- /dev/null +++ b/common/ThreadPool.cpp @@ -0,0 +1,137 @@ +/* + * MIT License + * + * Copyright (c) 2022 Colion Braley + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +// From https://raw.githubusercontent.com/cbraley/threadpool/master/src/thread_pool.cc + +#include "common/PrecompiledHeader.h" + +#include "common/ThreadPool.h" + +#include + +namespace cb { + +// static +unsigned int ThreadPool::GetNumLogicalCores() { + // TODO(cbraley): Apparently this is broken in some older stdlib + // implementations? + const unsigned int dflt = std::thread::hardware_concurrency(); + if (dflt == 0) { + // TODO(cbraley): Return some error code instead. + return 16; + } else { + return dflt; + } +} + +ThreadPool::~ThreadPool() { + // TODO(cbraley): The current thread could help out to drain the work_ queue + // faster - for example, if there is work that hasn't yet been scheduled this + // thread could "pitch in" to help finish faster. + + { + std::lock_guard scoped_lock(mu_); + exit_ = true; + } + condvar_.notify_all(); // Tell *all* workers we are ready. + + for (std::thread& thread : workers_) { + thread.join(); + } +} + +void ThreadPool::Wait() { + std::unique_lock lock(mu_); + if (!work_.empty()) { + work_done_condvar_.wait(lock, [this] { return work_.empty(); }); + } +} + +ThreadPool::ThreadPool(int num_workers) + : num_workers_(num_workers), exit_(false) { + assert(num_workers_ > 0); + // TODO(cbraley): Handle thread construction exceptions. + workers_.reserve(num_workers_); + for (int i = 0; i < num_workers_; ++i) { + workers_.emplace_back(&ThreadPool::ThreadLoop, this); + } +} + +void ThreadPool::Schedule(std::function func) { + ScheduleAndGetFuture(std::move(func)); // We ignore the returned std::future. +} + +void ThreadPool::ThreadLoop() { + // Wait until the ThreadPool sends us work. + while (true) { + WorkItem work_item; + + int prev_work_size = -1; + { + std::unique_lock lock(mu_); + condvar_.wait(lock, [this] { return exit_ || (!work_.empty()); }); + // ...after the wait(), we hold the lock. + + // If all the work is done and exit_ is true, break out of the loop. + if (exit_ && work_.empty()) { + break; + } + + // Pop the work off of the queue - we are careful to execute the + // work_item.func callback only after we have released the lock. + prev_work_size = work_.size(); + work_item = std::move(work_.front()); + work_.pop(); + } + + // We are careful to do the work without the lock held! + // TODO(cbraley): Handle exceptions properly. + work_item.func(); // Do work. + + if (work_done_callback_) { + work_done_callback_(prev_work_size - 1); + } + + // Notify a condvar is all work is done. + { + std::unique_lock lock(mu_); + if (work_.empty() && prev_work_size == 1) { + work_done_condvar_.notify_all(); + } + } + } +} + +int ThreadPool::OutstandingWorkSize() const { + std::lock_guard scoped_lock(mu_); + return work_.size(); +} + +int ThreadPool::NumWorkers() const { return num_workers_; } + +void ThreadPool::SetWorkDoneCallback(std::function func) { + work_done_callback_ = std::move(func); +} + +} // namespace cb diff --git a/common/ThreadPool.h b/common/ThreadPool.h new file mode 100644 index 0000000000..83a0664d6a --- /dev/null +++ b/common/ThreadPool.h @@ -0,0 +1,255 @@ +/* + * MIT License + * + * Copyright (c) 2022 Colion Braley + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +// From https://raw.githubusercontent.com/cbraley/threadpool/master/src/thread_pool.h + +#pragma once + +// A simple thread pool class. +// Usage examples: +// +// { +// ThreadPool pool(16); // 16 worker threads. +// for (int i = 0; i < 100; ++i) { +// pool.Schedule([i]() { +// DoSlowExpensiveOperation(i); +// }); +// } +// +// // `pool` goes out of scope here - the code will block in the ~ThreadPool +// // destructor until all work is complete. +// } +// +// // TODO(cbraley): Add examples with std::future. + +#include +#include +#include +#include +#include +#include +#include + +// We want to use std::invoke if C++17 is available, and fallback to "hand +// crafted" code if std::invoke isn't available. +#if __cplusplus >= 201703L || defined(_MSC_VER) + #define INVOKE_MACRO(CALLABLE, ARGS_TYPE, ARGS) std::invoke(CALLABLE, std::forward(ARGS)...) +#elif __cplusplus >= 201103L + // Update this with http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2014/n4169.html. + #define INVOKE_MACRO(CALLABLE, ARGS_TYPE, ARGS) CALLABLE(std::forward(ARGS)...) +#else + #error ("C++ version is too old! C++98 is not supported.") +#endif + +namespace cb { + +class ThreadPool { + public: + // Create a thread pool with `num_workers` dedicated worker threads. + explicit ThreadPool(int num_workers); + + // Default construction is disallowed. + ThreadPool() = delete; + + // Get the number of logical cores on the CPU. This is implemented using + // std::thread::hardware_concurrency(). + // https://en.cppreference.com/w/cpp/thread/thread/hardware_concurrency + static unsigned int GetNumLogicalCores(); + + // The `ThreadPool` destructor blocks until all outstanding work is complete. + ~ThreadPool(); + + // No copying, assigning, or std::move-ing. + ThreadPool& operator=(const ThreadPool&) = delete; + ThreadPool(const ThreadPool&) = delete; + ThreadPool(ThreadPool&&) = delete; + ThreadPool& operator=(ThreadPool&&) = delete; + + // Add the function `func` to the thread pool. `func` will be executed at some + // point in the future on an arbitrary thread. + void Schedule(std::function func); + + // Add `func` to the thread pool, and return a std::future that can be used to + // access the function's return value. + // + // *** Usage example *** + // Don't be alarmed by this function's tricky looking signature - this is + // very easy to use. Here's an example: + // + // int ComputeSum(std::vector& values) { + // int sum = 0; + // for (const int& v : values) { + // sum += v; + // } + // return sum; + // } + // + // ThreadPool pool = ...; + // std::vector numbers = ...; + // + // std::future sum_future = ScheduleAndGetFuture( + // []() { + // return ComputeSum(numbers); + // }); + // + // // Do other work... + // + // std::cout << "The sum is " << sum_future.get() << std::endl; + // + // *** Details *** + // Given a callable `func` that returns a value of type `RetT`, this + // function returns a std::future that can be used to access + // `func`'s results. + template + auto ScheduleAndGetFuture(FuncT&& func, ArgsT&&... args) + -> std::future; + + // Wait for all outstanding work to be completed. + void Wait(); + + // Return the number of outstanding functions to be executed. + int OutstandingWorkSize() const; + + // Return the number of threads in the pool. + int NumWorkers() const; + + void SetWorkDoneCallback(std::function func); + + private: + void ThreadLoop(); + + // Number of worker threads - fixed at construction time. + int num_workers_; + + // The destructor sets `exit_` to true and then notifies all workers. `exit_` + // causes each thread to break out of their work loop. + bool exit_; + + mutable std::mutex mu_; + + // Work queue. Guarded by `mu_`. + struct WorkItem { + std::function func; + }; + std::queue work_; + + // Condition variable used to notify worker threads that new work is + // available. + std::condition_variable condvar_; + + // Worker threads. + std::vector workers_; + + // Condition variable used to notify that all work is complete - the work + // queue has "run dry". + std::condition_variable work_done_condvar_; + + // Whenever a work item is complete, we call this callback. If this is empty, + // nothing is done. + std::function work_done_callback_; +}; + +namespace impl { + +// This helper class simply returns a std::function that executes: +// ReturnT x = func(); +// promise->set_value(x); +// However, this is tricky in the case where T == void. The code above won't +// compile if ReturnT == void, and neither will +// promise->set_value(func()); +// To workaround this, we use a template specialization for the case where +// ReturnT is void. If the "regular void" proposal is accepted, this could be +// simpler: +// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2016/p0146r1.html. + +// The non-specialized `FuncWrapper` implementation handles callables that +// return a non-void value. +template +struct FuncWrapper { + template + std::function GetWrapped( + FuncT&& func, std::shared_ptr> promise, + ArgsT&&... args) { + // TODO(cbraley): Capturing by value is inefficient. It would be more + // efficient to move-capture everything, but we can't do this until C++14 + // generalized lambda capture is available. Can we use std::bind instead to + // make this more efficient and still use C++11? + return [promise, func, args...]() mutable { + promise->set_value(INVOKE_MACRO(func, ArgsT, args)); + }; + } +}; + +template +void InvokeVoidRet(FuncT&& func, std::shared_ptr> promise, + ArgsT&&... args) { + INVOKE_MACRO(func, ArgsT, args); + promise->set_value(); +} + +// This `FuncWrapper` specialization handles callables that return void. +template <> +struct FuncWrapper { + template + std::function GetWrapped(FuncT&& func, + std::shared_ptr> promise, + ArgsT&&... args) { + return [promise, func, args...]() mutable { + INVOKE_MACRO(func, ArgsT, args); + promise->set_value(); + }; + } +}; + +} // namespace impl + +template +auto ThreadPool::ScheduleAndGetFuture(FuncT&& func, ArgsT&&... args) + -> std::future { + using ReturnT = decltype(INVOKE_MACRO(func, ArgsT, args)); + + // We are only allocating this std::promise in a shared_ptr because + // std::promise is non-copyable. + std::shared_ptr> promise = + std::make_shared>(); + std::future ret_future = promise->get_future(); + + impl::FuncWrapper func_wrapper; + std::function wrapped_func = func_wrapper.GetWrapped( + std::move(func), std::move(promise), std::forward(args)...); + + // Acquire the lock, and then push the WorkItem onto the queue. + { + std::lock_guard scoped_lock(mu_); + WorkItem work; + work.func = std::move(wrapped_func); + work_.emplace(std::move(work)); + } + condvar_.notify_one(); // Tell one worker we are ready. + return ret_future; +} + +} // namespace cb + +#undef INVOKE_MACRO diff --git a/common/common.vcxproj b/common/common.vcxproj index 9f8972df5a..8360d43933 100644 --- a/common/common.vcxproj +++ b/common/common.vcxproj @@ -67,6 +67,13 @@ + + + true + true + true + + @@ -141,6 +148,13 @@ + + + true + true + true + + @@ -159,6 +173,7 @@ + @@ -217,4 +232,4 @@ - \ No newline at end of file + diff --git a/common/common.vcxproj.filters b/common/common.vcxproj.filters index ad01da23f2..4ade96eef2 100644 --- a/common/common.vcxproj.filters +++ b/common/common.vcxproj.filters @@ -187,6 +187,15 @@ Source Files + + Source Files + + + Source Files + + + Source Files + @@ -441,6 +450,18 @@ Header Files + + Header Files + + + Header Files + + + Header Files + + + Header Files + @@ -487,4 +508,4 @@ Source Files - \ No newline at end of file +