mirror of https://github.com/xemu-project/xemu.git
target/arm: Prepare bfdotadd() callers for FEAT_EBF support
We use bfdotadd() in four callsites for various helper functions. Currently this all assumes that we have the FPCR.EBF=0 semantics. For FPCR.EBF=1 we will need to: * call a different routine to bfdotadd() because we need to do a fused multiply-add rather than separate multiply and add steps * use a different float_status that honours the FPCR rounding mode and denormal-flushing fields * pass in an extra float_status that has been set up to perform round-to-odd rounding To prepare for this, refactor all the callsites so that instead of for (...) { x = bfdotadd(...); } they are: float_status fpst, fpst_odd; if (is_ebf(env, &fpst, &fpst_odd)) { for (...) { x = bfdotadd_ebf(..., &fpst, &fpst_odd); } } else { for (...) { x = bfdotadd(..., &fpst); } } For the moment the is_ebf() function always returns false, sets up fpst for EBF=0 semantics and never sets up fpst_odd; bfdotadd_ebf() will assert if called. We'll fill in the handling for EBF=1 in the next commit. This change should be a zero-behaviour-change refactor. Signed-off-by: Peter Maydell <peter.maydell@linaro.org> Reviewed-by: Richard Henderson <richard.henderson@linaro.org>
This commit is contained in:
parent
2da2d7dc90
commit
09b0d9e0ad
|
@ -1085,32 +1085,62 @@ void HELPER(sme_bfmopa)(void *vza, void *vzn, void *vzm,
|
||||||
intptr_t row, col, oprsz = simd_maxsz(desc);
|
intptr_t row, col, oprsz = simd_maxsz(desc);
|
||||||
uint32_t neg = simd_data(desc) * 0x80008000u;
|
uint32_t neg = simd_data(desc) * 0x80008000u;
|
||||||
uint16_t *pn = vpn, *pm = vpm;
|
uint16_t *pn = vpn, *pm = vpm;
|
||||||
|
float_status fpst, fpst_odd;
|
||||||
|
|
||||||
for (row = 0; row < oprsz; ) {
|
if (is_ebf(env, &fpst, &fpst_odd)) {
|
||||||
uint16_t prow = pn[H2(row >> 4)];
|
for (row = 0; row < oprsz; ) {
|
||||||
do {
|
uint16_t prow = pn[H2(row >> 4)];
|
||||||
void *vza_row = vza + tile_vslice_offset(row);
|
do {
|
||||||
uint32_t n = *(uint32_t *)(vzn + H1_4(row));
|
void *vza_row = vza + tile_vslice_offset(row);
|
||||||
|
uint32_t n = *(uint32_t *)(vzn + H1_4(row));
|
||||||
|
|
||||||
n = f16mop_adj_pair(n, prow, neg);
|
n = f16mop_adj_pair(n, prow, neg);
|
||||||
|
|
||||||
for (col = 0; col < oprsz; ) {
|
for (col = 0; col < oprsz; ) {
|
||||||
uint16_t pcol = pm[H2(col >> 4)];
|
uint16_t pcol = pm[H2(col >> 4)];
|
||||||
do {
|
do {
|
||||||
if (prow & pcol & 0b0101) {
|
if (prow & pcol & 0b0101) {
|
||||||
uint32_t *a = vza_row + H1_4(col);
|
uint32_t *a = vza_row + H1_4(col);
|
||||||
uint32_t m = *(uint32_t *)(vzm + H1_4(col));
|
uint32_t m = *(uint32_t *)(vzm + H1_4(col));
|
||||||
|
|
||||||
m = f16mop_adj_pair(m, pcol, 0);
|
m = f16mop_adj_pair(m, pcol, 0);
|
||||||
*a = bfdotadd(*a, n, m);
|
*a = bfdotadd_ebf(*a, n, m, &fpst, &fpst_odd);
|
||||||
}
|
}
|
||||||
col += 4;
|
col += 4;
|
||||||
pcol >>= 4;
|
pcol >>= 4;
|
||||||
} while (col & 15);
|
} while (col & 15);
|
||||||
}
|
}
|
||||||
row += 4;
|
row += 4;
|
||||||
prow >>= 4;
|
prow >>= 4;
|
||||||
} while (row & 15);
|
} while (row & 15);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (row = 0; row < oprsz; ) {
|
||||||
|
uint16_t prow = pn[H2(row >> 4)];
|
||||||
|
do {
|
||||||
|
void *vza_row = vza + tile_vslice_offset(row);
|
||||||
|
uint32_t n = *(uint32_t *)(vzn + H1_4(row));
|
||||||
|
|
||||||
|
n = f16mop_adj_pair(n, prow, neg);
|
||||||
|
|
||||||
|
for (col = 0; col < oprsz; ) {
|
||||||
|
uint16_t pcol = pm[H2(col >> 4)];
|
||||||
|
do {
|
||||||
|
if (prow & pcol & 0b0101) {
|
||||||
|
uint32_t *a = vza_row + H1_4(col);
|
||||||
|
uint32_t m = *(uint32_t *)(vzm + H1_4(col));
|
||||||
|
|
||||||
|
m = f16mop_adj_pair(m, pcol, 0);
|
||||||
|
*a = bfdotadd(*a, n, m, &fpst);
|
||||||
|
}
|
||||||
|
col += 4;
|
||||||
|
pcol >>= 4;
|
||||||
|
} while (col & 15);
|
||||||
|
}
|
||||||
|
row += 4;
|
||||||
|
prow >>= 4;
|
||||||
|
} while (row & 15);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2790,39 +2790,58 @@ DO_MMLA_B(gvec_usmmla_b, do_usmmla_b)
|
||||||
* BFloat16 Dot Product
|
* BFloat16 Dot Product
|
||||||
*/
|
*/
|
||||||
|
|
||||||
float32 bfdotadd(float32 sum, uint32_t e1, uint32_t e2)
|
bool is_ebf(CPUARMState *env, float_status *statusp, float_status *oddstatusp)
|
||||||
{
|
{
|
||||||
/* FPCR is ignored for BFDOT and BFMMLA. */
|
/* FPCR is ignored for BFDOT and BFMMLA. */
|
||||||
float_status bf_status = {
|
*statusp = (float_status){
|
||||||
.tininess_before_rounding = float_tininess_before_rounding,
|
.tininess_before_rounding = float_tininess_before_rounding,
|
||||||
.float_rounding_mode = float_round_to_odd_inf,
|
.float_rounding_mode = float_round_to_odd_inf,
|
||||||
.flush_to_zero = true,
|
.flush_to_zero = true,
|
||||||
.flush_inputs_to_zero = true,
|
.flush_inputs_to_zero = true,
|
||||||
.default_nan_mode = true,
|
.default_nan_mode = true,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
float32 bfdotadd(float32 sum, uint32_t e1, uint32_t e2, float_status *fpst)
|
||||||
|
{
|
||||||
float32 t1, t2;
|
float32 t1, t2;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Extract each BFloat16 from the element pair, and shift
|
* Extract each BFloat16 from the element pair, and shift
|
||||||
* them such that they become float32.
|
* them such that they become float32.
|
||||||
*/
|
*/
|
||||||
t1 = float32_mul(e1 << 16, e2 << 16, &bf_status);
|
t1 = float32_mul(e1 << 16, e2 << 16, fpst);
|
||||||
t2 = float32_mul(e1 & 0xffff0000u, e2 & 0xffff0000u, &bf_status);
|
t2 = float32_mul(e1 & 0xffff0000u, e2 & 0xffff0000u, fpst);
|
||||||
t1 = float32_add(t1, t2, &bf_status);
|
t1 = float32_add(t1, t2, fpst);
|
||||||
t1 = float32_add(sum, t1, &bf_status);
|
t1 = float32_add(sum, t1, fpst);
|
||||||
|
|
||||||
return t1;
|
return t1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float32 bfdotadd_ebf(float32 sum, uint32_t e1, uint32_t e2,
|
||||||
|
float_status *fpst, float_status *fpst_odd)
|
||||||
|
{
|
||||||
|
g_assert_not_reached();
|
||||||
|
}
|
||||||
|
|
||||||
void HELPER(gvec_bfdot)(void *vd, void *vn, void *vm, void *va,
|
void HELPER(gvec_bfdot)(void *vd, void *vn, void *vm, void *va,
|
||||||
CPUARMState *env, uint32_t desc)
|
CPUARMState *env, uint32_t desc)
|
||||||
{
|
{
|
||||||
intptr_t i, opr_sz = simd_oprsz(desc);
|
intptr_t i, opr_sz = simd_oprsz(desc);
|
||||||
float32 *d = vd, *a = va;
|
float32 *d = vd, *a = va;
|
||||||
uint32_t *n = vn, *m = vm;
|
uint32_t *n = vn, *m = vm;
|
||||||
|
float_status fpst, fpst_odd;
|
||||||
|
|
||||||
for (i = 0; i < opr_sz / 4; ++i) {
|
if (is_ebf(env, &fpst, &fpst_odd)) {
|
||||||
d[i] = bfdotadd(a[i], n[i], m[i]);
|
for (i = 0; i < opr_sz / 4; ++i) {
|
||||||
|
d[i] = bfdotadd_ebf(a[i], n[i], m[i], &fpst, &fpst_odd);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (i = 0; i < opr_sz / 4; ++i) {
|
||||||
|
d[i] = bfdotadd(a[i], n[i], m[i], &fpst);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
clear_tail(d, opr_sz, simd_maxsz(desc));
|
clear_tail(d, opr_sz, simd_maxsz(desc));
|
||||||
}
|
}
|
||||||
|
@ -2836,12 +2855,23 @@ void HELPER(gvec_bfdot_idx)(void *vd, void *vn, void *vm,
|
||||||
intptr_t eltspersegment = MIN(16 / 4, elements);
|
intptr_t eltspersegment = MIN(16 / 4, elements);
|
||||||
float32 *d = vd, *a = va;
|
float32 *d = vd, *a = va;
|
||||||
uint32_t *n = vn, *m = vm;
|
uint32_t *n = vn, *m = vm;
|
||||||
|
float_status fpst, fpst_odd;
|
||||||
|
|
||||||
for (i = 0; i < elements; i += eltspersegment) {
|
if (is_ebf(env, &fpst, &fpst_odd)) {
|
||||||
uint32_t m_idx = m[i + H4(index)];
|
for (i = 0; i < elements; i += eltspersegment) {
|
||||||
|
uint32_t m_idx = m[i + H4(index)];
|
||||||
|
|
||||||
for (j = i; j < i + eltspersegment; j++) {
|
for (j = i; j < i + eltspersegment; j++) {
|
||||||
d[j] = bfdotadd(a[j], n[j], m_idx);
|
d[j] = bfdotadd_ebf(a[j], n[j], m_idx, &fpst, &fpst_odd);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (i = 0; i < elements; i += eltspersegment) {
|
||||||
|
uint32_t m_idx = m[i + H4(index)];
|
||||||
|
|
||||||
|
for (j = i; j < i + eltspersegment; j++) {
|
||||||
|
d[j] = bfdotadd(a[j], n[j], m_idx, &fpst);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
clear_tail(d, opr_sz, simd_maxsz(desc));
|
clear_tail(d, opr_sz, simd_maxsz(desc));
|
||||||
|
@ -2853,37 +2883,72 @@ void HELPER(gvec_bfmmla)(void *vd, void *vn, void *vm, void *va,
|
||||||
intptr_t s, opr_sz = simd_oprsz(desc);
|
intptr_t s, opr_sz = simd_oprsz(desc);
|
||||||
float32 *d = vd, *a = va;
|
float32 *d = vd, *a = va;
|
||||||
uint32_t *n = vn, *m = vm;
|
uint32_t *n = vn, *m = vm;
|
||||||
|
float_status fpst, fpst_odd;
|
||||||
|
|
||||||
for (s = 0; s < opr_sz / 4; s += 4) {
|
if (is_ebf(env, &fpst, &fpst_odd)) {
|
||||||
float32 sum00, sum01, sum10, sum11;
|
for (s = 0; s < opr_sz / 4; s += 4) {
|
||||||
|
float32 sum00, sum01, sum10, sum11;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Process the entire segment at once, writing back the
|
* Process the entire segment at once, writing back the
|
||||||
* results only after we've consumed all of the inputs.
|
* results only after we've consumed all of the inputs.
|
||||||
*
|
*
|
||||||
* Key to indices by column:
|
* Key to indices by column:
|
||||||
* i j i k j k
|
* i j i k j k
|
||||||
*/
|
*/
|
||||||
sum00 = a[s + H4(0 + 0)];
|
sum00 = a[s + H4(0 + 0)];
|
||||||
sum00 = bfdotadd(sum00, n[s + H4(0 + 0)], m[s + H4(0 + 0)]);
|
sum00 = bfdotadd_ebf(sum00, n[s + H4(0 + 0)], m[s + H4(0 + 0)], &fpst, &fpst_odd);
|
||||||
sum00 = bfdotadd(sum00, n[s + H4(0 + 1)], m[s + H4(0 + 1)]);
|
sum00 = bfdotadd_ebf(sum00, n[s + H4(0 + 1)], m[s + H4(0 + 1)], &fpst, &fpst_odd);
|
||||||
|
|
||||||
sum01 = a[s + H4(0 + 1)];
|
sum01 = a[s + H4(0 + 1)];
|
||||||
sum01 = bfdotadd(sum01, n[s + H4(0 + 0)], m[s + H4(2 + 0)]);
|
sum01 = bfdotadd_ebf(sum01, n[s + H4(0 + 0)], m[s + H4(2 + 0)], &fpst, &fpst_odd);
|
||||||
sum01 = bfdotadd(sum01, n[s + H4(0 + 1)], m[s + H4(2 + 1)]);
|
sum01 = bfdotadd_ebf(sum01, n[s + H4(0 + 1)], m[s + H4(2 + 1)], &fpst, &fpst_odd);
|
||||||
|
|
||||||
sum10 = a[s + H4(2 + 0)];
|
sum10 = a[s + H4(2 + 0)];
|
||||||
sum10 = bfdotadd(sum10, n[s + H4(2 + 0)], m[s + H4(0 + 0)]);
|
sum10 = bfdotadd_ebf(sum10, n[s + H4(2 + 0)], m[s + H4(0 + 0)], &fpst, &fpst_odd);
|
||||||
sum10 = bfdotadd(sum10, n[s + H4(2 + 1)], m[s + H4(0 + 1)]);
|
sum10 = bfdotadd_ebf(sum10, n[s + H4(2 + 1)], m[s + H4(0 + 1)], &fpst, &fpst_odd);
|
||||||
|
|
||||||
sum11 = a[s + H4(2 + 1)];
|
sum11 = a[s + H4(2 + 1)];
|
||||||
sum11 = bfdotadd(sum11, n[s + H4(2 + 0)], m[s + H4(2 + 0)]);
|
sum11 = bfdotadd_ebf(sum11, n[s + H4(2 + 0)], m[s + H4(2 + 0)], &fpst, &fpst_odd);
|
||||||
sum11 = bfdotadd(sum11, n[s + H4(2 + 1)], m[s + H4(2 + 1)]);
|
sum11 = bfdotadd_ebf(sum11, n[s + H4(2 + 1)], m[s + H4(2 + 1)], &fpst, &fpst_odd);
|
||||||
|
|
||||||
d[s + H4(0 + 0)] = sum00;
|
d[s + H4(0 + 0)] = sum00;
|
||||||
d[s + H4(0 + 1)] = sum01;
|
d[s + H4(0 + 1)] = sum01;
|
||||||
d[s + H4(2 + 0)] = sum10;
|
d[s + H4(2 + 0)] = sum10;
|
||||||
d[s + H4(2 + 1)] = sum11;
|
d[s + H4(2 + 1)] = sum11;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (s = 0; s < opr_sz / 4; s += 4) {
|
||||||
|
float32 sum00, sum01, sum10, sum11;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Process the entire segment at once, writing back the
|
||||||
|
* results only after we've consumed all of the inputs.
|
||||||
|
*
|
||||||
|
* Key to indices by column:
|
||||||
|
* i j i k j k
|
||||||
|
*/
|
||||||
|
sum00 = a[s + H4(0 + 0)];
|
||||||
|
sum00 = bfdotadd(sum00, n[s + H4(0 + 0)], m[s + H4(0 + 0)], &fpst);
|
||||||
|
sum00 = bfdotadd(sum00, n[s + H4(0 + 1)], m[s + H4(0 + 1)], &fpst);
|
||||||
|
|
||||||
|
sum01 = a[s + H4(0 + 1)];
|
||||||
|
sum01 = bfdotadd(sum01, n[s + H4(0 + 0)], m[s + H4(2 + 0)], &fpst);
|
||||||
|
sum01 = bfdotadd(sum01, n[s + H4(0 + 1)], m[s + H4(2 + 1)], &fpst);
|
||||||
|
|
||||||
|
sum10 = a[s + H4(2 + 0)];
|
||||||
|
sum10 = bfdotadd(sum10, n[s + H4(2 + 0)], m[s + H4(0 + 0)], &fpst);
|
||||||
|
sum10 = bfdotadd(sum10, n[s + H4(2 + 1)], m[s + H4(0 + 1)], &fpst);
|
||||||
|
|
||||||
|
sum11 = a[s + H4(2 + 1)];
|
||||||
|
sum11 = bfdotadd(sum11, n[s + H4(2 + 0)], m[s + H4(2 + 0)], &fpst);
|
||||||
|
sum11 = bfdotadd(sum11, n[s + H4(2 + 1)], m[s + H4(2 + 1)], &fpst);
|
||||||
|
|
||||||
|
d[s + H4(0 + 0)] = sum00;
|
||||||
|
d[s + H4(0 + 1)] = sum01;
|
||||||
|
d[s + H4(2 + 0)] = sum10;
|
||||||
|
d[s + H4(2 + 1)] = sum11;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
clear_tail(d, opr_sz, simd_maxsz(desc));
|
clear_tail(d, opr_sz, simd_maxsz(desc));
|
||||||
}
|
}
|
||||||
|
|
|
@ -223,13 +223,46 @@ int64_t do_sqrdmlah_d(int64_t, int64_t, int64_t, bool, bool);
|
||||||
* bfdotadd:
|
* bfdotadd:
|
||||||
* @sum: addend
|
* @sum: addend
|
||||||
* @e1, @e2: multiplicand vectors
|
* @e1, @e2: multiplicand vectors
|
||||||
|
* @fpst: floating-point status to use
|
||||||
*
|
*
|
||||||
* BFloat16 2-way dot product of @e1 & @e2, accumulating with @sum.
|
* BFloat16 2-way dot product of @e1 & @e2, accumulating with @sum.
|
||||||
* The @e1 and @e2 operands correspond to the 32-bit source vector
|
* The @e1 and @e2 operands correspond to the 32-bit source vector
|
||||||
* slots and contain two Bfloat16 values each.
|
* slots and contain two Bfloat16 values each.
|
||||||
*
|
*
|
||||||
* Corresponds to the ARM pseudocode function BFDotAdd.
|
* Corresponds to the ARM pseudocode function BFDotAdd, specialized
|
||||||
|
* for the FPCR.EBF == 0 case.
|
||||||
*/
|
*/
|
||||||
float32 bfdotadd(float32 sum, uint32_t e1, uint32_t e2);
|
float32 bfdotadd(float32 sum, uint32_t e1, uint32_t e2, float_status *fpst);
|
||||||
|
/**
|
||||||
|
* bfdotadd_ebf:
|
||||||
|
* @sum: addend
|
||||||
|
* @e1, @e2: multiplicand vectors
|
||||||
|
* @fpst: floating-point status to use
|
||||||
|
* @fpst_odd: floating-point status to use for round-to-odd operations
|
||||||
|
*
|
||||||
|
* BFloat16 2-way dot product of @e1 & @e2, accumulating with @sum.
|
||||||
|
* The @e1 and @e2 operands correspond to the 32-bit source vector
|
||||||
|
* slots and contain two Bfloat16 values each.
|
||||||
|
*
|
||||||
|
* Corresponds to the ARM pseudocode function BFDotAdd, specialized
|
||||||
|
* for the FPCR.EBF == 1 case.
|
||||||
|
*/
|
||||||
|
float32 bfdotadd_ebf(float32 sum, uint32_t e1, uint32_t e2,
|
||||||
|
float_status *fpst, float_status *fpst_odd);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* is_ebf:
|
||||||
|
* @env: CPU state
|
||||||
|
* @statusp: pointer to floating point status to fill in
|
||||||
|
* @oddstatusp: pointer to floating point status to fill in for round-to-odd
|
||||||
|
*
|
||||||
|
* Determine whether a BFDotAdd operation should use FPCR.EBF = 0
|
||||||
|
* or FPCR.EBF = 1 semantics. On return, has initialized *statusp
|
||||||
|
* and *oddstatusp to suitable float_status arguments to use with either
|
||||||
|
* bfdotadd() or bfdotadd_ebf().
|
||||||
|
* Returns true for EBF = 1, false for EBF = 0. (The caller should use this
|
||||||
|
* to decide whether to call bfdotadd() or bfdotadd_ebf().)
|
||||||
|
*/
|
||||||
|
bool is_ebf(CPUARMState *env, float_status *statusp, float_status *oddstatusp);
|
||||||
|
|
||||||
#endif /* TARGET_ARM_VEC_INTERNAL_H */
|
#endif /* TARGET_ARM_VEC_INTERNAL_H */
|
||||||
|
|
Loading…
Reference in New Issue