#pragma once /* This header implements bs_t<> class for scoped enum types (enum class). To enable bs_t<>, enum scope must contain `__bitset_enum_max` entry. enum class flagzz : u32 { flag1, // Bit indices start from zero flag2, __bitset_enum_max // It must be the last value }; This also enables helper operators for this enum type. Examples: `+flagzz::flag1` - unary `+` operator convert flagzz value to bs_t `flagzz::flag1 + flagzz::flag2` - bitset union `flagzz::flag1 - flagzz::flag2` - bitset difference Intersection (&) and symmetric difference (^) is also available. */ #include "util/types.hpp" #include "util/atomic.hpp" #include "Utilities/StrFmt.h" template concept BitSetEnum = std::is_enum_v && requires(T x) { T::__bitset_enum_max; }; template class atomic_bs_t; // Bitset type for enum class with available bits [0, T::__bitset_enum_max) template class bs_t final { public: // Underlying type using under = std::underlying_type_t; private: // Underlying value under m_data; friend class atomic_bs_t; // Value constructor constexpr explicit bs_t(int, under data) : m_data(data) { } public: static constexpr usz bitmax = sizeof(T) * 8; static constexpr usz bitsize = static_cast(T::__bitset_enum_max); static_assert(std::is_enum::value, "bs_t<> error: invalid type (must be enum)"); static_assert(bitsize <= bitmax, "bs_t<> error: invalid __bitset_enum_max"); static_assert(bitsize != bitmax || std::is_unsigned::value, "bs_t<> error: invalid __bitset_enum_max (sign bit)"); // Helper function static constexpr under shift(T value) { return static_cast(1) << static_cast(value); } bs_t() = default; // Construct from a single bit constexpr bs_t(T bit) : m_data(shift(bit)) { } // Test for empty bitset constexpr explicit operator bool() const { return m_data != 0; } // Extract underlying data constexpr explicit operator under() const { return m_data; } // Copy constexpr bs_t operator +() const { return *this; } constexpr bs_t& operator +=(bs_t rhs) { m_data |= static_cast(rhs); return *this; } constexpr bs_t& operator -=(bs_t rhs) { m_data &= ~static_cast(rhs); return *this; } constexpr bs_t& operator &=(bs_t rhs) { m_data &= static_cast(rhs); return *this; } constexpr bs_t& operator ^=(bs_t rhs) { m_data ^= static_cast(rhs); return *this; } friend constexpr bs_t operator +(bs_t lhs, bs_t rhs) { return bs_t(0, lhs.m_data | rhs.m_data); } friend constexpr bs_t operator -(bs_t lhs, bs_t rhs) { return bs_t(0, lhs.m_data & ~rhs.m_data); } friend constexpr bs_t operator &(bs_t lhs, bs_t rhs) { return bs_t(0, lhs.m_data & rhs.m_data); } friend constexpr bs_t operator ^(bs_t lhs, bs_t rhs) { return bs_t(0, lhs.m_data ^ rhs.m_data); } constexpr bool operator ==(bs_t rhs) const { return m_data == rhs.m_data; } constexpr bool test_and_set(T bit) { bool r = (m_data & shift(bit)) != 0; m_data |= shift(bit); return r; } constexpr bool test_and_reset(T bit) { bool r = (m_data & shift(bit)) != 0; m_data &= ~shift(bit); return r; } constexpr bool test_and_complement(T bit) { bool r = (m_data & shift(bit)) != 0; m_data ^= shift(bit); return r; } constexpr bool all_of(bs_t arg) { return (m_data & arg.m_data) == arg.m_data; } constexpr bool none_of(bs_t arg) { return (m_data & arg.m_data) == 0; } }; // Unary '+' operator: promote plain enum value to bitset value template constexpr bs_t operator +(T bit) { return bs_t(bit); } // Binary '+' operator: bitset union template requires (std::is_constructible_v, U>) constexpr bs_t operator +(T lhs, const U& rhs) { return bs_t(lhs) + bs_t(rhs); } // Binary '+' operator: bitset union template requires (std::is_constructible_v, U> && !std::is_enum_v) constexpr bs_t operator +(const U& lhs, T rhs) { return bs_t(lhs) + bs_t(rhs); } // Binary '-' operator: bitset difference template requires (std::is_constructible_v, U>) constexpr bs_t operator -(T lhs, const U& rhs) { return bs_t(lhs) - bs_t(rhs); } // Binary '-' operator: bitset difference template requires (std::is_constructible_v, U> && !std::is_enum_v) constexpr bs_t operator -(const U& lhs, T rhs) { return bs_t(lhs) - bs_t(rhs); } // Binary '&' operator: bitset intersection template requires (std::is_constructible_v, U>) constexpr bs_t operator &(T lhs, const U& rhs) { return bs_t(lhs) & bs_t(rhs); } // Binary '&' operator: bitset intersection template requires (std::is_constructible_v, U> && !std::is_enum_v) constexpr bs_t operator &(const U& lhs, T rhs) { return bs_t(lhs) & bs_t(rhs); } // Binary '^' operator: bitset symmetric difference template requires (std::is_constructible_v, U>) constexpr bs_t operator ^(T lhs, const U& rhs) { return bs_t(lhs) ^ bs_t(rhs); } // Binary '^' operator: bitset symmetric difference template requires (std::is_constructible_v, U> && !std::is_enum_v) constexpr bs_t operator ^(const U& lhs, T rhs) { return bs_t(lhs) ^ bs_t(rhs); } // Atomic bitset specialization with optimized operations template class atomic_bs_t : public atomic_t<::bs_t> { // Corresponding bitset type using bs_t = ::bs_t; // Base class using base = atomic_t<::bs_t>; // Use underlying m_data using base::m_data; public: // Underlying type using under = typename bs_t::under; atomic_bs_t() = default; atomic_bs_t(const atomic_bs_t&) = delete; atomic_bs_t& operator =(const atomic_bs_t&) = delete; explicit constexpr atomic_bs_t(bs_t value) : base(value) { } explicit constexpr atomic_bs_t(T bit) : base(bit) { } using base::operator bs_t; explicit operator bool() const { return static_cast(base::load()); } explicit operator under() const { return static_cast(base::load()); } bs_t operator +() const { return base::load(); } bs_t fetch_add(const bs_t& rhs) { return bs_t(0, atomic_storage::fetch_or(m_data.m_data, rhs.m_data)); } bs_t add_fetch(const bs_t& rhs) { return bs_t(0, atomic_storage::or_fetch(m_data.m_data, rhs.m_data)); } bs_t operator +=(const bs_t& rhs) { return add_fetch(rhs); } bs_t fetch_sub(const bs_t& rhs) { return bs_t(0, atomic_storage::fetch_and(m_data.m_data, ~rhs.m_data)); } bs_t sub_fetch(const bs_t& rhs) { return bs_t(0, atomic_storage::and_fetch(m_data.m_data, ~rhs.m_data)); } bs_t operator -=(const bs_t& rhs) { return sub_fetch(rhs); } bs_t fetch_and(const bs_t& rhs) { return bs_t(0, atomic_storage::fetch_and(m_data.m_data, rhs.m_data)); } bs_t and_fetch(const bs_t& rhs) { return bs_t(0, atomic_storage::and_fetch(m_data.m_data, rhs.m_data)); } bs_t operator &=(const bs_t& rhs) { return and_fetch(rhs); } bs_t fetch_xor(const bs_t& rhs) { return bs_t(0, atomic_storage::fetch_xor(m_data.m_data, rhs.m_data)); } bs_t xor_fetch(const bs_t& rhs) { return bs_t(0, atomic_storage::xor_fetch(m_data.m_data, rhs.m_data)); } bs_t operator ^=(const bs_t& rhs) { return xor_fetch(rhs); } auto fetch_or(const bs_t&) = delete; auto or_fetch(const bs_t&) = delete; auto operator |=(const bs_t&) = delete; bool test_and_set(T rhs) { return atomic_storage::bts(m_data.m_data, static_cast(static_cast(rhs))); } bool test_and_reset(T rhs) { return atomic_storage::btr(m_data.m_data, static_cast(static_cast(rhs))); } bool test_and_invert(T rhs) { return atomic_storage::btc(m_data.m_data, static_cast(static_cast(rhs))); } bool bit_test_set(uint bit) = delete; bool bit_test_reset(uint bit) = delete; bool bit_test_invert(uint bit) = delete; bool all_of(bs_t arg) { return base::load().all_of(arg); } bool none_of(bs_t arg) { return base::load().none_of(arg); } }; template struct fmt_unveil, void> { // Format as is using type = bs_t; static inline u64 get(const bs_t& bitset) { return static_cast>(bitset); } };