tls: get rid of constant-time add/sub operations

function                                             old     new   delta
sp_256_sub_10                                          -      22     +22
static.sp_256_mont_reduce_10                         176     178      +2
sp_256_mod_mul_norm_10                              1440    1439      -1
sp_256_proj_point_dbl_10                             453     446      -7
sp_256_ecc_mulmod_10                                1229    1216     -13
static.sp_256_mont_sub_10                             52      30     -22
static.sp_256_cond_sub_10                             32       -     -32
------------------------------------------------------------------------------
(add/remove: 1/1 grow/shrink: 1/4 up/down: 24/-75)            Total: -51 bytes

Signed-off-by: Denys Vlasenko <vda.linux@googlemail.com>
This commit is contained in:
Denys Vlasenko 2021-04-26 21:58:04 +02:00
parent 120401249a
commit 9a40be433d

View File

@ -203,26 +203,12 @@ static void sp_256_add_10(sp_digit* r, const sp_digit* a, const sp_digit* b)
r[i] = a[i] + b[i]; r[i] = a[i] + b[i];
} }
/* Conditionally add a and b using the mask m. /* Sub b from a into r. (r = a - b) */
* m is -1 to add and 0 when not. static void sp_256_sub_10(sp_digit* r, const sp_digit* a, const sp_digit* b)
*/
static void sp_256_cond_add_10(sp_digit* r, const sp_digit* a,
const sp_digit* b, const sp_digit m)
{ {
int i; int i;
for (i = 0; i < 10; i++) for (i = 0; i < 10; i++)
r[i] = a[i] + (b[i] & m); r[i] = a[i] - b[i];
}
/* Conditionally subtract b from a using the mask m.
* m is -1 to subtract and 0 when not.
*/
static void sp_256_cond_sub_10(sp_digit* r, const sp_digit* a,
const sp_digit* b, const sp_digit m)
{
int i;
for (i = 0; i < 10; i++)
r[i] = a[i] - (b[i] & m);
} }
/* Shift number left one bit. Bottom bit is lost. */ /* Shift number left one bit. Bottom bit is lost. */
@ -352,7 +338,8 @@ static void sp_256_mul_add_10(sp_digit* r, const sp_digit* a, sp_digit b)
/* Divide the number by 2 mod the modulus (prime). (r = a / 2 % m) */ /* Divide the number by 2 mod the modulus (prime). (r = a / 2 % m) */
static void sp_256_div2_10(sp_digit* r, const sp_digit* a, const sp_digit* m) static void sp_256_div2_10(sp_digit* r, const sp_digit* a, const sp_digit* m)
{ {
sp_256_cond_add_10(r, a, m, 0 - (a[0] & 1)); if (a[0] & 1)
sp_256_add_10(r, a, m);
sp_256_norm_10(r); sp_256_norm_10(r);
sp_256_rshift1_10(r, r); sp_256_rshift1_10(r, r);
} }
@ -382,7 +369,8 @@ static void sp_256_mont_add_10(sp_digit* r, const sp_digit* a, const sp_digit* b
{ {
sp_256_add_10(r, a, b); sp_256_add_10(r, a, b);
sp_256_norm_10(r); sp_256_norm_10(r);
sp_256_cond_sub_10(r, r, m, 0 - ((r[9] >> 22) > 0)); if ((r[9] >> 22) > 0)
sp_256_sub_10(r, r, m);
sp_256_norm_10(r); sp_256_norm_10(r);
} }
@ -391,7 +379,8 @@ static void sp_256_mont_dbl_10(sp_digit* r, const sp_digit* a, const sp_digit* m
{ {
sp_256_add_10(r, a, a); sp_256_add_10(r, a, a);
sp_256_norm_10(r); sp_256_norm_10(r);
sp_256_cond_sub_10(r, r, m, 0 - ((r[9] >> 22) > 0)); if ((r[9] >> 22) > 0)
sp_256_sub_10(r, r, m);
sp_256_norm_10(r); sp_256_norm_10(r);
} }
@ -400,28 +389,23 @@ static void sp_256_mont_tpl_10(sp_digit* r, const sp_digit* a, const sp_digit* m
{ {
sp_256_add_10(r, a, a); sp_256_add_10(r, a, a);
sp_256_norm_10(r); sp_256_norm_10(r);
sp_256_cond_sub_10(r, r, m, 0 - ((r[9] >> 22) > 0)); if ((r[9] >> 22) > 0)
sp_256_sub_10(r, r, m);
sp_256_norm_10(r); sp_256_norm_10(r);
sp_256_add_10(r, r, a); sp_256_add_10(r, r, a);
sp_256_norm_10(r); sp_256_norm_10(r);
sp_256_cond_sub_10(r, r, m, 0 - ((r[9] >> 22) > 0)); if ((r[9] >> 22) > 0)
sp_256_sub_10(r, r, m);
sp_256_norm_10(r); sp_256_norm_10(r);
} }
/* Sub b from a into r. (r = a - b) */
static void sp_256_sub_10(sp_digit* r, const sp_digit* a, const sp_digit* b)
{
int i;
for (i = 0; i < 10; i++)
r[i] = a[i] - b[i];
}
/* Subtract two Montgomery form numbers (r = a - b % m) */ /* Subtract two Montgomery form numbers (r = a - b % m) */
static void sp_256_mont_sub_10(sp_digit* r, const sp_digit* a, const sp_digit* b, static void sp_256_mont_sub_10(sp_digit* r, const sp_digit* a, const sp_digit* b,
const sp_digit* m) const sp_digit* m)
{ {
sp_256_sub_10(r, a, b); sp_256_sub_10(r, a, b);
sp_256_cond_add_10(r, r, m, r[9] >> 22); if (r[9] >> 22)
sp_256_add_10(r, r, m);
sp_256_norm_10(r); sp_256_norm_10(r);
} }
@ -460,7 +444,8 @@ static void sp_256_mont_reduce_10(sp_digit* a, const sp_digit* m, sp_digit mp)
} }
sp_256_mont_shift_10(a, a); sp_256_mont_shift_10(a, a);
sp_256_cond_sub_10(a, a, m, 0 - ((a[9] >> 22) > 0)); if ((a[9] >> 22) > 0)
sp_256_sub_10(a, a, m);
sp_256_norm_10(a); sp_256_norm_10(a);
} }
@ -590,7 +575,6 @@ static void sp_256_map_10(sp_point* r, sp_point* p)
{ {
sp_digit t1[2*10]; sp_digit t1[2*10];
sp_digit t2[2*10]; sp_digit t2[2*10];
int32_t n;
sp_256_mont_inv_10(t1, p->z); sp_256_mont_inv_10(t1, p->z);
@ -602,8 +586,8 @@ static void sp_256_map_10(sp_point* r, sp_point* p)
memset(r->x + 10, 0, sizeof(r->x) / 2); memset(r->x + 10, 0, sizeof(r->x) / 2);
sp_256_mont_reduce_10(r->x, p256_mod, p256_mp_mod); sp_256_mont_reduce_10(r->x, p256_mod, p256_mp_mod);
/* Reduce x to less than modulus */ /* Reduce x to less than modulus */
n = sp_256_cmp_10(r->x, p256_mod); if (sp_256_cmp_10(r->x, p256_mod) >= 0)
sp_256_cond_sub_10(r->x, r->x, p256_mod, 0 - (n >= 0)); sp_256_sub_10(r->x, r->x, p256_mod);
sp_256_norm_10(r->x); sp_256_norm_10(r->x);
/* y /= z^3 */ /* y /= z^3 */
@ -611,8 +595,8 @@ static void sp_256_map_10(sp_point* r, sp_point* p)
memset(r->y + 10, 0, sizeof(r->y) / 2); memset(r->y + 10, 0, sizeof(r->y) / 2);
sp_256_mont_reduce_10(r->y, p256_mod, p256_mp_mod); sp_256_mont_reduce_10(r->y, p256_mod, p256_mp_mod);
/* Reduce y to less than modulus */ /* Reduce y to less than modulus */
n = sp_256_cmp_10(r->y, p256_mod); if (sp_256_cmp_10(r->y, p256_mod) >= 0)
sp_256_cond_sub_10(r->y, r->y, p256_mod, 0 - (n >= 0)); sp_256_sub_10(r->y, r->y, p256_mod);
sp_256_norm_10(r->y); sp_256_norm_10(r->y);
memset(r->z, 0, sizeof(r->z)); memset(r->z, 0, sizeof(r->z));