PPU/simd.hpp: minor changes in DP instructions

This commit is contained in:
Nekotekina 2022-01-21 13:42:06 +03:00
parent 0de9960772
commit a4d94a83b9
2 changed files with 35 additions and 19 deletions

View File

@ -2839,12 +2839,9 @@ auto VSUM4SBS()
static const auto exec = [](auto&& d, auto&& a, auto&& b, auto&& sat)
{
//const auto r = _mm_dpbusds_epi32(b, _mm_set1_epi8(1), a);
//const auto s = _mm_dpbusd_epi32(b, _mm_set1_epi8(1), a);
auto x = gv_hadds8x4(a);
auto r = gv_adds_s32(x, b);
auto r = gv_dots_u8s8x4(gv_bcst8(1), a, b);
if constexpr (((Flags == set_sat) || ...))
sat = gv_or32(gv_xor32(gv_add32(std::move(x), std::move(b)), r), std::move(sat));
sat = gv_or32(gv_xor32(gv_hadds8x4(std::move(a), std::move(b)), r), std::move(sat));
d = std::move(r);
};
@ -2859,12 +2856,9 @@ auto VSUM4SHS()
static const auto exec = [](auto&& d, auto&& a, auto&& b, auto&& sat)
{
//const auto r = _mm_dpwssds_epi32(b, a, _mm_set1_epi16(1));
//const auto s = _mm_dpwssd_epi32(b, a, _mm_set1_epi16(1));
auto x = gv_hadds16x2(a);
auto r = gv_adds_s32(x, b);
auto r = gv_dots_s16x2(a, gv_bcst16(1), b);
if constexpr (((Flags == set_sat) || ...))
sat = gv_or32(gv_xor32(gv_add32(std::move(x), std::move(b)), r), std::move(sat));
sat = gv_or32(gv_xor32(gv_hadds16x2(std::move(a), std::move(b)), r), std::move(sat));
d = std::move(r);
};

View File

@ -1970,16 +1970,16 @@ inline v128 gv_hadds8x2(const v128& a)
#endif
}
inline v128 gv_hadds8x4(const v128& a)
inline v128 gv_hadds8x4(const v128& a, const v128& c)
{
#if (defined(__AVX512VL__) && defined(__AVX512VNNI__)) || defined(__AVXVNNI__)
return _mm_dpbusd_epi32(_mm_setzero_si128(), _mm_set1_epi8(1), a);
return _mm_dpbusd_epi32(c, _mm_set1_epi8(1), a);
#elif defined(__SSSE3__)
return _mm_madd_epi16(_mm_maddubs_epi16(_mm_set1_epi8(1), a), _mm_set1_epi16(1));
return _mm_add_epi32(_mm_madd_epi16(_mm_maddubs_epi16(_mm_set1_epi8(1), a), _mm_set1_epi16(1)), c);
#elif defined(ARCH_X64)
return _mm_madd_epi16(_mm_add_epi16(_mm_srai_epi16(a, 8), _mm_srai_epi16(_mm_slli_epi16(a, 8), 8)), _mm_set1_epi16(1));
return _mm_add_epi32(_mm_madd_epi16(_mm_add_epi16(_mm_srai_epi16(a, 8), _mm_srai_epi16(_mm_slli_epi16(a, 8), 8)), _mm_set1_epi16(1)), c);
#elif defined(ARCH_ARM64)
return vpaddlq_s16(vpaddlq_s8(a));
return vaddq_s32(vpaddlq_s16(vpaddlq_s8(a)), c);
#endif
}
@ -2007,12 +2007,14 @@ inline v128 gv_haddu8x4(const v128& a)
#endif
}
inline v128 gv_hadds16x2(const v128& a)
inline v128 gv_hadds16x2(const v128& a, const v128& c)
{
#if defined(ARCH_X64)
return _mm_madd_epi16(a, _mm_set1_epi16(1));
#if (defined(__AVX512VL__) && defined(__AVX512VNNI__)) || defined(__AVXVNNI__)
return _mm_dpwssd_epi32(c, a, _mm_set1_epi8(1));
#elif defined(ARCH_X64)
return _mm_add_epi32(_mm_madd_epi16(a, _mm_set1_epi16(1)), c);
#elif defined(ARCH_ARM64)
return vpaddlq_s16(a);
return vaddq_s32(vpaddlq_s16(a), c);
#endif
}
@ -2099,6 +2101,26 @@ inline v128 gv_dotu16x2(const v128& a, const v128& b)
#endif
}
// Unsigned bytes from a, signed bytes from b, 32-bit accumulator c
inline v128 gv_dots_u8s8x4(const v128& a, const v128& b, const v128& c)
{
#if (defined(__AVX512VL__) && defined(__AVX512VNNI__)) || defined(__AVXVNNI__)
return _mm_dpbusds_epi32(c, a, b);
#elif defined(ARCH_X64)
const __m128i ah = _mm_srli_epi16(a, 8);
const __m128i al = _mm_and_si128(a, _mm_set1_epi16(0x00ff));
const __m128i bh = _mm_srai_epi16(b, 8);
const __m128i bl = _mm_srai_epi16(_mm_slli_epi16(b, 8), 8);
const __m128i mh = _mm_madd_epi16(ah, bh);
const __m128i ml = _mm_madd_epi16(al, bl);
return gv_adds_s32(c, _mm_add_epi32(mh, ml));
#elif defined(ARCH_ARM64)
const auto l = vpaddlq_s16(vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(a))), vmovl_s8(vget_low_s8(b))));
const auto h = vpaddlq_s16(vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(a))), vmovl_s8(vget_high_s8(b))));
return vqaddq_s32(c, vaddq_s32(vuzp1q_s32(l, h), vuzp2q_s32(l, h)));
#endif
}
// Signed s16 from a and b, 32-bit accumulator c; signed saturation
inline v128 gv_dots_s16x2(const v128& a, const v128& b, const v128& c)
{