diff --git a/src/httplib/httplib.h b/src/httplib/httplib.h index ea600db72..04a552f0a 100644 --- a/src/httplib/httplib.h +++ b/src/httplib/httplib.h @@ -1,13 +1,15 @@ // // httplib.h // -// Copyright (c) 2021 Yuji Hirose. All rights reserved. +// Copyright (c) 2022 Yuji Hirose. All rights reserved. // MIT License // #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H +#define CPPHTTPLIB_VERSION "0.10.7" + /* * Configuration */ @@ -60,6 +62,10 @@ #define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192 #endif +#ifndef CPPHTTPLIB_HEADER_MAX_LENGTH +#define CPPHTTPLIB_HEADER_MAX_LENGTH 8192 +#endif + #ifndef CPPHTTPLIB_REDIRECT_MAX_COUNT #define CPPHTTPLIB_REDIRECT_MAX_COUNT 20 #endif @@ -95,6 +101,10 @@ #define CPPHTTPLIB_SEND_FLAGS 0 #endif +#ifndef CPPHTTPLIB_LISTEN_BACKLOG +#define CPPHTTPLIB_LISTEN_BACKLOG 5 +#endif + /* * Headers */ @@ -134,8 +144,6 @@ using ssize_t = int; #include #include - -#include #include #ifndef WSA_FLAG_NO_HANDLE_INHERIT @@ -144,8 +152,6 @@ using ssize_t = int; #ifdef _MSC_VER #pragma comment(lib, "ws2_32.lib") -#pragma comment(lib, "crypt32.lib") -#pragma comment(lib, "cryptui.lib") #endif #ifndef strcasecmp @@ -162,6 +168,7 @@ using socket_t = SOCKET; #include #include #include +#include #include #include #ifdef __linux__ @@ -209,8 +216,24 @@ using socket_t = int; #include #ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef _WIN32 +#include + +// these are defined in wincrypt.h and it breaks compilation if BoringSSL is +// used +#undef X509_NAME +#undef X509_CERT_PAIR +#undef X509_EXTENSIONS +#undef PKCS7_SIGNER_INFO + +#ifdef _MSC_VER +#pragma comment(lib, "crypt32.lib") +#pragma comment(lib, "cryptui.lib") +#endif +#endif //_WIN32 + #include -#include +#include #include #include @@ -259,7 +282,7 @@ namespace detail { template typename std::enable_if::value, std::unique_ptr>::type -make_unique(Args &&... args) { +make_unique(Args &&...args) { return std::unique_ptr(new T(std::forward(args)...)); } @@ -492,7 +515,7 @@ public: virtual socket_t socket() const = 0; template - ssize_t write_format(const char *fmt, const Args &... args); + ssize_t write_format(const char *fmt, const Args &...args); ssize_t write(const char *ptr); ssize_t write(const std::string &s); }; @@ -505,7 +528,7 @@ public: virtual void enqueue(std::function fn) = 0; virtual void shutdown() = 0; - virtual void on_idle(){}; + virtual void on_idle() {} }; class ThreadPool : public TaskQueue { @@ -582,23 +605,7 @@ using Logger = std::function; using SocketOptions = std::function; -inline void default_socket_options(socket_t sock) { - int yes = 1; -#ifdef _WIN32 - setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), - sizeof(yes)); - setsockopt(sock, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, - reinterpret_cast(&yes), sizeof(yes)); -#else -#ifdef SO_REUSEPORT - setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast(&yes), - sizeof(yes)); -#else - setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), - sizeof(yes)); -#endif -#endif -} +void default_socket_options(socket_t sock); class Server { public: @@ -638,7 +645,7 @@ public: Server &Options(const std::string &pattern, Handler handler); bool set_base_dir(const std::string &dir, - const std::string &mount_point = nullptr); + const std::string &mount_point = std::string()); bool set_mount_point(const std::string &mount_point, const std::string &dir, Headers headers = Headers()); bool remove_mount_point(const std::string &mount_point); @@ -798,35 +805,12 @@ enum class Error { SSLServerVerification, UnsupportedMultipartBoundaryChars, Compression, + ConnectionTimeout, }; -inline std::string to_string(const Error error) { - switch (error) { - case Error::Success: return "Success"; - case Error::Connection: return "Connection"; - case Error::BindIPAddress: return "BindIPAddress"; - case Error::Read: return "Read"; - case Error::Write: return "Write"; - case Error::ExceedRedirectCount: return "ExceedRedirectCount"; - case Error::Canceled: return "Canceled"; - case Error::SSLConnection: return "SSLConnection"; - case Error::SSLLoadingCerts: return "SSLLoadingCerts"; - case Error::SSLServerVerification: return "SSLServerVerification"; - case Error::UnsupportedMultipartBoundaryChars: - return "UnsupportedMultipartBoundaryChars"; - case Error::Compression: return "Compression"; - case Error::Unknown: return "Unknown"; - default: break; - } +std::string to_string(const Error error); - return "Invalid"; -} - -inline std::ostream &operator<<(std::ostream &os, const Error &obj) { - os << to_string(obj); - os << " (" << static_cast::type>(obj) << ')'; - return os; -} +std::ostream &operator<<(std::ostream &os, const Error &obj); class Result { public: @@ -995,6 +979,8 @@ public: void stop(); + void set_hostname_addr_map(std::map addr_map); + void set_default_headers(Headers headers); void set_address_family(int family); @@ -1098,6 +1084,9 @@ protected: std::thread::id socket_requests_are_from_thread_ = std::thread::id(); bool socket_should_be_closed_when_request_is_done_ = false; + // Hostname-IP map + std::map addr_map_; + // Default headers Headers default_headers_; @@ -1201,6 +1190,8 @@ public: const std::string &client_cert_path, const std::string &client_key_path); + Client(Client &&) = default; + ~Client(); bool is_valid() const; @@ -1323,6 +1314,8 @@ public: void stop(); + void set_hostname_addr_map(std::map addr_map); + void set_default_headers(Headers headers); void set_address_family(int family); @@ -1397,15 +1390,21 @@ class SSLServer : public Server { public: SSLServer(const char *cert_path, const char *private_key_path, const char *client_ca_cert_file_path = nullptr, - const char *client_ca_cert_dir_path = nullptr); + const char *client_ca_cert_dir_path = nullptr, + const char *private_key_password = nullptr); SSLServer(X509 *cert, EVP_PKEY *private_key, X509_STORE *client_ca_cert_store = nullptr); + SSLServer( + const std::function &setup_ssl_ctx_callback); + ~SSLServer() override; bool is_valid() const override; + SSL_CTX *ssl_context() const; + private: bool process_and_close_socket(socket_t sock) override; @@ -1513,12 +1512,12 @@ inline T Response::get_header_value(const char *key, size_t id) const { } template -inline ssize_t Stream::write_format(const char *fmt, const Args &... args) { +inline ssize_t Stream::write_format(const char *fmt, const Args &...args) { const auto bufsiz = 2048; - std::array buf; + std::array buf{}; #if defined(_MSC_VER) && _MSC_VER < 1900 - auto sn = _snprintf_s(buf.data(), bufsiz - 1, buf.size() - 1, fmt, args...); + auto sn = _snprintf_s(buf.data(), bufsiz, _TRUNCATE, fmt, args...); #else auto sn = snprintf(buf.data(), buf.size() - 1, fmt, args...); #endif @@ -1546,6 +1545,24 @@ inline ssize_t Stream::write_format(const char *fmt, const Args &... args) { } } +inline void default_socket_options(socket_t sock) { + int yes = 1; +#ifdef _WIN32 + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), + sizeof(yes)); + setsockopt(sock, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, + reinterpret_cast(&yes), sizeof(yes)); +#else +#ifdef SO_REUSEPORT + setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast(&yes), + sizeof(yes)); +#else + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), + sizeof(yes)); +#endif +#endif +} + template inline Server & Server::set_read_timeout(const std::chrono::duration &duration) { @@ -1570,6 +1587,35 @@ Server::set_idle_interval(const std::chrono::duration &duration) { return *this; } +inline std::string to_string(const Error error) { + switch (error) { + case Error::Success: return "Success"; + case Error::Connection: return "Connection"; + case Error::BindIPAddress: return "BindIPAddress"; + case Error::Read: return "Read"; + case Error::Write: return "Write"; + case Error::ExceedRedirectCount: return "ExceedRedirectCount"; + case Error::Canceled: return "Canceled"; + case Error::SSLConnection: return "SSLConnection"; + case Error::SSLLoadingCerts: return "SSLLoadingCerts"; + case Error::SSLServerVerification: return "SSLServerVerification"; + case Error::UnsupportedMultipartBoundaryChars: + return "UnsupportedMultipartBoundaryChars"; + case Error::Compression: return "Compression"; + case Error::ConnectionTimeout: return "ConnectionTimeout"; + case Error::Unknown: return "Unknown"; + default: break; + } + + return "Invalid"; +} + +inline std::ostream &operator<<(std::ostream &os, const Error &obj) { + os << to_string(obj); + os << " (" << static_cast::type>(obj) << ')'; + return os; +} + template inline T Result::get_request_header_value(const char *key, size_t id) const { return detail::get_header_value(request_headers_, key, id, 0); @@ -1620,6 +1666,12 @@ Client::set_write_timeout(const std::chrono::duration &duration) { * .h + .cc. */ +std::string hosted_at(const char *hostname); + +void hosted_at(const char *hostname, std::vector &addrs); + +std::string append_query_params(const char *path, const Params ¶ms); + std::pair make_range_header(Ranges ranges); std::pair @@ -1631,6 +1683,8 @@ namespace detail { std::string encode_query_param(const std::string &value); +std::string decode_url(const std::string &s, bool convert_plus_to_space); + void read_file(const std::string &path, std::string &out); std::string trim_copy(const std::string &s); @@ -1643,14 +1697,12 @@ bool process_client_socket(socket_t sock, time_t read_timeout_sec, time_t write_timeout_usec, std::function callback); -socket_t create_client_socket(const char *host, int port, int address_family, - bool tcp_nodelay, SocketOptions socket_options, - time_t connection_timeout_sec, - time_t connection_timeout_usec, - time_t read_timeout_sec, time_t read_timeout_usec, - time_t write_timeout_sec, - time_t write_timeout_usec, - const std::string &intf, Error &error); +socket_t create_client_socket( + const char *host, const char *ip, int port, int address_family, + bool tcp_nodelay, SocketOptions socket_options, + time_t connection_timeout_sec, time_t connection_timeout_usec, + time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, const std::string &intf, Error &error); const char *get_header_value(const Headers &headers, const char *key, size_t id = 0, const char *def = nullptr); @@ -1663,6 +1715,10 @@ bool parse_range_header(const std::string &s, Ranges &ranges); int close_socket(socket_t sock); +ssize_t send_socket(socket_t sock, const void *ptr, size_t size, int flags); + +ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags); + enum class EncodingType { None = 0, Gzip, Brotli }; EncodingType encoding_type(const Request &req, const Response &res); @@ -1908,8 +1964,12 @@ inline std::string base64_encode(const std::string &in) { } inline bool is_file(const std::string &path) { +#ifdef _WIN32 + return _access_s(path.c_str(), 0) == 0; +#else struct stat st; return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); +#endif } inline bool is_dir(const std::string &path) { @@ -2181,6 +2241,31 @@ template inline ssize_t handle_EINTR(T fn) { return res; } +inline ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags) { + return handle_EINTR([&]() { + return recv(sock, +#ifdef _WIN32 + static_cast(ptr), static_cast(size), +#else + ptr, size, +#endif + flags); + }); +} + +inline ssize_t send_socket(socket_t sock, const void *ptr, size_t size, + int flags) { + return handle_EINTR([&]() { + return send(sock, +#ifdef _WIN32 + static_cast(ptr), static_cast(size), +#else + ptr, size, +#endif + flags); + }); +} + inline ssize_t select_read(socket_t sock, time_t sec, time_t usec) { #ifdef CPPHTTPLIB_USE_POLL struct pollfd pfd_read; @@ -2237,7 +2322,8 @@ inline ssize_t select_write(socket_t sock, time_t sec, time_t usec) { #endif } -inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { +inline Error wait_until_socket_is_ready(socket_t sock, time_t sec, + time_t usec) { #ifdef CPPHTTPLIB_USE_POLL struct pollfd pfd_read; pfd_read.fd = sock; @@ -2247,17 +2333,21 @@ inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { auto poll_res = handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); + if (poll_res == 0) { return Error::ConnectionTimeout; } + if (poll_res > 0 && pfd_read.revents & (POLLIN | POLLOUT)) { int error = 0; socklen_t len = sizeof(error); auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len); - return res >= 0 && !error; + auto successful = res >= 0 && !error; + return successful ? Error::Success : Error::Connection; } - return false; + + return Error::Connection; #else #ifndef _WIN32 - if (sock >= FD_SETSIZE) { return false; } + if (sock >= FD_SETSIZE) { return Error::Connection; } #endif fd_set fdsr; @@ -2275,17 +2365,31 @@ inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { return select(static_cast(sock + 1), &fdsr, &fdsw, &fdse, &tv); }); + if (ret == 0) { return Error::ConnectionTimeout; } + if (ret > 0 && (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { int error = 0; socklen_t len = sizeof(error); - return getsockopt(sock, SOL_SOCKET, SO_ERROR, - reinterpret_cast(&error), &len) >= 0 && - !error; + auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len); + auto successful = res >= 0 && !error; + return successful ? Error::Success : Error::Connection; } - return false; + return Error::Connection; #endif } +inline bool is_socket_alive(socket_t sock) { + const auto val = detail::select_read(sock, 0, 0); + if (val == 0) { + return true; + } else if (val < 0 && errno == EBADF) { + return false; + } + char buf[1]; + return detail::read_socket(sock, &buf[0], sizeof(buf), MSG_PEEK) > 0; +} + class SocketStream : public Stream { public: SocketStream(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, @@ -2305,6 +2409,12 @@ private: time_t read_timeout_usec_; time_t write_timeout_sec_; time_t write_timeout_usec_; + + std::vector read_buff_; + size_t read_buff_off_ = 0; + size_t read_buff_content_size_ = 0; + + static const size_t read_buff_size_ = 1024 * 4; }; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT @@ -2353,12 +2463,14 @@ inline bool keep_alive(socket_t sock, time_t keep_alive_timeout_sec) { template inline bool -process_server_socket_core(socket_t sock, size_t keep_alive_max_count, +process_server_socket_core(const std::atomic &svr_sock, socket_t sock, + size_t keep_alive_max_count, time_t keep_alive_timeout_sec, T callback) { assert(keep_alive_max_count > 0); auto ret = false; auto count = keep_alive_max_count; - while (count > 0 && keep_alive(sock, keep_alive_timeout_sec)) { + while (svr_sock != INVALID_SOCKET && count > 0 && + keep_alive(sock, keep_alive_timeout_sec)) { auto close_connection = count == 1; auto connection_closed = false; ret = callback(close_connection, connection_closed); @@ -2370,12 +2482,13 @@ process_server_socket_core(socket_t sock, size_t keep_alive_max_count, template inline bool -process_server_socket(socket_t sock, size_t keep_alive_max_count, +process_server_socket(const std::atomic &svr_sock, socket_t sock, + size_t keep_alive_max_count, time_t keep_alive_timeout_sec, time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, time_t write_timeout_usec, T callback) { return process_server_socket_core( - sock, keep_alive_max_count, keep_alive_timeout_sec, + svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec, [&](bool close_connection, bool &connection_closed) { SocketStream strm(sock, read_timeout_sec, read_timeout_usec, write_timeout_sec, write_timeout_usec); @@ -2402,23 +2515,33 @@ inline int shutdown_socket(socket_t sock) { } template -socket_t create_socket(const char *host, int port, int address_family, - int socket_flags, bool tcp_nodelay, +socket_t create_socket(const char *host, const char *ip, int port, + int address_family, int socket_flags, bool tcp_nodelay, SocketOptions socket_options, BindOrConnect bind_or_connect) { // Get address info + const char *node = nullptr; struct addrinfo hints; struct addrinfo *result; memset(&hints, 0, sizeof(struct addrinfo)); - hints.ai_family = address_family; hints.ai_socktype = SOCK_STREAM; - hints.ai_flags = socket_flags; hints.ai_protocol = 0; + if (ip[0] != '\0') { + node = ip; + // Ask getaddrinfo to convert IP in c-string to address + hints.ai_family = AF_UNSPEC; + hints.ai_flags = AI_NUMERICHOST; + } else { + node = host; + hints.ai_family = address_family; + hints.ai_flags = socket_flags; + } + auto service = std::to_string(port); - if (getaddrinfo(host, service.c_str(), &hints, &result)) { + if (getaddrinfo(node, service.c_str(), &hints, &result)) { #if defined __linux__ && !defined __ANDROID__ res_init(); #endif @@ -2532,11 +2655,14 @@ inline bool bind_ip_address(socket_t sock, const char *host) { #endif #ifdef USE_IF2IP -inline std::string if2ip(const std::string &ifn) { +inline std::string if2ip(int address_family, const std::string &ifn) { struct ifaddrs *ifap; getifaddrs(&ifap); + std::string addr_candidate; for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { - if (ifa->ifa_addr && ifn == ifa->ifa_name) { + if (ifa->ifa_addr && ifn == ifa->ifa_name && + (AF_UNSPEC == address_family || + ifa->ifa_addr->sa_family == address_family)) { if (ifa->ifa_addr->sa_family == AF_INET) { auto sa = reinterpret_cast(ifa->ifa_addr); char buf[INET_ADDRSTRLEN]; @@ -2544,26 +2670,41 @@ inline std::string if2ip(const std::string &ifn) { freeifaddrs(ifap); return std::string(buf, INET_ADDRSTRLEN); } + } else if (ifa->ifa_addr->sa_family == AF_INET6) { + auto sa = reinterpret_cast(ifa->ifa_addr); + if (!IN6_IS_ADDR_LINKLOCAL(&sa->sin6_addr)) { + char buf[INET6_ADDRSTRLEN] = {}; + if (inet_ntop(AF_INET6, &sa->sin6_addr, buf, INET6_ADDRSTRLEN)) { + // equivalent to mac's IN6_IS_ADDR_UNIQUE_LOCAL + auto s6_addr_head = sa->sin6_addr.s6_addr[0]; + if (s6_addr_head == 0xfc || s6_addr_head == 0xfd) { + addr_candidate = std::string(buf, INET6_ADDRSTRLEN); + } else { + freeifaddrs(ifap); + return std::string(buf, INET6_ADDRSTRLEN); + } + } + } } } } freeifaddrs(ifap); - return std::string(); + return addr_candidate; } #endif inline socket_t create_client_socket( - const char *host, int port, int address_family, bool tcp_nodelay, - SocketOptions socket_options, time_t connection_timeout_sec, - time_t connection_timeout_usec, time_t read_timeout_sec, - time_t read_timeout_usec, time_t write_timeout_sec, + const char *host, const char *ip, int port, int address_family, + bool tcp_nodelay, SocketOptions socket_options, + time_t connection_timeout_sec, time_t connection_timeout_usec, + time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, time_t write_timeout_usec, const std::string &intf, Error &error) { auto sock = create_socket( - host, port, address_family, 0, tcp_nodelay, std::move(socket_options), + host, ip, port, address_family, 0, tcp_nodelay, std::move(socket_options), [&](socket_t sock2, struct addrinfo &ai) -> bool { if (!intf.empty()) { #ifdef USE_IF2IP - auto ip = if2ip(intf); + auto ip = if2ip(address_family, intf); if (ip.empty()) { ip = intf; } if (!bind_ip_address(sock2, ip.c_str())) { error = Error::BindIPAddress; @@ -2578,27 +2719,43 @@ inline socket_t create_client_socket( ::connect(sock2, ai.ai_addr, static_cast(ai.ai_addrlen)); if (ret < 0) { - if (is_connection_error() || - !wait_until_socket_is_ready(sock2, connection_timeout_sec, - connection_timeout_usec)) { + if (is_connection_error()) { error = Error::Connection; return false; } + error = wait_until_socket_is_ready(sock2, connection_timeout_sec, + connection_timeout_usec); + if (error != Error::Success) { return false; } } set_nonblocking(sock2, false); { +#ifdef _WIN32 + auto timeout = static_cast(read_timeout_sec * 1000 + + read_timeout_usec / 1000); + setsockopt(sock2, SOL_SOCKET, SO_RCVTIMEO, (char *)&timeout, + sizeof(timeout)); +#else timeval tv; tv.tv_sec = static_cast(read_timeout_sec); tv.tv_usec = static_cast(read_timeout_usec); setsockopt(sock2, SOL_SOCKET, SO_RCVTIMEO, (char *)&tv, sizeof(tv)); +#endif } { + +#ifdef _WIN32 + auto timeout = static_cast(write_timeout_sec * 1000 + + write_timeout_usec / 1000); + setsockopt(sock2, SOL_SOCKET, SO_SNDTIMEO, (char *)&timeout, + sizeof(timeout)); +#else timeval tv; tv.tv_sec = static_cast(write_timeout_sec); tv.tv_usec = static_cast(write_timeout_usec); setsockopt(sock2, SOL_SOCKET, SO_SNDTIMEO, (char *)&tv, sizeof(tv)); +#endif } error = Error::Success; @@ -2614,7 +2771,7 @@ inline socket_t create_client_socket( return sock; } -inline void get_remote_ip_and_port(const struct sockaddr_storage &addr, +inline bool get_remote_ip_and_port(const struct sockaddr_storage &addr, socklen_t addr_len, std::string &ip, int &port) { if (addr.ss_family == AF_INET) { @@ -2622,14 +2779,19 @@ inline void get_remote_ip_and_port(const struct sockaddr_storage &addr, } else if (addr.ss_family == AF_INET6) { port = ntohs(reinterpret_cast(&addr)->sin6_port); + } else { + return false; } std::array ipstr{}; - if (!getnameinfo(reinterpret_cast(&addr), addr_len, - ipstr.data(), static_cast(ipstr.size()), nullptr, - 0, NI_NUMERICHOST)) { - ip = ipstr.data(); + if (getnameinfo(reinterpret_cast(&addr), addr_len, + ipstr.data(), static_cast(ipstr.size()), nullptr, + 0, NI_NUMERICHOST)) { + return false; } + + ip = ipstr.data(); + return true; } inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) { @@ -2675,10 +2837,12 @@ find_content_type(const std::string &path, default: return nullptr; case "css"_t: return "text/css"; case "csv"_t: return "text/csv"; - case "txt"_t: return "text/plain"; - case "vtt"_t: return "text/vtt"; case "htm"_t: case "html"_t: return "text/html"; + case "js"_t: + case "mjs"_t: return "text/javascript"; + case "txt"_t: return "text/plain"; + case "vtt"_t: return "text/vtt"; case "apng"_t: return "image/apng"; case "avif"_t: return "image/avif"; @@ -2710,8 +2874,6 @@ find_content_type(const std::string &path, case "7z"_t: return "application/x-7z-compressed"; case "atom"_t: return "application/atom+xml"; case "pdf"_t: return "application/pdf"; - case "js"_t: - case "mjs"_t: return "application/javascript"; case "json"_t: return "application/json"; case "rss"_t: return "application/rss+xml"; case "tar"_t: return "application/x-tar"; @@ -2796,12 +2958,21 @@ inline const char *status_message(int status) { } inline bool can_compress_content_type(const std::string &content_type) { - return (!content_type.find("text/") && content_type != "text/event-stream") || - content_type == "image/svg+xml" || - content_type == "application/javascript" || - content_type == "application/json" || - content_type == "application/xml" || - content_type == "application/xhtml+xml"; + using udl::operator""_t; + + auto tag = str2tag(content_type); + + switch (tag) { + case "image/svg+xml"_t: + case "application/javascript"_t: + case "application/json"_t: + case "application/xml"_t: + case "application/protobuf"_t: + case "application/xhtml+xml"_t: return true; + + default: + return !content_type.rfind("text/", 0) && tag != "text/event-stream"_t; + } } inline EncodingType encoding_type(const Request &req, const Response &res) { @@ -2852,10 +3023,10 @@ inline bool gzip_compressor::compress(const char *data, size_t data_length, do { constexpr size_t max_avail_in = - std::numeric_limits::max(); + (std::numeric_limits::max)(); strm_.avail_in = static_cast( - std::min(data_length, max_avail_in)); + (std::min)(data_length, max_avail_in)); strm_.next_in = const_cast(reinterpret_cast(data)); data_length -= strm_.avail_in; @@ -2880,7 +3051,6 @@ inline bool gzip_compressor::compress(const char *data, size_t data_length, assert((flush == Z_FINISH && ret == Z_STREAM_END) || (flush == Z_NO_FLUSH && ret == Z_OK)); assert(strm_.avail_in == 0); - } while (data_length > 0); return true; @@ -2911,10 +3081,10 @@ inline bool gzip_decompressor::decompress(const char *data, size_t data_length, do { constexpr size_t max_avail_in = - std::numeric_limits::max(); + (std::numeric_limits::max)(); strm_.avail_in = static_cast( - std::min(data_length, max_avail_in)); + (std::min)(data_length, max_avail_in)); strm_.next_in = const_cast(reinterpret_cast(data)); data_length -= strm_.avail_in; @@ -2925,7 +3095,12 @@ inline bool gzip_decompressor::decompress(const char *data, size_t data_length, strm_.avail_out = static_cast(buff.size()); strm_.next_out = reinterpret_cast(buff.data()); + auto prev_avail_in = strm_.avail_in; + ret = inflate(&strm_, Z_NO_FLUSH); + + if (prev_avail_in - strm_.avail_in == 0) { return false; } + assert(ret != Z_STREAM_ERROR); switch (ret) { case Z_NEED_DICT: @@ -3084,15 +3259,26 @@ inline bool read_headers(Stream &strm, Headers &headers) { if (!line_reader.getline()) { return false; } // Check if the line ends with CRLF. + auto line_terminator_len = 2; if (line_reader.end_with_crlf()) { // Blank line indicates end of headers. if (line_reader.size() == 2) { break; } +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + } else { + // Blank line indicates end of headers. + if (line_reader.size() == 1) { break; } + line_terminator_len = 1; + } +#else } else { continue; // Skip invalid line. } +#endif - // Exclude CRLF - auto end = line_reader.ptr() + line_reader.size() - 2; + if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } + + // Exclude line terminator + auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; parse_header(line_reader.ptr(), end, [&](std::string &&key, std::string &&val) { @@ -3208,8 +3394,7 @@ bool prepare_content_receiver(T &x, int &status, std::string encoding = x.get_header_value("Content-Encoding"); std::unique_ptr decompressor; - if (encoding.find("gzip") != std::string::npos || - encoding.find("deflate") != std::string::npos) { + if (encoding == "gzip" || encoding == "deflate") { #ifdef CPPHTTPLIB_ZLIB_SUPPORT decompressor = detail::make_unique(); #else @@ -3277,7 +3462,7 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, if (!ret) { status = exceed_payload_max_length ? 413 : 400; } return ret; }); -} +} // namespace detail inline ssize_t write_headers(Stream &strm, const Headers &headers) { ssize_t write_len = 0; @@ -3503,14 +3688,6 @@ inline std::string params_to_query_str(const Params ¶ms) { return query; } -inline std::string append_query_params(const char *path, const Params ¶ms) { - std::string path_with_query = path; - const static std::regex re("[^?]+\\?.*"); - auto delm = std::regex_match(path, re) ? '&' : '?'; - path_with_query += delm + params_to_query_str(params); - return path_with_query; -} - inline void parse_query_text(const std::string &s, Params ¶ms) { std::set cache; split(s.data(), s.data() + s.size(), '&', [&](const char *b, const char *e) { @@ -3546,7 +3723,11 @@ inline bool parse_multipart_boundary(const std::string &content_type, return !boundary.empty(); } +#ifdef CPPHTTPLIB_NO_EXCEPTIONS +inline bool parse_range_header(const std::string &s, Ranges &ranges) { +#else inline bool parse_range_header(const std::string &s, Ranges &ranges) try { +#endif static auto re_first_range = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); std::smatch m; if (std::regex_match(s, m, re_first_range)) { @@ -3578,7 +3759,11 @@ inline bool parse_range_header(const std::string &s, Ranges &ranges) try { return all_valid_ranges; } return false; +#ifdef CPPHTTPLIB_NO_EXCEPTIONS +} +#else } catch (...) { return false; } +#endif class MultipartFormDataParser { public: @@ -3591,24 +3776,23 @@ public: bool parse(const char *buf, size_t n, const ContentReceiver &content_callback, const MultipartContentHeader &header_callback) { + // TODO: support 'filename*' static const std::regex re_content_disposition( - "^Content-Disposition:\\s*form-data;\\s*name=\"(.*?)\"(?:;\\s*filename=" - "\"(.*?)\")?\\s*$", + R"~(^Content-Disposition:\s*form-data;\s*name="(.*?)"(?:;\s*filename="(.*?)")?(?:;\s*filename\*=\S+)?\s*$)~", std::regex_constants::icase); + static const std::string dash_ = "--"; static const std::string crlf_ = "\r\n"; - buf_.append(buf, n); // TODO: performance improvement + buf_append(buf, n); - while (!buf_.empty()) { + while (buf_size() > 0) { switch (state_) { case 0: { // Initial boundary auto pattern = dash_ + boundary_ + crlf_; - if (pattern.size() > buf_.size()) { return true; } - auto pos = buf_.find(pattern); - if (pos != 0) { return false; } - buf_.erase(0, pattern.size()); - off_ += pattern.size(); + if (pattern.size() > buf_size()) { return true; } + if (!buf_start_with(pattern)) { return false; } + buf_erase(pattern.size()); state_ = 1; break; } @@ -3618,22 +3802,22 @@ public: break; } case 2: { // Headers - auto pos = buf_.find(crlf_); - while (pos != std::string::npos) { + auto pos = buf_find(crlf_); + if (pos > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } + while (pos < buf_size()) { // Empty line if (pos == 0) { if (!header_callback(file_)) { is_valid_ = false; return false; } - buf_.erase(0, crlf_.size()); - off_ += crlf_.size(); + buf_erase(crlf_.size()); state_ = 3; break; } static const std::string header_name = "content-type:"; - const auto header = buf_.substr(0, pos); + const auto header = buf_head(pos); if (start_with_case_ignore(header, header_name)) { file_.content_type = trim_copy(header.substr(header_name.size())); } else { @@ -3644,9 +3828,8 @@ public: } } - buf_.erase(0, pos + crlf_.size()); - off_ += pos + crlf_.size(); - pos = buf_.find(crlf_); + buf_erase(pos + crlf_.size()); + pos = buf_find(crlf_); } if (state_ != 3) { return true; } break; @@ -3654,56 +3837,51 @@ public: case 3: { // Body { auto pattern = crlf_ + dash_; - if (pattern.size() > buf_.size()) { return true; } + if (pattern.size() > buf_size()) { return true; } - auto pos = find_string(buf_, pattern); + auto pos = buf_find(pattern); - if (!content_callback(buf_.data(), pos)) { + if (!content_callback(buf_data(), pos)) { is_valid_ = false; return false; } - off_ += pos; - buf_.erase(0, pos); + buf_erase(pos); } { auto pattern = crlf_ + dash_ + boundary_; - if (pattern.size() > buf_.size()) { return true; } + if (pattern.size() > buf_size()) { return true; } - auto pos = buf_.find(pattern); - if (pos != std::string::npos) { - if (!content_callback(buf_.data(), pos)) { + auto pos = buf_find(pattern); + if (pos < buf_size()) { + if (!content_callback(buf_data(), pos)) { is_valid_ = false; return false; } - off_ += pos + pattern.size(); - buf_.erase(0, pos + pattern.size()); + buf_erase(pos + pattern.size()); state_ = 4; } else { - if (!content_callback(buf_.data(), pattern.size())) { + if (!content_callback(buf_data(), pattern.size())) { is_valid_ = false; return false; } - off_ += pattern.size(); - buf_.erase(0, pattern.size()); + buf_erase(pattern.size()); } } break; } case 4: { // Boundary - if (crlf_.size() > buf_.size()) { return true; } - if (buf_.compare(0, crlf_.size(), crlf_) == 0) { - buf_.erase(0, crlf_.size()); - off_ += crlf_.size(); + if (crlf_.size() > buf_size()) { return true; } + if (buf_start_with(crlf_)) { + buf_erase(crlf_.size()); state_ = 1; } else { auto pattern = dash_ + crlf_; - if (pattern.size() > buf_.size()) { return true; } - if (buf_.compare(0, pattern.size(), pattern) == 0) { - buf_.erase(0, pattern.size()); - off_ += pattern.size(); + if (pattern.size() > buf_size()) { return true; } + if (buf_start_with(pattern)) { + buf_erase(pattern.size()); is_valid_ = true; state_ = 5; } else { @@ -3738,41 +3916,78 @@ private: return true; } - bool start_with(const std::string &a, size_t off, + std::string boundary_; + + size_t state_ = 0; + bool is_valid_ = false; + MultipartFormData file_; + + // Buffer + bool start_with(const std::string &a, size_t spos, size_t epos, const std::string &b) const { - if (a.size() - off < b.size()) { return false; } + if (epos - spos < b.size()) { return false; } for (size_t i = 0; i < b.size(); i++) { - if (a[i + off] != b[i]) { return false; } + if (a[i + spos] != b[i]) { return false; } } return true; } - size_t find_string(const std::string &s, const std::string &pattern) const { - auto c = pattern.front(); + size_t buf_size() const { return buf_epos_ - buf_spos_; } - size_t off = 0; - while (off < s.size()) { - auto pos = s.find(c, off); - if (pos == std::string::npos) { return s.size(); } + const char *buf_data() const { return &buf_[buf_spos_]; } - auto rem = s.size() - pos; - if (pattern.size() > rem) { return pos; } + std::string buf_head(size_t l) const { return buf_.substr(buf_spos_, l); } - if (start_with(s, pos, pattern)) { return pos; } + bool buf_start_with(const std::string &s) const { + return start_with(buf_, buf_spos_, buf_epos_, s); + } + + size_t buf_find(const std::string &s) const { + auto c = s.front(); + + size_t off = buf_spos_; + while (off < buf_epos_) { + auto pos = off; + while (true) { + if (pos == buf_epos_) { return buf_size(); } + if (buf_[pos] == c) { break; } + pos++; + } + + auto remaining_size = buf_epos_ - pos; + if (s.size() > remaining_size) { return buf_size(); } + + if (start_with(buf_, pos, buf_epos_, s)) { return pos - buf_spos_; } off = pos + 1; } - return s.size(); + return buf_size(); } - std::string boundary_; + void buf_append(const char *data, size_t n) { + auto remaining_size = buf_size(); + if (remaining_size > 0 && buf_spos_ > 0) { + for (size_t i = 0; i < remaining_size; i++) { + buf_[i] = buf_[buf_spos_ + i]; + } + } + buf_spos_ = 0; + buf_epos_ = remaining_size; + + if (remaining_size + n > buf_.size()) { buf_.resize(remaining_size + n); } + + for (size_t i = 0; i < n; i++) { + buf_[buf_epos_ + i] = data[i]; + } + buf_epos_ += n; + } + + void buf_erase(size_t size) { buf_spos_ += size; } std::string buf_; - size_t state_ = 0; - bool is_valid_ = false; - size_t off_ = 0; - MultipartFormData file_; + size_t buf_spos_ = 0; + size_t buf_epos_ = 0; }; inline std::string to_lower(const char *beg, const char *end) { @@ -3793,6 +4008,7 @@ inline std::string make_multipart_data_boundary() { // platforms, but due to lack of support in the c++ standard library, // doing better requires either some ugly hacks or breaking portability. std::random_device seed_gen; + // Request 128 bits of entropy for initialization std::seed_seq seed_sequence{seed_gen(), seed_gen(), seed_gen(), seed_gen()}; std::mt19937 engine(seed_sequence); @@ -3954,38 +4170,36 @@ inline bool has_crlf(const char *s) { } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT -template -inline std::string message_digest(const std::string &s, Init init, - Update update, Final final, - size_t digest_length) { - using namespace std; +inline std::string message_digest(const std::string &s, const EVP_MD *algo) { + auto context = std::unique_ptr( + EVP_MD_CTX_new(), EVP_MD_CTX_free); - std::vector md(digest_length, 0); - CTX ctx; - init(&ctx); - update(&ctx, s.data(), s.size()); - final(md.data(), &ctx); + unsigned int hash_length = 0; + unsigned char hash[EVP_MAX_MD_SIZE]; - stringstream ss; - for (auto c : md) { - ss << setfill('0') << setw(2) << hex << (unsigned int)c; + EVP_DigestInit_ex(context.get(), algo, nullptr); + EVP_DigestUpdate(context.get(), s.c_str(), s.size()); + EVP_DigestFinal_ex(context.get(), hash, &hash_length); + + std::stringstream ss; + for (auto i = 0u; i < hash_length; ++i) { + ss << std::hex << std::setw(2) << std::setfill('0') + << (unsigned int)hash[i]; } + return ss.str(); } inline std::string MD5(const std::string &s) { - return message_digest(s, MD5_Init, MD5_Update, MD5_Final, - MD5_DIGEST_LENGTH); + return message_digest(s, EVP_md5()); } inline std::string SHA_256(const std::string &s) { - return message_digest(s, SHA256_Init, SHA256_Update, SHA256_Final, - SHA256_DIGEST_LENGTH); + return message_digest(s, EVP_sha256()); } inline std::string SHA_512(const std::string &s) { - return message_digest(s, SHA512_Init, SHA512_Update, SHA512_Final, - SHA512_DIGEST_LENGTH); + return message_digest(s, EVP_sha512()); } #endif @@ -4022,10 +4236,14 @@ class WSInit { public: WSInit() { WSADATA wsaData; - WSAStartup(0x0002, &wsaData); + if (WSAStartup(0x0002, &wsaData) == 0) is_valid_ = true; } - ~WSInit() { WSACleanup(); } + ~WSInit() { + if (is_valid_) WSACleanup(); + } + + bool is_valid_ = false; }; static WSInit wsinit_; @@ -4036,45 +4254,57 @@ inline std::pair make_digest_authentication_header( const Request &req, const std::map &auth, size_t cnonce_count, const std::string &cnonce, const std::string &username, const std::string &password, bool is_proxy = false) { - using namespace std; - - string nc; + std::string nc; { - stringstream ss; - ss << setfill('0') << setw(8) << hex << cnonce_count; + std::stringstream ss; + ss << std::setfill('0') << std::setw(8) << std::hex << cnonce_count; nc = ss.str(); } - auto qop = auth.at("qop"); - if (qop.find("auth-int") != std::string::npos) { - qop = "auth-int"; - } else { - qop = "auth"; + std::string qop; + if (auth.find("qop") != auth.end()) { + qop = auth.at("qop"); + if (qop.find("auth-int") != std::string::npos) { + qop = "auth-int"; + } else if (qop.find("auth") != std::string::npos) { + qop = "auth"; + } else { + qop.clear(); + } } std::string algo = "MD5"; if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } - string response; + std::string response; { - auto H = algo == "SHA-256" - ? detail::SHA_256 - : algo == "SHA-512" ? detail::SHA_512 : detail::MD5; + auto H = algo == "SHA-256" ? detail::SHA_256 + : algo == "SHA-512" ? detail::SHA_512 + : detail::MD5; auto A1 = username + ":" + auth.at("realm") + ":" + password; auto A2 = req.method + ":" + req.path; if (qop == "auth-int") { A2 += ":" + H(req.body); } - response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + - ":" + qop + ":" + H(A2)); + if (qop.empty()) { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + H(A2)); + } else { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + + ":" + qop + ":" + H(A2)); + } } + auto opaque = (auth.find("opaque") != auth.end()) ? auth.at("opaque") : ""; + auto field = "Digest username=\"" + username + "\", realm=\"" + auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + "\", uri=\"" + req.path + "\", algorithm=" + algo + - ", qop=" + qop + ", nc=\"" + nc + "\", cnonce=\"" + cnonce + - "\", response=\"" + response + "\""; + (qop.empty() ? ", response=\"" + : ", qop=" + qop + ", nc=" + nc + ", cnonce=\"" + + cnonce + "\", response=\"") + + response + "\"" + + (opaque.empty() ? "" : ", opaque=\"" + opaque + "\""); auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; return std::make_pair(key, field); @@ -4144,6 +4374,49 @@ private: } // namespace detail +inline std::string hosted_at(const char *hostname) { + std::vector addrs; + hosted_at(hostname, addrs); + if (addrs.empty()) { return std::string(); } + return addrs[0]; +} + +inline void hosted_at(const char *hostname, std::vector &addrs) { + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (getaddrinfo(hostname, nullptr, &hints, &result)) { +#if defined __linux__ && !defined __ANDROID__ + res_init(); +#endif + return; + } + + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &addr = + *reinterpret_cast(rp->ai_addr); + std::string ip; + int dummy = -1; + if (detail::get_remote_ip_and_port(addr, sizeof(struct sockaddr_storage), + ip, dummy)) { + addrs.push_back(ip); + } + } +} + +inline std::string append_query_params(const char *path, const Params ¶ms) { + std::string path_with_query = path; + const static std::regex re("[^?]+\\?.*"); + auto delm = std::regex_match(path, re) ? '&' : '?'; + path_with_query += delm + detail::params_to_query_str(params); + return path_with_query; +} + // Header utilities inline std::pair make_range_header(Ranges ranges) { std::string field = "bytes="; @@ -4219,7 +4492,7 @@ inline size_t Request::get_param_value_count(const char *key) const { inline bool Request::is_multipart_form_data() const { const auto &content_type = get_header_value("Content-Type"); - return !content_type.find("multipart/form-data"); + return !content_type.rfind("multipart/form-data", 0); } inline bool Request::has_file(const char *key) const { @@ -4353,7 +4626,7 @@ inline SocketStream::SocketStream(socket_t sock, time_t read_timeout_sec, : sock_(sock), read_timeout_sec_(read_timeout_sec), read_timeout_usec_(read_timeout_usec), write_timeout_sec_(write_timeout_sec), - write_timeout_usec_(write_timeout_usec) {} + write_timeout_usec_(write_timeout_usec), read_buff_(read_buff_size_, 0) {} inline SocketStream::~SocketStream() {} @@ -4366,31 +4639,60 @@ inline bool SocketStream::is_writable() const { } inline ssize_t SocketStream::read(char *ptr, size_t size) { +#ifdef _WIN32 + size = + (std::min)(size, static_cast((std::numeric_limits::max)())); +#else + size = (std::min)(size, + static_cast((std::numeric_limits::max)())); +#endif + + if (read_buff_off_ < read_buff_content_size_) { + auto remaining_size = read_buff_content_size_ - read_buff_off_; + if (size <= remaining_size) { + memcpy(ptr, read_buff_.data() + read_buff_off_, size); + read_buff_off_ += size; + return static_cast(size); + } else { + memcpy(ptr, read_buff_.data() + read_buff_off_, remaining_size); + read_buff_off_ += remaining_size; + return static_cast(remaining_size); + } + } + if (!is_readable()) { return -1; } -#ifdef _WIN32 - if (size > static_cast((std::numeric_limits::max)())) { - return -1; + read_buff_off_ = 0; + read_buff_content_size_ = 0; + + if (size < read_buff_size_) { + auto n = read_socket(sock_, read_buff_.data(), read_buff_size_, + CPPHTTPLIB_RECV_FLAGS); + if (n <= 0) { + return n; + } else if (n <= static_cast(size)) { + memcpy(ptr, read_buff_.data(), static_cast(n)); + return n; + } else { + memcpy(ptr, read_buff_.data(), size); + read_buff_off_ = size; + read_buff_content_size_ = static_cast(n); + return static_cast(size); + } + } else { + return read_socket(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS); } - return recv(sock_, ptr, static_cast(size), CPPHTTPLIB_RECV_FLAGS); -#else - return handle_EINTR( - [&]() { return recv(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS); }); -#endif } inline ssize_t SocketStream::write(const char *ptr, size_t size) { if (!is_writable()) { return -1; } #ifdef _WIN32 - if (size > static_cast((std::numeric_limits::max)())) { - return -1; - } - return send(sock_, ptr, static_cast(size), CPPHTTPLIB_SEND_FLAGS); -#else - return handle_EINTR( - [&]() { return send(sock_, ptr, size, CPPHTTPLIB_SEND_FLAGS); }); + size = + (std::min)(size, static_cast((std::numeric_limits::max)())); #endif + + return send_socket(sock_, ptr, size, CPPHTTPLIB_SEND_FLAGS); } inline void SocketStream::get_remote_ip_and_port(std::string &ip, @@ -4692,6 +4994,14 @@ inline bool Server::parse_request_line(const char *s, Request &req) { if (req.version != "HTTP/1.1" && req.version != "HTTP/1.0") { return false; } { + // Skip URL fragment + for (size_t i = 0; i < req.target.size(); i++) { + if (req.target[i] == '#') { + req.target.erase(i); + break; + } + } + size_t count = 0; detail::split(req.target.data(), req.target.data() + req.target.size(), '?', @@ -4879,6 +5189,10 @@ inline bool Server::read_content(Stream &strm, Request &req, Response &res) { })) { const auto &content_type = req.get_header_value("Content-Type"); if (!content_type.find("application/x-www-form-urlencoded")) { + if (req.body.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + res.status = 413; // NOTE: should be 414? + return false; + } detail::parse_query_text(req.body, req.params); } return true; @@ -4915,7 +5229,7 @@ inline bool Server::read_content_core(Stream &strm, Request &req, Response &res, /* For debug size_t pos = 0; while (pos < n) { - auto read_size = std::min(1, n - pos); + auto read_size = (std::min)(1, n - pos); auto ret = multipart_form_data_parser.parse( buf + pos, read_size, multipart_receiver, mulitpart_header); if (!ret) { return false; } @@ -4984,15 +5298,13 @@ inline socket_t Server::create_server_socket(const char *host, int port, int socket_flags, SocketOptions socket_options) const { return detail::create_socket( - host, port, address_family_, socket_flags, tcp_nodelay_, + host, "", port, address_family_, socket_flags, tcp_nodelay_, std::move(socket_options), [](socket_t sock, struct addrinfo &ai) -> bool { if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { return false; } - if (::listen(sock, 5)) { // Listen through 5 channels - return false; - } + if (::listen(sock, CPPHTTPLIB_LISTEN_BACKLOG)) { return false; } return true; }); } @@ -5061,16 +5373,31 @@ inline bool Server::listen_internal() { } { +#ifdef _WIN32 + auto timeout = static_cast(read_timeout_sec_ * 1000 + + read_timeout_usec_ / 1000); + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, (char *)&timeout, + sizeof(timeout)); +#else timeval tv; tv.tv_sec = static_cast(read_timeout_sec_); tv.tv_usec = static_cast(read_timeout_usec_); setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, (char *)&tv, sizeof(tv)); +#endif } { + +#ifdef _WIN32 + auto timeout = static_cast(write_timeout_sec_ * 1000 + + write_timeout_usec_ / 1000); + setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, (char *)&timeout, + sizeof(timeout)); +#else timeval tv; tv.tv_sec = static_cast(write_timeout_sec_); tv.tv_usec = static_cast(write_timeout_usec_); setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, (char *)&tv, sizeof(tv)); +#endif } #if __cplusplus > 201703L @@ -5394,6 +5721,9 @@ Server::process_request(Stream &strm, bool close_connection, // Rounting bool routed = false; +#ifdef CPPHTTPLIB_NO_EXCEPTIONS + routed = routing(req, res, strm); +#else try { routed = routing(req, res, strm); } catch (std::exception &e) { @@ -5408,6 +5738,7 @@ Server::process_request(Stream &strm, bool close_connection, res.status = 500; res.set_header("EXCEPTION_WHAT", "UNKNOWN"); } +#endif if (routed) { if (res.status == -1) { res.status = req.ranges.empty() ? 200 : 206; } @@ -5422,8 +5753,9 @@ inline bool Server::is_valid() const { return true; } inline bool Server::process_and_close_socket(socket_t sock) { auto ret = detail::process_server_socket( - sock, keep_alive_max_count_, keep_alive_timeout_sec_, read_timeout_sec_, - read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, + svr_sock_, sock, keep_alive_max_count_, keep_alive_timeout_sec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, [this](Stream &strm, bool close_connection, bool &connection_closed) { return process_request(strm, close_connection, connection_closed, nullptr); @@ -5503,16 +5835,22 @@ inline void ClientImpl::copy_settings(const ClientImpl &rhs) { inline socket_t ClientImpl::create_client_socket(Error &error) const { if (!proxy_host_.empty() && proxy_port_ != -1) { return detail::create_client_socket( - proxy_host_.c_str(), proxy_port_, address_family_, tcp_nodelay_, + proxy_host_.c_str(), "", proxy_port_, address_family_, tcp_nodelay_, socket_options_, connection_timeout_sec_, connection_timeout_usec_, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, interface_, error); } + + // Check is custom IP specified for host_ + std::string ip; + auto it = addr_map_.find(host_); + if (it != addr_map_.end()) ip = it->second; + return detail::create_client_socket( - host_.c_str(), port_, address_family_, tcp_nodelay_, socket_options_, - connection_timeout_sec_, connection_timeout_usec_, read_timeout_sec_, - read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, interface_, - error); + host_.c_str(), ip.c_str(), port_, address_family_, tcp_nodelay_, + socket_options_, connection_timeout_sec_, connection_timeout_usec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, interface_, error); } inline bool ClientImpl::create_and_connect_socket(Socket &socket, @@ -5557,13 +5895,17 @@ inline void ClientImpl::close_socket(Socket &socket) { inline bool ClientImpl::read_response_line(Stream &strm, const Request &req, Response &res) { - std::array buf; + std::array buf{}; detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); if (!line_reader.getline()) { return false; } +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR const static std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r\n"); +#else + const static std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r?\n"); +#endif std::cmatch m; if (!std::regex_match(line_reader.ptr(), m, re)) { @@ -5592,13 +5934,14 @@ inline bool ClientImpl::send(Request &req, Response &res, Error &error) { { std::lock_guard guard(socket_mutex_); + // Set this to false immediately - if it ever gets set to true by the end of // the request, we know another thread instructed us to close the socket. socket_should_be_closed_when_request_is_done_ = false; auto is_alive = false; if (socket_.is_open()) { - is_alive = detail::select_write(socket_.sock, 0, 0) > 0; + is_alive = detail::is_socket_alive(socket_.sock); if (!is_alive) { // Attempt to avoid sigpipe by shutting down nongracefully if it seems // like the other side has already closed the connection Also, there @@ -5854,9 +6197,12 @@ inline bool ClientImpl::write_request(Stream &strm, Request &req, if (!req.has_header("Accept")) { req.headers.emplace("Accept", "*/*"); } +#ifndef CPPHTTPLIB_NO_DEFAULT_USER_AGENT if (!req.has_header("User-Agent")) { - req.headers.emplace("User-Agent", "cpp-httplib/0.9"); + auto agent = std::string("cpp-httplib/") + CPPHTTPLIB_VERSION; + req.headers.emplace("User-Agent", agent); } +#endif if (req.body.empty()) { if (req.content_provider_) { @@ -5934,7 +6280,12 @@ inline bool ClientImpl::write_request(Stream &strm, Request &req, return write_content_with_provider(strm, req, error); } - return detail::write_data(strm, req.body.data(), req.body.size()); + if (!detail::write_data(strm, req.body.data(), req.body.size())) { + error = Error::Write; + return false; + } + + return true; } inline std::unique_ptr ClientImpl::send_with_content_provider( @@ -5944,11 +6295,6 @@ inline std::unique_ptr ClientImpl::send_with_content_provider( ContentProviderWithoutLength content_provider_without_length, const char *content_type, Error &error) { - // Request req; - // req.method = method; - // req.headers = headers; - // req.path = path; - if (content_type) { req.headers.emplace("Content-Type", content_type); } #ifdef CPPHTTPLIB_ZLIB_SUPPORT @@ -6237,7 +6583,7 @@ inline Result ClientImpl::Get(const char *path, const Params ¶ms, const Headers &headers, Progress progress) { if (params.empty()) { return Get(path, headers); } - std::string path_with_query = detail::append_query_params(path, params); + std::string path_with_query = append_query_params(path, params); return Get(path_with_query.c_str(), headers, progress); } @@ -6257,7 +6603,7 @@ inline Result ClientImpl::Get(const char *path, const Params ¶ms, return Get(path, headers, response_handler, content_receiver, progress); } - std::string path_with_query = detail::append_query_params(path, params); + std::string path_with_query = append_query_params(path, params); return Get(path_with_query.c_str(), headers, response_handler, content_receiver, progress); } @@ -6632,6 +6978,11 @@ inline void ClientImpl::set_follow_location(bool on) { follow_location_ = on; } inline void ClientImpl::set_url_encode(bool on) { url_encode_ = on; } +inline void +ClientImpl::set_hostname_addr_map(std::map addr_map) { + addr_map_ = std::move(addr_map); +} + inline void ClientImpl::set_default_headers(Headers headers) { default_headers_ = std::move(headers); } @@ -6771,14 +7122,13 @@ bool ssl_connect_or_accept_nonblocking(socket_t sock, SSL *ssl, } template -inline bool -process_server_socket_ssl(SSL *ssl, socket_t sock, size_t keep_alive_max_count, - time_t keep_alive_timeout_sec, - time_t read_timeout_sec, time_t read_timeout_usec, - time_t write_timeout_sec, time_t write_timeout_usec, - T callback) { +inline bool process_server_socket_ssl( + const std::atomic &svr_sock, SSL *ssl, socket_t sock, + size_t keep_alive_max_count, time_t keep_alive_timeout_sec, + time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, T callback) { return process_server_socket_core( - sock, keep_alive_max_count, keep_alive_timeout_sec, + svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec, [&](bool close_connection, bool &connection_closed) { SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, write_timeout_sec, write_timeout_usec); @@ -6872,39 +7222,63 @@ inline bool SSLSocketStream::is_writable() const { } inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { + size_t readbytes = 0; if (SSL_pending(ssl_) > 0) { - return SSL_read(ssl_, ptr, static_cast(size)); - } else if (is_readable()) { - auto ret = SSL_read(ssl_, ptr, static_cast(size)); - if (ret < 0) { - auto err = SSL_get_error(ssl_, ret); - int n = 1000; -#ifdef _WIN32 - while (--n >= 0 && - (err == SSL_ERROR_WANT_READ || - err == SSL_ERROR_SYSCALL && WSAGetLastError() == WSAETIMEDOUT)) { -#else - while (--n >= 0 && err == SSL_ERROR_WANT_READ) { -#endif - if (SSL_pending(ssl_) > 0) { - return SSL_read(ssl_, ptr, static_cast(size)); - } else if (is_readable()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - ret = SSL_read(ssl_, ptr, static_cast(size)); - if (ret >= 0) { return ret; } - err = SSL_get_error(ssl_, ret); - } else { - return -1; - } - } - } - return ret; + auto ret = SSL_read_ex(ssl_, ptr, size, &readbytes); + if (ret == 1) { return static_cast(readbytes); } + if (SSL_get_error(ssl_, ret) == SSL_ERROR_ZERO_RETURN) { return 0; } + return -1; } + if (!is_readable()) { return -1; } + + auto ret = SSL_read_ex(ssl_, ptr, size, &readbytes); + if (ret == 1) { return static_cast(readbytes); } + auto err = SSL_get_error(ssl_, ret); + int n = 1000; +#ifdef _WIN32 + while (--n >= 0 && + (err == SSL_ERROR_WANT_READ || + (err == SSL_ERROR_SYSCALL && WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err == SSL_ERROR_WANT_READ) { +#endif + if (SSL_pending(ssl_) > 0) { + ret = SSL_read_ex(ssl_, ptr, size, &readbytes); + if (ret == 1) { return static_cast(readbytes); } + if (SSL_get_error(ssl_, ret) == SSL_ERROR_ZERO_RETURN) { return 0; } + return -1; + } + if (!is_readable()) { return -1; } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + ret = SSL_read_ex(ssl_, ptr, size, &readbytes); + if (ret == 1) { return static_cast(readbytes); } + err = SSL_get_error(ssl_, ret); + } + if (err == SSL_ERROR_ZERO_RETURN) { return 0; } return -1; } inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { - if (is_writable()) { return SSL_write(ssl_, ptr, static_cast(size)); } + if (!is_writable()) { return -1; } + size_t written = 0; + auto ret = SSL_write_ex(ssl_, ptr, size, &written); + if (ret == 1) { return static_cast(written); } + auto err = SSL_get_error(ssl_, ret); + int n = 1000; +#ifdef _WIN32 + while (--n >= 0 && + (err == SSL_ERROR_WANT_WRITE || + (err == SSL_ERROR_SYSCALL && WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { +#endif + if (!is_writable()) { return -1; } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + ret = SSL_write_ex(ssl_, ptr, size, &written); + if (ret == 1) { return static_cast(written); } + err = SSL_get_error(ssl_, ret); + } + if (err == SSL_ERROR_ZERO_RETURN) { return 0; } return -1; } @@ -6922,18 +7296,22 @@ static SSLInit sslinit_; // SSL HTTP server implementation inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, const char *client_ca_cert_file_path, - const char *client_ca_cert_dir_path) { - ctx_ = SSL_CTX_new(TLS_method()); + const char *client_ca_cert_dir_path, + const char *private_key_password) { + ctx_ = SSL_CTX_new(TLS_server_method()); if (ctx_) { SSL_CTX_set_options(ctx_, - SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | - SSL_OP_NO_COMPRESSION | + SSL_OP_NO_COMPRESSION | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); - // auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); - // SSL_CTX_set_tmp_ecdh(ctx_, ecdh); - // EC_KEY_free(ecdh); + SSL_CTX_set_min_proto_version(ctx_, TLS1_1_VERSION); + + // add default password callback before opening encrypted private key + if (private_key_password != nullptr && (private_key_password[0] != '\0')) { + SSL_CTX_set_default_passwd_cb_userdata(ctx_, + (char *)private_key_password); + } if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != @@ -6941,46 +7319,46 @@ inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, SSL_CTX_free(ctx_); ctx_ = nullptr; } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { - // if (client_ca_cert_file_path) { - // auto list = SSL_load_client_CA_file(client_ca_cert_file_path); - // SSL_CTX_set_client_CA_list(ctx_, list); - // } - SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, client_ca_cert_dir_path); SSL_CTX_set_verify( - ctx_, - SSL_VERIFY_PEER | - SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, - nullptr); + ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); } } } inline SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, X509_STORE *client_ca_cert_store) { - ctx_ = SSL_CTX_new(SSLv23_server_method()); + ctx_ = SSL_CTX_new(TLS_server_method()); if (ctx_) { SSL_CTX_set_options(ctx_, - SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | - SSL_OP_NO_COMPRESSION | + SSL_OP_NO_COMPRESSION | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + SSL_CTX_set_min_proto_version(ctx_, TLS1_1_VERSION); + if (SSL_CTX_use_certificate(ctx_, cert) != 1 || SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) { SSL_CTX_free(ctx_); ctx_ = nullptr; } else if (client_ca_cert_store) { - SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); SSL_CTX_set_verify( - ctx_, - SSL_VERIFY_PEER | - SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, - nullptr); + ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); + } + } +} + +inline SSLServer::SSLServer( + const std::function &setup_ssl_ctx_callback) { + ctx_ = SSL_CTX_new(TLS_method()); + if (ctx_) { + if (!setup_ssl_ctx_callback(*ctx_)) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; } } } @@ -6991,6 +7369,8 @@ inline SSLServer::~SSLServer() { inline bool SSLServer::is_valid() const { return ctx_; } +inline SSL_CTX *SSLServer::ssl_context() const { return ctx_; } + inline bool SSLServer::process_and_close_socket(socket_t sock) { auto ssl = detail::ssl_new( sock, ctx_, ctx_mutex_, @@ -7003,7 +7383,7 @@ inline bool SSLServer::process_and_close_socket(socket_t sock) { bool ret = false; if (ssl) { ret = detail::process_server_socket_ssl( - ssl, sock, keep_alive_max_count_, keep_alive_timeout_sec_, + svr_sock_, ssl, sock, keep_alive_max_count_, keep_alive_timeout_sec_, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, [this, ssl](Stream &strm, bool close_connection, @@ -7034,12 +7414,13 @@ inline SSLClient::SSLClient(const std::string &host, int port, const std::string &client_cert_path, const std::string &client_key_path) : ClientImpl(host, port, client_cert_path, client_key_path) { - ctx_ = SSL_CTX_new(SSLv23_client_method()); + ctx_ = SSL_CTX_new(TLS_client_method()); detail::split(&host_[0], &host_[host_.size()], '.', [&](const char *b, const char *e) { host_components_.emplace_back(std::string(b, e)); }); + if (!client_cert_path.empty() && !client_key_path.empty()) { if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), SSL_FILETYPE_PEM) != 1 || @@ -7054,12 +7435,13 @@ inline SSLClient::SSLClient(const std::string &host, int port, inline SSLClient::SSLClient(const std::string &host, int port, X509 *client_cert, EVP_PKEY *client_key) : ClientImpl(host, port) { - ctx_ = SSL_CTX_new(SSLv23_client_method()); + ctx_ = SSL_CTX_new(TLS_client_method()); detail::split(&host_[0], &host_[host_.size()], '.', [&](const char *b, const char *e) { host_components_.emplace_back(std::string(b, e)); }); + if (client_cert != nullptr && client_key != nullptr) { if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 || SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) { @@ -7419,8 +7801,10 @@ inline Client::Client(const std::string &scheme_host_port, #else if (!scheme.empty() && scheme != "http") { #endif +#ifndef CPPHTTPLIB_NO_EXCEPTIONS std::string msg = "'" + scheme + "' scheme is not supported."; throw std::invalid_argument(msg); +#endif return; } @@ -7721,6 +8105,11 @@ inline size_t Client::is_socket_open() const { return cli_->is_socket_open(); } inline void Client::stop() { cli_->stop(); } +inline void +Client::set_hostname_addr_map(std::map addr_map) { + cli_->set_hostname_addr_map(std::move(addr_map)); +} + inline void Client::set_default_headers(Headers headers) { cli_->set_default_headers(std::move(headers)); }