Atomic fix

Fix possible pointer arithmetic ops.
Fix fat atomics (currently unused).
This commit is contained in:
Nekotekina 2020-02-08 22:42:54 +03:00
parent efc8c3f4a9
commit 7ea4eb0095
1 changed files with 121 additions and 63 deletions

View File

@ -683,6 +683,8 @@ class atomic_t
protected: protected:
using type = typename std::remove_cv<T>::type; using type = typename std::remove_cv<T>::type;
using ptr_rt = std::conditional_t<std::is_pointer_v<type>, ullong, type>;
static_assert(alignof(type) == sizeof(type), "atomic_t<> error: unexpected alignment, use alignas() if necessary"); static_assert(alignof(type) == sizeof(type), "atomic_t<> error: unexpected alignment, use alignas() if necessary");
type m_data; type m_data;
@ -827,7 +829,7 @@ public:
return atomic_storage<type>::exchange(m_data, rhs); return atomic_storage<type>::exchange(m_data, rhs);
} }
type fetch_add(const type& rhs) auto fetch_add(const ptr_rt& rhs)
{ {
if constexpr(std::is_integral<type>::value) if constexpr(std::is_integral<type>::value)
{ {
@ -840,7 +842,7 @@ public:
}); });
} }
type add_fetch(const type& rhs) auto add_fetch(const ptr_rt& rhs)
{ {
if constexpr(std::is_integral<type>::value) if constexpr(std::is_integral<type>::value)
{ {
@ -854,7 +856,7 @@ public:
}); });
} }
auto operator +=(const type& rhs) auto operator +=(const ptr_rt& rhs)
{ {
if constexpr(std::is_integral<type>::value) if constexpr(std::is_integral<type>::value)
{ {
@ -867,7 +869,7 @@ public:
}); });
} }
type fetch_sub(const type& rhs) auto fetch_sub(const ptr_rt& rhs)
{ {
if constexpr(std::is_integral<type>::value) if constexpr(std::is_integral<type>::value)
{ {
@ -880,7 +882,7 @@ public:
}); });
} }
type sub_fetch(const type& rhs) auto sub_fetch(const ptr_rt& rhs)
{ {
if constexpr(std::is_integral<type>::value) if constexpr(std::is_integral<type>::value)
{ {
@ -894,7 +896,7 @@ public:
}); });
} }
auto operator -=(const type& rhs) auto operator -=(const ptr_rt& rhs)
{ {
if constexpr(std::is_integral<type>::value) if constexpr(std::is_integral<type>::value)
{ {
@ -907,7 +909,7 @@ public:
}); });
} }
type fetch_and(const type& rhs) auto fetch_and(const type& rhs)
{ {
if constexpr(std::is_integral<type>::value) if constexpr(std::is_integral<type>::value)
{ {
@ -920,7 +922,7 @@ public:
}); });
} }
type and_fetch(const type& rhs) auto and_fetch(const type& rhs)
{ {
if constexpr(std::is_integral<type>::value) if constexpr(std::is_integral<type>::value)
{ {
@ -947,7 +949,7 @@ public:
}); });
} }
type fetch_or(const type& rhs) auto fetch_or(const type& rhs)
{ {
if constexpr(std::is_integral<type>::value) if constexpr(std::is_integral<type>::value)
{ {
@ -960,7 +962,7 @@ public:
}); });
} }
type or_fetch(const type& rhs) auto or_fetch(const type& rhs)
{ {
if constexpr(std::is_integral<type>::value) if constexpr(std::is_integral<type>::value)
{ {
@ -987,7 +989,7 @@ public:
}); });
} }
type fetch_xor(const type& rhs) auto fetch_xor(const type& rhs)
{ {
if constexpr(std::is_integral<type>::value) if constexpr(std::is_integral<type>::value)
{ {
@ -1000,7 +1002,7 @@ public:
}); });
} }
type xor_fetch(const type& rhs) auto xor_fetch(const type& rhs)
{ {
if constexpr(std::is_integral<type>::value) if constexpr(std::is_integral<type>::value)
{ {
@ -1158,11 +1160,18 @@ class atomic_with_lock_bit
// Simply internal type // Simply internal type
using type = std::conditional_t<std::is_pointer_v<T>, std::uintptr_t, T>; using type = std::conditional_t<std::is_pointer_v<T>, std::uintptr_t, T>;
// Used for pointer arithmetics
using ptr_rt = std::conditional_t<std::is_pointer_v<T>, ullong, T>;
static constexpr auto c_lock_bit = BitWidth + 1;
static constexpr auto c_dirty = type{1} << BitWidth;
// Check space for lock bit // Check space for lock bit
static_assert(BitWidth < sizeof(T) * 8, "No space for lock bit"); static_assert(BitWidth <= sizeof(T) * 8 - 2, "No space for lock bit");
static_assert(sizeof(T) <= 8, "Not supported"); static_assert(sizeof(T) <= 8 || (!std::is_pointer_v<T> && !std::is_integral_v<T>), "Not supported");
static_assert(!std::is_same_v<std::decay_t<T>, bool>, "Bool not supported, use integral with size 1.");
static_assert(std::is_pointer_v<T> == (BitWidth == 0), "BitWidth should be 0 for pointers"); static_assert(std::is_pointer_v<T> == (BitWidth == 0), "BitWidth should be 0 for pointers");
static_assert(!std::is_pointer_v<T> || (alignof(std::remove_pointer_t<T>) > 1), "Pointer type should have align 2 or more"); static_assert(!std::is_pointer_v<T> || (alignof(std::remove_pointer_t<T>) >= 4), "Pointer type should have align 4 or more");
// Use the most significant bit as a mutex // Use the most significant bit as a mutex
atomic_t<type> m_data; atomic_t<type> m_data;
@ -1172,33 +1181,29 @@ public:
static bool is_locked(type old_val) static bool is_locked(type old_val)
{ {
if constexpr (std::is_signed_v<type> && BitWidth == sizeof(T) * 8 - 1) if constexpr (std::is_signed_v<type> && BitWidth == sizeof(T) * 8 - 2)
{ {
return old_val < 0; return old_val < 0;
} }
else if constexpr (std::is_pointer_v<T>) else if constexpr (std::is_pointer_v<T>)
{ {
return (old_val & 1) != 0; return (old_val & 2) != 0;
} }
else else
{ {
return (old_val & (type{1} << BitWidth)) != 0; return (old_val & (type{2} << BitWidth)) != 0;
} }
} }
static type clamp_value(type old_val) static type clamp_value(type old_val)
{ {
if constexpr (std::is_signed_v<type>) if constexpr (std::is_pointer_v<T>)
{ {
return static_cast<type>(static_cast<std::make_unsigned_t<type>>(old_val) << (sizeof(T) * 8 - BitWidth)) >> (sizeof(T) * 8 - BitWidth); return old_val & (~type{0} << 2);
}
else if constexpr (std::is_pointer_v<T>)
{
return old_val & 0xffff'ffff'ffff'fffeull;
} }
else else
{ {
return old_val & static_cast<type>(0xffff'ffff'ffff'ffffull >> (64 - BitWidth)); return old_val & ((type{1} << BitWidth) - type{1});
} }
} }
@ -1226,18 +1231,33 @@ public:
void raw_release(type value) void raw_release(type value)
{ {
m_data.release(clamp_value(value)); m_data.release(clamp_value(value));
m_data.notify_all();
// TODO: test dirty bit for notification
if (true)
{
m_data.notify_all();
}
} }
void lock() void lock()
{ {
while (m_data.bts(BitWidth)) [[unlikely]] while (m_data.bts(c_lock_bit)) [[unlikely]]
{ {
type old_val = m_data.load(); type old_val = m_data.load();
if (is_locked(old_val)) if (is_locked(old_val)) [[likely]]
{ {
m_data.wait(old_val); if ((old_val & c_dirty) == 0)
{
// Try to set dirty bit if not set already
if (!m_data.compare_and_swap_test(old_val, old_val | c_dirty))
{
// Situation changed
continue;
}
}
m_data.wait(old_val | c_dirty);
old_val = m_data.load(); old_val = m_data.load();
} }
} }
@ -1245,13 +1265,27 @@ public:
bool try_lock() bool try_lock()
{ {
return !m_data.bts(BitWidth); return !m_data.bts(c_lock_bit);
} }
void unlock() void unlock()
{ {
m_data.btr(BitWidth); type old_val = m_data.load();
m_data.notify_all();
if constexpr (std::is_pointer_v<T>)
{
m_data.and_fetch(~type{0} << 2);
}
else
{
m_data.and_fetch((type{1} << BitWidth) - type{1});
}
// Test dirty bit for notification
if (old_val & c_dirty)
{
m_data.notify_all();
}
} }
T load() T load()
@ -1260,7 +1294,15 @@ public:
while (is_locked(old_val)) [[unlikely]] while (is_locked(old_val)) [[unlikely]]
{ {
m_data.wait(old_val); if ((old_val & c_dirty) == 0)
{
if (!m_data.compare_and_swap_test(old_val, old_val | c_dirty))
{
continue;
}
}
m_data.wait(old_val | c_dirty);
old_val = m_data.load(); old_val = m_data.load();
} }
@ -1273,23 +1315,39 @@ public:
while (is_locked(old_val) || !m_data.compare_and_swap_test(old_val, clamp_value(reinterpret_cast<type>(value)))) [[unlikely]] while (is_locked(old_val) || !m_data.compare_and_swap_test(old_val, clamp_value(reinterpret_cast<type>(value)))) [[unlikely]]
{ {
if ((old_val & c_dirty) == 0)
{
if (!m_data.compare_and_swap_test(old_val, old_val | c_dirty))
{
continue;
}
}
m_data.wait(old_val); m_data.wait(old_val);
old_val = m_data.load(); old_val = m_data.load();
} }
} }
template <typename F, typename RT = std::invoke_result_t<F, type&>> template <typename F, typename RT = std::invoke_result_t<F, T&>>
RT atomic_op(F func) RT atomic_op(F func)
{ {
type _new, old; type _new, old;
old.m_data = m_data.load(); old = m_data.load();
while (true) while (true)
{ {
if (is_locked(old.m_data)) [[unlikely]] if (is_locked(old)) [[unlikely]]
{ {
m_data.wait(old.m_data); if ((old & c_dirty) == 0)
old.m_data = m_data.load(); {
if (!m_data.compare_and_swap_test(old, old | c_dirty))
{
continue;
}
}
m_data.wait(old);
old = m_data.load();
continue; continue;
} }
@ -1299,7 +1357,7 @@ public:
{ {
std::invoke(func, reinterpret_cast<T&>(_new)); std::invoke(func, reinterpret_cast<T&>(_new));
if (atomic_storage<type>::compare_exchange(m_data, old.m_data, clamp_value(_new.m_data))) [[likely]] if (atomic_storage<type>::compare_exchange(m_data.raw(), old, clamp_value(_new))) [[likely]]
{ {
return; return;
} }
@ -1308,7 +1366,7 @@ public:
{ {
RT result = std::invoke(func, reinterpret_cast<T&>(_new)); RT result = std::invoke(func, reinterpret_cast<T&>(_new));
if (atomic_storage<type>::compare_exchange(m_data, old.m_data, clamp_value(_new.m_data))) [[likely]] if (atomic_storage<type>::compare_exchange(m_data.raw(), old, clamp_value(_new))) [[likely]]
{ {
return result; return result;
} }
@ -1316,15 +1374,15 @@ public:
} }
} }
type fetch_add(const type& rhs) auto fetch_add(const ptr_rt& rhs)
{ {
return atomic_op([&](T& v) return atomic_op([&](T& v)
{ {
return std::exchange(v, v += rhs); return std::exchange(v, (v += rhs));
}); });
} }
auto operator +=(const type& rhs) auto operator +=(const ptr_rt& rhs)
{ {
return atomic_op([&](T& v) return atomic_op([&](T& v)
{ {
@ -1332,15 +1390,15 @@ public:
}); });
} }
type fetch_sub(const type& rhs) auto fetch_sub(const ptr_rt& rhs)
{ {
return atomic_op([&](T& v) return atomic_op([&](T& v)
{ {
return std::exchange(v, v -= rhs); return std::exchange(v, (v -= rhs));
}); });
} }
auto operator -=(const type& rhs) auto operator -=(const ptr_rt& rhs)
{ {
return atomic_op([&](T& v) return atomic_op([&](T& v)
{ {
@ -1348,15 +1406,15 @@ public:
}); });
} }
type fetch_and(const type& rhs) auto fetch_and(const T& rhs)
{ {
return atomic_op([&](T& v) return atomic_op([&](T& v)
{ {
return std::exchange(v, v &= rhs); return std::exchange(v, (v &= rhs));
}); });
} }
auto operator &=(const type& rhs) auto operator &=(const T& rhs)
{ {
return atomic_op([&](T& v) return atomic_op([&](T& v)
{ {
@ -1364,15 +1422,15 @@ public:
}); });
} }
type fetch_or(const type& rhs) auto fetch_or(const T& rhs)
{ {
return atomic_op([&](T& v) return atomic_op([&](T& v)
{ {
return std::exchange(v, v |= rhs); return std::exchange(v, (v |= rhs));
}); });
} }
auto operator |=(const type& rhs) auto operator |=(const T& rhs)
{ {
return atomic_op([&](T& v) return atomic_op([&](T& v)
{ {
@ -1380,15 +1438,15 @@ public:
}); });
} }
type fetch_xor(const type& rhs) auto fetch_xor(const T& rhs)
{ {
return atomic_op([&](T& v) return atomic_op([&](T& v)
{ {
return std::exchange(v, v ^= rhs); return std::exchange(v, (v ^= rhs));
}); });
} }
auto operator ^=(const type& rhs) auto operator ^=(const T& rhs)
{ {
return atomic_op([&](T& v) return atomic_op([&](T& v)
{ {
@ -1430,22 +1488,22 @@ public:
}; };
using fat_atomic_u1 = atomic_with_lock_bit<u8, 1>; using fat_atomic_u1 = atomic_with_lock_bit<u8, 1>;
using fat_atomic_u7 = atomic_with_lock_bit<u8, 7>; using fat_atomic_u6 = atomic_with_lock_bit<u8, 6>;
using fat_atomic_s7 = atomic_with_lock_bit<s8, 7>; using fat_atomic_s6 = atomic_with_lock_bit<s8, 6>;
using fat_atomic_u8 = atomic_with_lock_bit<u16, 8>; using fat_atomic_u8 = atomic_with_lock_bit<u16, 8>;
using fat_atomic_s8 = atomic_with_lock_bit<s16, 8>; using fat_atomic_s8 = atomic_with_lock_bit<s16, 8>;
using fat_atomic_u15 = atomic_with_lock_bit<u16, 15>; using fat_atomic_u14 = atomic_with_lock_bit<u16, 14>;
using fat_atomic_s15 = atomic_with_lock_bit<s16, 15>; using fat_atomic_s14 = atomic_with_lock_bit<s16, 14>;
using fat_atomic_u16 = atomic_with_lock_bit<u32, 16>; using fat_atomic_u16 = atomic_with_lock_bit<u32, 16>;
using fat_atomic_s16 = atomic_with_lock_bit<s32, 16>; using fat_atomic_s16 = atomic_with_lock_bit<s32, 16>;
using fat_atomic_u31 = atomic_with_lock_bit<u32, 31>; using fat_atomic_u30 = atomic_with_lock_bit<u32, 30>;
using fat_atomic_s31 = atomic_with_lock_bit<s32, 31>; using fat_atomic_s30 = atomic_with_lock_bit<s32, 30>;
using fat_atomic_u32 = atomic_with_lock_bit<u64, 32>; using fat_atomic_u32 = atomic_with_lock_bit<u64, 32>;
using fat_atomic_s32 = atomic_with_lock_bit<s64, 32>; using fat_atomic_s32 = atomic_with_lock_bit<s64, 32>;
using fat_atomic_u63 = atomic_with_lock_bit<u64, 63>; using fat_atomic_u62 = atomic_with_lock_bit<u64, 62>;
using fat_atomic_s63 = atomic_with_lock_bit<s64, 63>; using fat_atomic_s62 = atomic_with_lock_bit<s64, 62>;
template <typename Ptr> template <typename Ptr>
using fat_atomic_ptr = atomic_with_lock_bit<Ptr*, 0>; using fat_atomic_ptr = atomic_with_lock_bit<Ptr*, 0>;