bn: Use int instead of u32 for sizes

The loops relied on unsigned integer overflow, which is not immediately
obvious. Replace them with less clever variants that are clearer.

Also implement bn_compare using std::memcmp.
This commit is contained in:
Léo Lam 2018-05-16 00:27:43 +02:00
parent 56e91bfdc1
commit b9dd94b9b2
2 changed files with 34 additions and 53 deletions

View File

@ -3,61 +3,43 @@
// http://www.gnu.org/licenses/old-licenses/gpl-2.0.txt // http://www.gnu.org/licenses/old-licenses/gpl-2.0.txt
#include <cstdio> #include <cstdio>
#include <string.h> #include <cstring>
#include "Common/CommonTypes.h" #include "Common/CommonTypes.h"
#include "Common/Crypto/bn.h" #include "Common/Crypto/bn.h"
static void bn_zero(u8* d, u32 n) static void bn_zero(u8* d, int n)
{ {
memset(d, 0, n); std::memset(d, 0, n);
} }
static void bn_copy(u8* d, const u8* a, u32 n) static void bn_copy(u8* d, const u8* a, int n)
{ {
memcpy(d, a, n); std::memcpy(d, a, n);
} }
int bn_compare(const u8* a, const u8* b, u32 n) int bn_compare(const u8* a, const u8* b, int n)
{ {
u32 i; return std::memcmp(a, b, n);
for (i = 0; i < n; i++)
{
if (a[i] < b[i])
return -1;
if (a[i] > b[i])
return 1;
} }
return 0; void bn_sub_modulus(u8* a, const u8* N, int n)
}
void bn_sub_modulus(u8* a, const u8* N, u32 n)
{ {
u32 i; u8 c = 0;
u32 dig; for (int i = n - 1; i >= 0; --i)
u8 c;
c = 0;
for (i = n - 1; i < n; i--)
{ {
dig = N[i] + c; u32 dig = N[i] + c;
c = (a[i] < dig); c = (a[i] < dig);
a[i] -= dig; a[i] -= dig;
} }
} }
void bn_add(u8* d, const u8* a, const u8* b, const u8* N, u32 n) void bn_add(u8* d, const u8* a, const u8* b, const u8* N, int n)
{ {
u32 i; u8 c = 0;
u32 dig; for (int i = n - 1; i >= 0; --i)
u8 c;
c = 0;
for (i = n - 1; i < n; i--)
{ {
dig = a[i] + b[i] + c; u32 dig = a[i] + b[i] + c;
c = (dig >= 0x100); c = (dig >= 0x100);
d[i] = dig; d[i] = dig;
} }
@ -69,32 +51,30 @@ void bn_add(u8* d, const u8* a, const u8* b, const u8* N, u32 n)
bn_sub_modulus(d, N, n); bn_sub_modulus(d, N, n);
} }
void bn_mul(u8* d, const u8* a, const u8* b, const u8* N, u32 n) void bn_mul(u8* d, const u8* a, const u8* b, const u8* N, int n)
{ {
u32 i;
u8 mask;
bn_zero(d, n); bn_zero(d, n);
for (i = 0; i < n; i++) for (int i = 0; i < n; i++)
for (mask = 0x80; mask != 0; mask >>= 1) {
for (u8 mask = 0x80; mask != 0; mask >>= 1)
{ {
bn_add(d, d, d, N, n); bn_add(d, d, d, N, n);
if ((a[i] & mask) != 0) if ((a[i] & mask) != 0)
bn_add(d, d, b, N, n); bn_add(d, d, b, N, n);
} }
} }
}
void bn_exp(u8* d, const u8* a, const u8* N, u32 n, const u8* e, u32 en) void bn_exp(u8* d, const u8* a, const u8* N, int n, const u8* e, int en)
{ {
u8 t[512]; u8 t[512];
u32 i;
u8 mask;
bn_zero(d, n); bn_zero(d, n);
d[n - 1] = 1; d[n - 1] = 1;
for (i = 0; i < en; i++) for (int i = 0; i < en; i++)
for (mask = 0x80; mask != 0; mask >>= 1) {
for (u8 mask = 0x80; mask != 0; mask >>= 1)
{ {
bn_mul(t, d, d, N, n); bn_mul(t, d, d, N, n);
if ((e[i] & mask) != 0) if ((e[i] & mask) != 0)
@ -103,9 +83,10 @@ void bn_exp(u8* d, const u8* a, const u8* N, u32 n, const u8* e, u32 en)
bn_copy(d, t, n); bn_copy(d, t, n);
} }
} }
}
// only for prime N -- stupid but lazy, see if I care // only for prime N -- stupid but lazy, see if I care
void bn_inv(u8* d, const u8* a, const u8* N, u32 n) void bn_inv(u8* d, const u8* a, const u8* N, int n)
{ {
u8 t[512], s[512]; u8 t[512], s[512];

View File

@ -8,9 +8,9 @@
// bignum arithmetic // bignum arithmetic
int bn_compare(const u8* a, const u8* b, u32 n); int bn_compare(const u8* a, const u8* b, int n);
void bn_sub_modulus(u8* a, const u8* N, u32 n); void bn_sub_modulus(u8* a, const u8* N, int n);
void bn_add(u8* d, const u8* a, const u8* b, const u8* N, u32 n); void bn_add(u8* d, const u8* a, const u8* b, const u8* N, int n);
void bn_mul(u8* d, const u8* a, const u8* b, const u8* N, u32 n); void bn_mul(u8* d, const u8* a, const u8* b, const u8* N, int n);
void bn_inv(u8* d, const u8* a, const u8* N, u32 n); // only for prime N void bn_inv(u8* d, const u8* a, const u8* N, int n); // only for prime N
void bn_exp(u8* d, const u8* a, const u8* N, u32 n, const u8* e, u32 en); void bn_exp(u8* d, const u8* a, const u8* N, int n, const u8* e, int en);