tls: P256: change logic so that we don't need double-wide vectors everywhere

Change sp_256to512z_mont_{mul,sqr}_8 to not require/zero upper 256 bits.
There is only one place where we actually used that (and that's why there
used to be zeroing memset of top half!). Fix up that place.
As a bonus, 256x256->512 multiply no longer needs to care for
"r overlaps a or b" case.

This shrinks sp_point structure as well, not just temporaries.

function                                             old     new   delta
sp_256to512z_mont_mul_8                              150       -    -150
sp_256_mont_mul_8                                      -     147    +147
sp_256to512z_mont_sqr_8                                7       -      -7
sp_256_mont_sqr_8                                      -       7      +7
sp_256_ecc_mulmod_8                                  494     543     +49
sp_512to256_mont_reduce_8                            243     249      +6
sp_256_point_from_bin2x32                             73      70      -3
sp_256_proj_point_dbl_8                              353     345      -8
sp_256_proj_point_add_8                              544     499     -45
------------------------------------------------------------------------------
(add/remove: 2/2 grow/shrink: 2/3 up/down: 209/-213)           Total: -4 bytes

Signed-off-by: Denys Vlasenko <vda.linux@googlemail.com>
This commit is contained in:
Denys Vlasenko 2021-11-27 19:15:43 +01:00
parent 9c671fe3dd
commit f92ae1dc4b

View File

@ -49,9 +49,9 @@ typedef int32_t signed_sp_digit;
*/
typedef struct sp_point {
sp_digit x[2 * 8];
sp_digit y[2 * 8];
sp_digit z[2 * 8];
sp_digit x[8];
sp_digit y[8];
sp_digit z[8];
int infinity;
} sp_point;
@ -456,12 +456,11 @@ static void sp_256_sub_8_p256_mod(sp_digit* r)
#endif
/* Multiply a and b into r. (r = a * b)
* r should be [16] array (512 bits).
* r should be [16] array (512 bits), and must not coincide with a or b.
*/
static void sp_256to512_mul_8(sp_digit* r, const sp_digit* a, const sp_digit* b)
{
#if ALLOW_ASM && defined(__GNUC__) && defined(__i386__)
sp_digit rr[15]; /* in case r coincides with a or b */
int k;
uint32_t accl;
uint32_t acch;
@ -493,16 +492,15 @@ static void sp_256to512_mul_8(sp_digit* r, const sp_digit* a, const sp_digit* b)
j--;
i++;
} while (i != 8 && i <= k);
rr[k] = accl;
r[k] = accl;
accl = acch;
acch = acc_hi;
}
r[15] = accl;
memcpy(r, rr, sizeof(rr));
#elif ALLOW_ASM && defined(__GNUC__) && defined(__x86_64__)
const uint64_t* aa = (const void*)a;
const uint64_t* bb = (const void*)b;
uint64_t rr[8];
const uint64_t* rr = (const void*)r;
int k;
uint64_t accl;
uint64_t acch;
@ -539,11 +537,8 @@ static void sp_256to512_mul_8(sp_digit* r, const sp_digit* a, const sp_digit* b)
acch = acc_hi;
}
rr[7] = accl;
memcpy(r, rr, sizeof(rr));
#elif 0
//TODO: arm assembly (untested)
sp_digit tmp[16];
asm volatile (
"\n mov r5, #0"
"\n mov r6, #0"
@ -575,12 +570,10 @@ static void sp_256to512_mul_8(sp_digit* r, const sp_digit* a, const sp_digit* b)
"\n cmp r5, #56"
"\n ble 1b"
"\n str r6, [%[r], r5]"
: [r] "r" (tmp), [a] "r" (a), [b] "r" (b)
: [r] "r" (r), [a] "r" (a), [b] "r" (b)
: "memory", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", "r12", "r14"
);
memcpy(r, tmp, sizeof(tmp));
#else
sp_digit rr[15]; /* in case r coincides with a or b */
int i, j, k;
uint64_t acc;
@ -600,11 +593,10 @@ static void sp_256to512_mul_8(sp_digit* r, const sp_digit* a, const sp_digit* b)
j--;
i++;
} while (i != 8 && i <= k);
rr[k] = acc;
r[k] = acc;
acc = (acc >> 32) | ((uint64_t)acc_hi << 32);
}
r[15] = acc;
memcpy(r, rr, sizeof(rr));
#endif
}
@ -709,30 +701,11 @@ static void sp_256_mont_tpl_8(sp_digit* r, const sp_digit* a /*, const sp_digit*
}
/* Shift the result in the high 256 bits down to the bottom.
* High half is cleared to zeros.
*/
#if BB_UNALIGNED_MEMACCESS_OK && ULONG_MAX > 0xffffffff
static void sp_512to256_mont_shift_8(sp_digit* rr)
static void sp_512to256_mont_shift_8(sp_digit* r, sp_digit* a)
{
uint64_t *r = (void*)rr;
int i;
for (i = 0; i < 4; i++) {
r[i] = r[i+4];
r[i+4] = 0;
memcpy(r, a + 8, sizeof(*r) * 8);
}
}
#else
static void sp_512to256_mont_shift_8(sp_digit* r)
{
int i;
for (i = 0; i < 8; i++) {
r[i] = r[i+8];
r[i+8] = 0;
}
}
#endif
/* Mul a by scalar b and add into r. (r += a * b)
* a = p256_mod
@ -868,11 +841,12 @@ static int sp_256_mul_add_8(sp_digit* r /*, const sp_digit* a, sp_digit b*/)
* Note: the result is NOT guaranteed to be less than p256_mod!
* (it is only guaranteed to fit into 256 bits).
*
* a Double-wide number to reduce in place.
* r Result.
* a Double-wide number to reduce. Clobbered.
* m The single precision number representing the modulus.
* mp The digit representing the negative inverse of m mod 2^n.
*/
static void sp_512to256_mont_reduce_8(sp_digit* a/*, const sp_digit* m, sp_digit mp*/)
static void sp_512to256_mont_reduce_8(sp_digit* r, sp_digit* a/*, const sp_digit* m, sp_digit mp*/)
{
// const sp_digit* m = p256_mod;
sp_digit mp = p256_mp_mod;
@ -895,10 +869,10 @@ static void sp_512to256_mont_reduce_8(sp_digit* a/*, const sp_digit* m, sp_digit
goto inc_next_word0;
}
}
sp_512to256_mont_shift_8(a);
sp_512to256_mont_shift_8(r, a);
if (word16th != 0)
sp_256_sub_8_p256_mod(a);
sp_256_norm_8(a);
sp_256_sub_8_p256_mod(r);
sp_256_norm_8(r);
}
else { /* Same code for explicit mp == 1 (which is always the case for P256) */
sp_digit word16th = 0;
@ -915,10 +889,10 @@ static void sp_512to256_mont_reduce_8(sp_digit* a/*, const sp_digit* m, sp_digit
goto inc_next_word;
}
}
sp_512to256_mont_shift_8(a);
sp_512to256_mont_shift_8(r, a);
if (word16th != 0)
sp_256_sub_8_p256_mod(a);
sp_256_norm_8(a);
sp_256_sub_8_p256_mod(r);
sp_256_norm_8(r);
}
}
@ -926,35 +900,34 @@ static void sp_512to256_mont_reduce_8(sp_digit* a/*, const sp_digit* m, sp_digit
* (r = a * b mod m)
*
* r Result of multiplication.
* Should be [16] array (512 bits), but high half is cleared to zeros (used as scratch pad).
* a First number to multiply in Montogmery form.
* b Second number to multiply in Montogmery form.
* m Modulus (prime).
* mp Montogmery mulitplier.
*/
static void sp_256to512z_mont_mul_8(sp_digit* r, const sp_digit* a, const sp_digit* b
static void sp_256_mont_mul_8(sp_digit* r, const sp_digit* a, const sp_digit* b
/*, const sp_digit* m, sp_digit mp*/)
{
//const sp_digit* m = p256_mod;
//sp_digit mp = p256_mp_mod;
sp_256to512_mul_8(r, a, b);
sp_512to256_mont_reduce_8(r /*, m, mp*/);
sp_digit t[2 * 8];
sp_256to512_mul_8(t, a, b);
sp_512to256_mont_reduce_8(r, t /*, m, mp*/);
}
/* Square the Montgomery form number. (r = a * a mod m)
*
* r Result of squaring.
* Should be [16] array (512 bits), but high half is cleared to zeros (used as scratch pad).
* a Number to square in Montogmery form.
* m Modulus (prime).
* mp Montogmery mulitplier.
*/
static void sp_256to512z_mont_sqr_8(sp_digit* r, const sp_digit* a
static void sp_256_mont_sqr_8(sp_digit* r, const sp_digit* a
/*, const sp_digit* m, sp_digit mp*/)
{
//const sp_digit* m = p256_mod;
//sp_digit mp = p256_mp_mod;
sp_256to512z_mont_mul_8(r, a, a /*, m, mp*/);
sp_256_mont_mul_8(r, a, a /*, m, mp*/);
}
/* Invert the number, in Montgomery form, modulo the modulus (prime) of the
@ -964,11 +937,8 @@ static void sp_256to512z_mont_sqr_8(sp_digit* r, const sp_digit* a
* a Number to invert.
*/
#if 0
/* Mod-2 for the P256 curve. */
static const uint32_t p256_mod_2[8] = {
0xfffffffd,0xffffffff,0xffffffff,0x00000000,
0x00000000,0x00000000,0x00000001,0xffffffff,
};
//p256_mod - 2:
//ffffffff 00000001 00000000 00000000 00000000 ffffffff ffffffff ffffffff - 2
//Bit pattern:
//2 2 2 2 2 2 2 1...1
//5 5 4 3 2 1 0 9...0 9...1
@ -977,15 +947,15 @@ static const uint32_t p256_mod_2[8] = {
#endif
static void sp_256_mont_inv_8(sp_digit* r, sp_digit* a)
{
sp_digit t[2*8];
sp_digit t[8];
int i;
memcpy(t, a, sizeof(sp_digit) * 8);
for (i = 254; i >= 0; i--) {
sp_256to512z_mont_sqr_8(t, t /*, p256_mod, p256_mp_mod*/);
sp_256_mont_sqr_8(t, t /*, p256_mod, p256_mp_mod*/);
/*if (p256_mod_2[i / 32] & ((sp_digit)1 << (i % 32)))*/
if (i >= 224 || i == 192 || (i <= 95 && i != 1))
sp_256to512z_mont_mul_8(t, t, a /*, p256_mod, p256_mp_mod*/);
sp_256_mont_mul_8(t, t, a /*, p256_mod, p256_mp_mod*/);
}
memcpy(r, t, sizeof(sp_digit) * 8);
}
@ -1056,25 +1026,28 @@ static void sp_256_mod_mul_norm_8(sp_digit* r, const sp_digit* a)
*/
static void sp_256_map_8(sp_point* r, sp_point* p)
{
sp_digit t1[2*8];
sp_digit t2[2*8];
sp_digit t1[8];
sp_digit t2[8];
sp_digit rr[2 * 8];
sp_256_mont_inv_8(t1, p->z);
sp_256to512z_mont_sqr_8(t2, t1 /*, p256_mod, p256_mp_mod*/);
sp_256to512z_mont_mul_8(t1, t2, t1 /*, p256_mod, p256_mp_mod*/);
sp_256_mont_sqr_8(t2, t1 /*, p256_mod, p256_mp_mod*/);
sp_256_mont_mul_8(t1, t2, t1 /*, p256_mod, p256_mp_mod*/);
/* x /= z^2 */
sp_256to512z_mont_mul_8(r->x, p->x, t2 /*, p256_mod, p256_mp_mod*/);
sp_512to256_mont_reduce_8(r->x /*, p256_mod, p256_mp_mod*/);
sp_256_mont_mul_8(rr, p->x, t2 /*, p256_mod, p256_mp_mod*/);
memset(rr + 8, 0, sizeof(rr) / 2);
sp_512to256_mont_reduce_8(r->x, rr /*, p256_mod, p256_mp_mod*/);
/* Reduce x to less than modulus */
if (sp_256_cmp_8(r->x, p256_mod) >= 0)
sp_256_sub_8_p256_mod(r->x);
sp_256_norm_8(r->x);
/* y /= z^3 */
sp_256to512z_mont_mul_8(r->y, p->y, t1 /*, p256_mod, p256_mp_mod*/);
sp_512to256_mont_reduce_8(r->y /*, p256_mod, p256_mp_mod*/);
sp_256_mont_mul_8(rr, p->y, t1 /*, p256_mod, p256_mp_mod*/);
memset(rr + 8, 0, sizeof(rr) / 2);
sp_512to256_mont_reduce_8(r->y, rr /*, p256_mod, p256_mp_mod*/);
/* Reduce y to less than modulus */
if (sp_256_cmp_8(r->y, p256_mod) >= 0)
sp_256_sub_8_p256_mod(r->y);
@ -1091,8 +1064,8 @@ static void sp_256_map_8(sp_point* r, sp_point* p)
*/
static void sp_256_proj_point_dbl_8(sp_point* r, sp_point* p)
{
sp_digit t1[2*8];
sp_digit t2[2*8];
sp_digit t1[8];
sp_digit t2[8];
/* Put point to double into result */
if (r != p)
@ -1101,17 +1074,10 @@ static void sp_256_proj_point_dbl_8(sp_point* r, sp_point* p)
if (r->infinity)
return;
if (SP_DEBUG) {
/* unused part of t2, may result in spurios
* differences in debug output. Clear it.
*/
memset(t2, 0, sizeof(t2));
}
/* T1 = Z * Z */
sp_256to512z_mont_sqr_8(t1, r->z /*, p256_mod, p256_mp_mod*/);
sp_256_mont_sqr_8(t1, r->z /*, p256_mod, p256_mp_mod*/);
/* Z = Y * Z */
sp_256to512z_mont_mul_8(r->z, r->y, r->z /*, p256_mod, p256_mp_mod*/);
sp_256_mont_mul_8(r->z, r->y, r->z /*, p256_mod, p256_mp_mod*/);
/* Z = 2Z */
sp_256_mont_dbl_8(r->z, r->z /*, p256_mod*/);
/* T2 = X - T1 */
@ -1119,21 +1085,21 @@ static void sp_256_proj_point_dbl_8(sp_point* r, sp_point* p)
/* T1 = X + T1 */
sp_256_mont_add_8(t1, r->x, t1 /*, p256_mod*/);
/* T2 = T1 * T2 */
sp_256to512z_mont_mul_8(t2, t1, t2 /*, p256_mod, p256_mp_mod*/);
sp_256_mont_mul_8(t2, t1, t2 /*, p256_mod, p256_mp_mod*/);
/* T1 = 3T2 */
sp_256_mont_tpl_8(t1, t2 /*, p256_mod*/);
/* Y = 2Y */
sp_256_mont_dbl_8(r->y, r->y /*, p256_mod*/);
/* Y = Y * Y */
sp_256to512z_mont_sqr_8(r->y, r->y /*, p256_mod, p256_mp_mod*/);
sp_256_mont_sqr_8(r->y, r->y /*, p256_mod, p256_mp_mod*/);
/* T2 = Y * Y */
sp_256to512z_mont_sqr_8(t2, r->y /*, p256_mod, p256_mp_mod*/);
sp_256_mont_sqr_8(t2, r->y /*, p256_mod, p256_mp_mod*/);
/* T2 = T2/2 */
sp_256_div2_8(t2 /*, p256_mod*/);
/* Y = Y * X */
sp_256to512z_mont_mul_8(r->y, r->y, r->x /*, p256_mod, p256_mp_mod*/);
sp_256_mont_mul_8(r->y, r->y, r->x /*, p256_mod, p256_mp_mod*/);
/* X = T1 * T1 */
sp_256to512z_mont_mul_8(r->x, t1, t1 /*, p256_mod, p256_mp_mod*/);
sp_256_mont_mul_8(r->x, t1, t1 /*, p256_mod, p256_mp_mod*/);
/* X = X - Y */
sp_256_mont_sub_8(r->x, r->x, r->y /*, p256_mod*/);
/* X = X - Y */
@ -1141,7 +1107,7 @@ static void sp_256_proj_point_dbl_8(sp_point* r, sp_point* p)
/* Y = Y - X */
sp_256_mont_sub_8(r->y, r->y, r->x /*, p256_mod*/);
/* Y = Y * T1 */
sp_256to512z_mont_mul_8(r->y, r->y, t1 /*, p256_mod, p256_mp_mod*/);
sp_256_mont_mul_8(r->y, r->y, t1 /*, p256_mod, p256_mp_mod*/);
/* Y = Y - T2 */
sp_256_mont_sub_8(r->y, r->y, t2 /*, p256_mod*/);
dump_512("y2 %s\n", r->y);
@ -1155,11 +1121,11 @@ static void sp_256_proj_point_dbl_8(sp_point* r, sp_point* p)
*/
static NOINLINE void sp_256_proj_point_add_8(sp_point* r, sp_point* p, sp_point* q)
{
sp_digit t1[2*8];
sp_digit t2[2*8];
sp_digit t3[2*8];
sp_digit t4[2*8];
sp_digit t5[2*8];
sp_digit t1[8];
sp_digit t2[8];
sp_digit t3[8];
sp_digit t4[8];
sp_digit t5[8];
/* Ensure only the first point is the same as the result. */
if (q == r) {
@ -1186,36 +1152,36 @@ static NOINLINE void sp_256_proj_point_add_8(sp_point* r, sp_point* p, sp_point*
}
/* U1 = X1*Z2^2 */
sp_256to512z_mont_sqr_8(t1, q->z /*, p256_mod, p256_mp_mod*/);
sp_256to512z_mont_mul_8(t3, t1, q->z /*, p256_mod, p256_mp_mod*/);
sp_256to512z_mont_mul_8(t1, t1, r->x /*, p256_mod, p256_mp_mod*/);
sp_256_mont_sqr_8(t1, q->z /*, p256_mod, p256_mp_mod*/);
sp_256_mont_mul_8(t3, t1, q->z /*, p256_mod, p256_mp_mod*/);
sp_256_mont_mul_8(t1, t1, r->x /*, p256_mod, p256_mp_mod*/);
/* U2 = X2*Z1^2 */
sp_256to512z_mont_sqr_8(t2, r->z /*, p256_mod, p256_mp_mod*/);
sp_256to512z_mont_mul_8(t4, t2, r->z /*, p256_mod, p256_mp_mod*/);
sp_256to512z_mont_mul_8(t2, t2, q->x /*, p256_mod, p256_mp_mod*/);
sp_256_mont_sqr_8(t2, r->z /*, p256_mod, p256_mp_mod*/);
sp_256_mont_mul_8(t4, t2, r->z /*, p256_mod, p256_mp_mod*/);
sp_256_mont_mul_8(t2, t2, q->x /*, p256_mod, p256_mp_mod*/);
/* S1 = Y1*Z2^3 */
sp_256to512z_mont_mul_8(t3, t3, r->y /*, p256_mod, p256_mp_mod*/);
sp_256_mont_mul_8(t3, t3, r->y /*, p256_mod, p256_mp_mod*/);
/* S2 = Y2*Z1^3 */
sp_256to512z_mont_mul_8(t4, t4, q->y /*, p256_mod, p256_mp_mod*/);
sp_256_mont_mul_8(t4, t4, q->y /*, p256_mod, p256_mp_mod*/);
/* H = U2 - U1 */
sp_256_mont_sub_8(t2, t2, t1 /*, p256_mod*/);
/* R = S2 - S1 */
sp_256_mont_sub_8(t4, t4, t3 /*, p256_mod*/);
/* Z3 = H*Z1*Z2 */
sp_256to512z_mont_mul_8(r->z, r->z, q->z /*, p256_mod, p256_mp_mod*/);
sp_256to512z_mont_mul_8(r->z, r->z, t2 /*, p256_mod, p256_mp_mod*/);
sp_256_mont_mul_8(r->z, r->z, q->z /*, p256_mod, p256_mp_mod*/);
sp_256_mont_mul_8(r->z, r->z, t2 /*, p256_mod, p256_mp_mod*/);
/* X3 = R^2 - H^3 - 2*U1*H^2 */
sp_256to512z_mont_sqr_8(r->x, t4 /*, p256_mod, p256_mp_mod*/);
sp_256to512z_mont_sqr_8(t5, t2 /*, p256_mod, p256_mp_mod*/);
sp_256to512z_mont_mul_8(r->y, t1, t5 /*, p256_mod, p256_mp_mod*/);
sp_256to512z_mont_mul_8(t5, t5, t2 /*, p256_mod, p256_mp_mod*/);
sp_256_mont_sqr_8(r->x, t4 /*, p256_mod, p256_mp_mod*/);
sp_256_mont_sqr_8(t5, t2 /*, p256_mod, p256_mp_mod*/);
sp_256_mont_mul_8(r->y, t1, t5 /*, p256_mod, p256_mp_mod*/);
sp_256_mont_mul_8(t5, t5, t2 /*, p256_mod, p256_mp_mod*/);
sp_256_mont_sub_8(r->x, r->x, t5 /*, p256_mod*/);
sp_256_mont_dbl_8(t1, r->y /*, p256_mod*/);
sp_256_mont_sub_8(r->x, r->x, t1 /*, p256_mod*/);
/* Y3 = R*(U1*H^2 - X3) - S1*H^3 */
sp_256_mont_sub_8(r->y, r->y, r->x /*, p256_mod*/);
sp_256to512z_mont_mul_8(r->y, r->y, t4 /*, p256_mod, p256_mp_mod*/);
sp_256to512z_mont_mul_8(t5, t5, t3 /*, p256_mod, p256_mp_mod*/);
sp_256_mont_mul_8(r->y, r->y, t4 /*, p256_mod, p256_mp_mod*/);
sp_256_mont_mul_8(t5, t5, t3 /*, p256_mod, p256_mp_mod*/);
sp_256_mont_sub_8(r->y, r->y, t5 /*, p256_mod*/);
}