blob: 414c78a9c2143b700b7200dbf3216d285b93620c
1 | /* |
2 | * Copyright (c) 2013, Kenneth MacKay |
3 | * All rights reserved. |
4 | * |
5 | * Redistribution and use in source and binary forms, with or without |
6 | * modification, are permitted provided that the following conditions are |
7 | * met: |
8 | * * Redistributions of source code must retain the above copyright |
9 | * notice, this list of conditions and the following disclaimer. |
10 | * * Redistributions in binary form must reproduce the above copyright |
11 | * notice, this list of conditions and the following disclaimer in the |
12 | * documentation and/or other materials provided with the distribution. |
13 | * |
14 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
15 | * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
16 | * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR |
17 | * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT |
18 | * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, |
19 | * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT |
20 | * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
21 | * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
22 | * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
24 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
25 | */ |
26 | |
27 | #include <linux/random.h> |
28 | #include <linux/slab.h> |
29 | #include <linux/swab.h> |
30 | #include <linux/fips.h> |
31 | #include <crypto/ecdh.h> |
32 | |
33 | #include "ecc.h" |
34 | #include "ecc_curve_defs.h" |
35 | |
36 | typedef struct { |
37 | u64 m_low; |
38 | u64 m_high; |
39 | } uint128_t; |
40 | |
41 | static inline const struct ecc_curve *ecc_get_curve(unsigned int curve_id) |
42 | { |
43 | switch (curve_id) { |
44 | /* In FIPS mode only allow P256 and higher */ |
45 | case ECC_CURVE_NIST_P192: |
46 | return fips_enabled ? NULL : &nist_p192; |
47 | case ECC_CURVE_NIST_P256: |
48 | return &nist_p256; |
49 | default: |
50 | return NULL; |
51 | } |
52 | } |
53 | |
54 | static u64 *ecc_alloc_digits_space(unsigned int ndigits) |
55 | { |
56 | size_t len = ndigits * sizeof(u64); |
57 | |
58 | if (!len) |
59 | return NULL; |
60 | |
61 | return kmalloc(len, GFP_KERNEL); |
62 | } |
63 | |
64 | static void ecc_free_digits_space(u64 *space) |
65 | { |
66 | kzfree(space); |
67 | } |
68 | |
69 | static struct ecc_point *ecc_alloc_point(unsigned int ndigits) |
70 | { |
71 | struct ecc_point *p = kmalloc(sizeof(*p), GFP_KERNEL); |
72 | |
73 | if (!p) |
74 | return NULL; |
75 | |
76 | p->x = ecc_alloc_digits_space(ndigits); |
77 | if (!p->x) |
78 | goto err_alloc_x; |
79 | |
80 | p->y = ecc_alloc_digits_space(ndigits); |
81 | if (!p->y) |
82 | goto err_alloc_y; |
83 | |
84 | p->ndigits = ndigits; |
85 | |
86 | return p; |
87 | |
88 | err_alloc_y: |
89 | ecc_free_digits_space(p->x); |
90 | err_alloc_x: |
91 | kfree(p); |
92 | return NULL; |
93 | } |
94 | |
95 | static void ecc_free_point(struct ecc_point *p) |
96 | { |
97 | if (!p) |
98 | return; |
99 | |
100 | kzfree(p->x); |
101 | kzfree(p->y); |
102 | kzfree(p); |
103 | } |
104 | |
105 | static void vli_clear(u64 *vli, unsigned int ndigits) |
106 | { |
107 | int i; |
108 | |
109 | for (i = 0; i < ndigits; i++) |
110 | vli[i] = 0; |
111 | } |
112 | |
113 | /* Returns true if vli == 0, false otherwise. */ |
114 | static bool vli_is_zero(const u64 *vli, unsigned int ndigits) |
115 | { |
116 | int i; |
117 | |
118 | for (i = 0; i < ndigits; i++) { |
119 | if (vli[i]) |
120 | return false; |
121 | } |
122 | |
123 | return true; |
124 | } |
125 | |
126 | /* Returns nonzero if bit bit of vli is set. */ |
127 | static u64 vli_test_bit(const u64 *vli, unsigned int bit) |
128 | { |
129 | return (vli[bit / 64] & ((u64)1 << (bit % 64))); |
130 | } |
131 | |
132 | /* Counts the number of 64-bit "digits" in vli. */ |
133 | static unsigned int vli_num_digits(const u64 *vli, unsigned int ndigits) |
134 | { |
135 | int i; |
136 | |
137 | /* Search from the end until we find a non-zero digit. |
138 | * We do it in reverse because we expect that most digits will |
139 | * be nonzero. |
140 | */ |
141 | for (i = ndigits - 1; i >= 0 && vli[i] == 0; i--); |
142 | |
143 | return (i + 1); |
144 | } |
145 | |
146 | /* Counts the number of bits required for vli. */ |
147 | static unsigned int vli_num_bits(const u64 *vli, unsigned int ndigits) |
148 | { |
149 | unsigned int i, num_digits; |
150 | u64 digit; |
151 | |
152 | num_digits = vli_num_digits(vli, ndigits); |
153 | if (num_digits == 0) |
154 | return 0; |
155 | |
156 | digit = vli[num_digits - 1]; |
157 | for (i = 0; digit; i++) |
158 | digit >>= 1; |
159 | |
160 | return ((num_digits - 1) * 64 + i); |
161 | } |
162 | |
163 | /* Sets dest = src. */ |
164 | static void vli_set(u64 *dest, const u64 *src, unsigned int ndigits) |
165 | { |
166 | int i; |
167 | |
168 | for (i = 0; i < ndigits; i++) |
169 | dest[i] = src[i]; |
170 | } |
171 | |
172 | /* Returns sign of left - right. */ |
173 | static int vli_cmp(const u64 *left, const u64 *right, unsigned int ndigits) |
174 | { |
175 | int i; |
176 | |
177 | for (i = ndigits - 1; i >= 0; i--) { |
178 | if (left[i] > right[i]) |
179 | return 1; |
180 | else if (left[i] < right[i]) |
181 | return -1; |
182 | } |
183 | |
184 | return 0; |
185 | } |
186 | |
187 | /* Computes result = in << c, returning carry. Can modify in place |
188 | * (if result == in). 0 < shift < 64. |
189 | */ |
190 | static u64 vli_lshift(u64 *result, const u64 *in, unsigned int shift, |
191 | unsigned int ndigits) |
192 | { |
193 | u64 carry = 0; |
194 | int i; |
195 | |
196 | for (i = 0; i < ndigits; i++) { |
197 | u64 temp = in[i]; |
198 | |
199 | result[i] = (temp << shift) | carry; |
200 | carry = temp >> (64 - shift); |
201 | } |
202 | |
203 | return carry; |
204 | } |
205 | |
206 | /* Computes vli = vli >> 1. */ |
207 | static void vli_rshift1(u64 *vli, unsigned int ndigits) |
208 | { |
209 | u64 *end = vli; |
210 | u64 carry = 0; |
211 | |
212 | vli += ndigits; |
213 | |
214 | while (vli-- > end) { |
215 | u64 temp = *vli; |
216 | *vli = (temp >> 1) | carry; |
217 | carry = temp << 63; |
218 | } |
219 | } |
220 | |
221 | /* Computes result = left + right, returning carry. Can modify in place. */ |
222 | static u64 vli_add(u64 *result, const u64 *left, const u64 *right, |
223 | unsigned int ndigits) |
224 | { |
225 | u64 carry = 0; |
226 | int i; |
227 | |
228 | for (i = 0; i < ndigits; i++) { |
229 | u64 sum; |
230 | |
231 | sum = left[i] + right[i] + carry; |
232 | if (sum != left[i]) |
233 | carry = (sum < left[i]); |
234 | |
235 | result[i] = sum; |
236 | } |
237 | |
238 | return carry; |
239 | } |
240 | |
241 | /* Computes result = left - right, returning borrow. Can modify in place. */ |
242 | static u64 vli_sub(u64 *result, const u64 *left, const u64 *right, |
243 | unsigned int ndigits) |
244 | { |
245 | u64 borrow = 0; |
246 | int i; |
247 | |
248 | for (i = 0; i < ndigits; i++) { |
249 | u64 diff; |
250 | |
251 | diff = left[i] - right[i] - borrow; |
252 | if (diff != left[i]) |
253 | borrow = (diff > left[i]); |
254 | |
255 | result[i] = diff; |
256 | } |
257 | |
258 | return borrow; |
259 | } |
260 | |
261 | static uint128_t mul_64_64(u64 left, u64 right) |
262 | { |
263 | u64 a0 = left & 0xffffffffull; |
264 | u64 a1 = left >> 32; |
265 | u64 b0 = right & 0xffffffffull; |
266 | u64 b1 = right >> 32; |
267 | u64 m0 = a0 * b0; |
268 | u64 m1 = a0 * b1; |
269 | u64 m2 = a1 * b0; |
270 | u64 m3 = a1 * b1; |
271 | uint128_t result; |
272 | |
273 | m2 += (m0 >> 32); |
274 | m2 += m1; |
275 | |
276 | /* Overflow */ |
277 | if (m2 < m1) |
278 | m3 += 0x100000000ull; |
279 | |
280 | result.m_low = (m0 & 0xffffffffull) | (m2 << 32); |
281 | result.m_high = m3 + (m2 >> 32); |
282 | |
283 | return result; |
284 | } |
285 | |
286 | static uint128_t add_128_128(uint128_t a, uint128_t b) |
287 | { |
288 | uint128_t result; |
289 | |
290 | result.m_low = a.m_low + b.m_low; |
291 | result.m_high = a.m_high + b.m_high + (result.m_low < a.m_low); |
292 | |
293 | return result; |
294 | } |
295 | |
296 | static void vli_mult(u64 *result, const u64 *left, const u64 *right, |
297 | unsigned int ndigits) |
298 | { |
299 | uint128_t r01 = { 0, 0 }; |
300 | u64 r2 = 0; |
301 | unsigned int i, k; |
302 | |
303 | /* Compute each digit of result in sequence, maintaining the |
304 | * carries. |
305 | */ |
306 | for (k = 0; k < ndigits * 2 - 1; k++) { |
307 | unsigned int min; |
308 | |
309 | if (k < ndigits) |
310 | min = 0; |
311 | else |
312 | min = (k + 1) - ndigits; |
313 | |
314 | for (i = min; i <= k && i < ndigits; i++) { |
315 | uint128_t product; |
316 | |
317 | product = mul_64_64(left[i], right[k - i]); |
318 | |
319 | r01 = add_128_128(r01, product); |
320 | r2 += (r01.m_high < product.m_high); |
321 | } |
322 | |
323 | result[k] = r01.m_low; |
324 | r01.m_low = r01.m_high; |
325 | r01.m_high = r2; |
326 | r2 = 0; |
327 | } |
328 | |
329 | result[ndigits * 2 - 1] = r01.m_low; |
330 | } |
331 | |
332 | static void vli_square(u64 *result, const u64 *left, unsigned int ndigits) |
333 | { |
334 | uint128_t r01 = { 0, 0 }; |
335 | u64 r2 = 0; |
336 | int i, k; |
337 | |
338 | for (k = 0; k < ndigits * 2 - 1; k++) { |
339 | unsigned int min; |
340 | |
341 | if (k < ndigits) |
342 | min = 0; |
343 | else |
344 | min = (k + 1) - ndigits; |
345 | |
346 | for (i = min; i <= k && i <= k - i; i++) { |
347 | uint128_t product; |
348 | |
349 | product = mul_64_64(left[i], left[k - i]); |
350 | |
351 | if (i < k - i) { |
352 | r2 += product.m_high >> 63; |
353 | product.m_high = (product.m_high << 1) | |
354 | (product.m_low >> 63); |
355 | product.m_low <<= 1; |
356 | } |
357 | |
358 | r01 = add_128_128(r01, product); |
359 | r2 += (r01.m_high < product.m_high); |
360 | } |
361 | |
362 | result[k] = r01.m_low; |
363 | r01.m_low = r01.m_high; |
364 | r01.m_high = r2; |
365 | r2 = 0; |
366 | } |
367 | |
368 | result[ndigits * 2 - 1] = r01.m_low; |
369 | } |
370 | |
371 | /* Computes result = (left + right) % mod. |
372 | * Assumes that left < mod and right < mod, result != mod. |
373 | */ |
374 | static void vli_mod_add(u64 *result, const u64 *left, const u64 *right, |
375 | const u64 *mod, unsigned int ndigits) |
376 | { |
377 | u64 carry; |
378 | |
379 | carry = vli_add(result, left, right, ndigits); |
380 | |
381 | /* result > mod (result = mod + remainder), so subtract mod to |
382 | * get remainder. |
383 | */ |
384 | if (carry || vli_cmp(result, mod, ndigits) >= 0) |
385 | vli_sub(result, result, mod, ndigits); |
386 | } |
387 | |
388 | /* Computes result = (left - right) % mod. |
389 | * Assumes that left < mod and right < mod, result != mod. |
390 | */ |
391 | static void vli_mod_sub(u64 *result, const u64 *left, const u64 *right, |
392 | const u64 *mod, unsigned int ndigits) |
393 | { |
394 | u64 borrow = vli_sub(result, left, right, ndigits); |
395 | |
396 | /* In this case, p_result == -diff == (max int) - diff. |
397 | * Since -x % d == d - x, we can get the correct result from |
398 | * result + mod (with overflow). |
399 | */ |
400 | if (borrow) |
401 | vli_add(result, result, mod, ndigits); |
402 | } |
403 | |
404 | /* Computes p_result = p_product % curve_p. |
405 | * See algorithm 5 and 6 from |
406 | * http://www.isys.uni-klu.ac.at/PDF/2001-0126-MT.pdf |
407 | */ |
408 | static void vli_mmod_fast_192(u64 *result, const u64 *product, |
409 | const u64 *curve_prime, u64 *tmp) |
410 | { |
411 | const unsigned int ndigits = 3; |
412 | int carry; |
413 | |
414 | vli_set(result, product, ndigits); |
415 | |
416 | vli_set(tmp, &product[3], ndigits); |
417 | carry = vli_add(result, result, tmp, ndigits); |
418 | |
419 | tmp[0] = 0; |
420 | tmp[1] = product[3]; |
421 | tmp[2] = product[4]; |
422 | carry += vli_add(result, result, tmp, ndigits); |
423 | |
424 | tmp[0] = tmp[1] = product[5]; |
425 | tmp[2] = 0; |
426 | carry += vli_add(result, result, tmp, ndigits); |
427 | |
428 | while (carry || vli_cmp(curve_prime, result, ndigits) != 1) |
429 | carry -= vli_sub(result, result, curve_prime, ndigits); |
430 | } |
431 | |
432 | /* Computes result = product % curve_prime |
433 | * from http://www.nsa.gov/ia/_files/nist-routines.pdf |
434 | */ |
435 | static void vli_mmod_fast_256(u64 *result, const u64 *product, |
436 | const u64 *curve_prime, u64 *tmp) |
437 | { |
438 | int carry; |
439 | const unsigned int ndigits = 4; |
440 | |
441 | /* t */ |
442 | vli_set(result, product, ndigits); |
443 | |
444 | /* s1 */ |
445 | tmp[0] = 0; |
446 | tmp[1] = product[5] & 0xffffffff00000000ull; |
447 | tmp[2] = product[6]; |
448 | tmp[3] = product[7]; |
449 | carry = vli_lshift(tmp, tmp, 1, ndigits); |
450 | carry += vli_add(result, result, tmp, ndigits); |
451 | |
452 | /* s2 */ |
453 | tmp[1] = product[6] << 32; |
454 | tmp[2] = (product[6] >> 32) | (product[7] << 32); |
455 | tmp[3] = product[7] >> 32; |
456 | carry += vli_lshift(tmp, tmp, 1, ndigits); |
457 | carry += vli_add(result, result, tmp, ndigits); |
458 | |
459 | /* s3 */ |
460 | tmp[0] = product[4]; |
461 | tmp[1] = product[5] & 0xffffffff; |
462 | tmp[2] = 0; |
463 | tmp[3] = product[7]; |
464 | carry += vli_add(result, result, tmp, ndigits); |
465 | |
466 | /* s4 */ |
467 | tmp[0] = (product[4] >> 32) | (product[5] << 32); |
468 | tmp[1] = (product[5] >> 32) | (product[6] & 0xffffffff00000000ull); |
469 | tmp[2] = product[7]; |
470 | tmp[3] = (product[6] >> 32) | (product[4] << 32); |
471 | carry += vli_add(result, result, tmp, ndigits); |
472 | |
473 | /* d1 */ |
474 | tmp[0] = (product[5] >> 32) | (product[6] << 32); |
475 | tmp[1] = (product[6] >> 32); |
476 | tmp[2] = 0; |
477 | tmp[3] = (product[4] & 0xffffffff) | (product[5] << 32); |
478 | carry -= vli_sub(result, result, tmp, ndigits); |
479 | |
480 | /* d2 */ |
481 | tmp[0] = product[6]; |
482 | tmp[1] = product[7]; |
483 | tmp[2] = 0; |
484 | tmp[3] = (product[4] >> 32) | (product[5] & 0xffffffff00000000ull); |
485 | carry -= vli_sub(result, result, tmp, ndigits); |
486 | |
487 | /* d3 */ |
488 | tmp[0] = (product[6] >> 32) | (product[7] << 32); |
489 | tmp[1] = (product[7] >> 32) | (product[4] << 32); |
490 | tmp[2] = (product[4] >> 32) | (product[5] << 32); |
491 | tmp[3] = (product[6] << 32); |
492 | carry -= vli_sub(result, result, tmp, ndigits); |
493 | |
494 | /* d4 */ |
495 | tmp[0] = product[7]; |
496 | tmp[1] = product[4] & 0xffffffff00000000ull; |
497 | tmp[2] = product[5]; |
498 | tmp[3] = product[6] & 0xffffffff00000000ull; |
499 | carry -= vli_sub(result, result, tmp, ndigits); |
500 | |
501 | if (carry < 0) { |
502 | do { |
503 | carry += vli_add(result, result, curve_prime, ndigits); |
504 | } while (carry < 0); |
505 | } else { |
506 | while (carry || vli_cmp(curve_prime, result, ndigits) != 1) |
507 | carry -= vli_sub(result, result, curve_prime, ndigits); |
508 | } |
509 | } |
510 | |
511 | /* Computes result = product % curve_prime |
512 | * from http://www.nsa.gov/ia/_files/nist-routines.pdf |
513 | */ |
514 | static bool vli_mmod_fast(u64 *result, u64 *product, |
515 | const u64 *curve_prime, unsigned int ndigits) |
516 | { |
517 | u64 tmp[2 * ndigits]; |
518 | |
519 | switch (ndigits) { |
520 | case 3: |
521 | vli_mmod_fast_192(result, product, curve_prime, tmp); |
522 | break; |
523 | case 4: |
524 | vli_mmod_fast_256(result, product, curve_prime, tmp); |
525 | break; |
526 | default: |
527 | pr_err("unsupports digits size!\n"); |
528 | return false; |
529 | } |
530 | |
531 | return true; |
532 | } |
533 | |
534 | /* Computes result = (left * right) % curve_prime. */ |
535 | static void vli_mod_mult_fast(u64 *result, const u64 *left, const u64 *right, |
536 | const u64 *curve_prime, unsigned int ndigits) |
537 | { |
538 | u64 product[2 * ndigits]; |
539 | |
540 | vli_mult(product, left, right, ndigits); |
541 | vli_mmod_fast(result, product, curve_prime, ndigits); |
542 | } |
543 | |
544 | /* Computes result = left^2 % curve_prime. */ |
545 | static void vli_mod_square_fast(u64 *result, const u64 *left, |
546 | const u64 *curve_prime, unsigned int ndigits) |
547 | { |
548 | u64 product[2 * ndigits]; |
549 | |
550 | vli_square(product, left, ndigits); |
551 | vli_mmod_fast(result, product, curve_prime, ndigits); |
552 | } |
553 | |
554 | #define EVEN(vli) (!(vli[0] & 1)) |
555 | /* Computes result = (1 / p_input) % mod. All VLIs are the same size. |
556 | * See "From Euclid's GCD to Montgomery Multiplication to the Great Divide" |
557 | * https://labs.oracle.com/techrep/2001/smli_tr-2001-95.pdf |
558 | */ |
559 | static void vli_mod_inv(u64 *result, const u64 *input, const u64 *mod, |
560 | unsigned int ndigits) |
561 | { |
562 | u64 a[ndigits], b[ndigits]; |
563 | u64 u[ndigits], v[ndigits]; |
564 | u64 carry; |
565 | int cmp_result; |
566 | |
567 | if (vli_is_zero(input, ndigits)) { |
568 | vli_clear(result, ndigits); |
569 | return; |
570 | } |
571 | |
572 | vli_set(a, input, ndigits); |
573 | vli_set(b, mod, ndigits); |
574 | vli_clear(u, ndigits); |
575 | u[0] = 1; |
576 | vli_clear(v, ndigits); |
577 | |
578 | while ((cmp_result = vli_cmp(a, b, ndigits)) != 0) { |
579 | carry = 0; |
580 | |
581 | if (EVEN(a)) { |
582 | vli_rshift1(a, ndigits); |
583 | |
584 | if (!EVEN(u)) |
585 | carry = vli_add(u, u, mod, ndigits); |
586 | |
587 | vli_rshift1(u, ndigits); |
588 | if (carry) |
589 | u[ndigits - 1] |= 0x8000000000000000ull; |
590 | } else if (EVEN(b)) { |
591 | vli_rshift1(b, ndigits); |
592 | |
593 | if (!EVEN(v)) |
594 | carry = vli_add(v, v, mod, ndigits); |
595 | |
596 | vli_rshift1(v, ndigits); |
597 | if (carry) |
598 | v[ndigits - 1] |= 0x8000000000000000ull; |
599 | } else if (cmp_result > 0) { |
600 | vli_sub(a, a, b, ndigits); |
601 | vli_rshift1(a, ndigits); |
602 | |
603 | if (vli_cmp(u, v, ndigits) < 0) |
604 | vli_add(u, u, mod, ndigits); |
605 | |
606 | vli_sub(u, u, v, ndigits); |
607 | if (!EVEN(u)) |
608 | carry = vli_add(u, u, mod, ndigits); |
609 | |
610 | vli_rshift1(u, ndigits); |
611 | if (carry) |
612 | u[ndigits - 1] |= 0x8000000000000000ull; |
613 | } else { |
614 | vli_sub(b, b, a, ndigits); |
615 | vli_rshift1(b, ndigits); |
616 | |
617 | if (vli_cmp(v, u, ndigits) < 0) |
618 | vli_add(v, v, mod, ndigits); |
619 | |
620 | vli_sub(v, v, u, ndigits); |
621 | if (!EVEN(v)) |
622 | carry = vli_add(v, v, mod, ndigits); |
623 | |
624 | vli_rshift1(v, ndigits); |
625 | if (carry) |
626 | v[ndigits - 1] |= 0x8000000000000000ull; |
627 | } |
628 | } |
629 | |
630 | vli_set(result, u, ndigits); |
631 | } |
632 | |
633 | /* ------ Point operations ------ */ |
634 | |
635 | /* Returns true if p_point is the point at infinity, false otherwise. */ |
636 | static bool ecc_point_is_zero(const struct ecc_point *point) |
637 | { |
638 | return (vli_is_zero(point->x, point->ndigits) && |
639 | vli_is_zero(point->y, point->ndigits)); |
640 | } |
641 | |
642 | /* Point multiplication algorithm using Montgomery's ladder with co-Z |
643 | * coordinates. From http://eprint.iacr.org/2011/338.pdf |
644 | */ |
645 | |
646 | /* Double in place */ |
647 | static void ecc_point_double_jacobian(u64 *x1, u64 *y1, u64 *z1, |
648 | u64 *curve_prime, unsigned int ndigits) |
649 | { |
650 | /* t1 = x, t2 = y, t3 = z */ |
651 | u64 t4[ndigits]; |
652 | u64 t5[ndigits]; |
653 | |
654 | if (vli_is_zero(z1, ndigits)) |
655 | return; |
656 | |
657 | /* t4 = y1^2 */ |
658 | vli_mod_square_fast(t4, y1, curve_prime, ndigits); |
659 | /* t5 = x1*y1^2 = A */ |
660 | vli_mod_mult_fast(t5, x1, t4, curve_prime, ndigits); |
661 | /* t4 = y1^4 */ |
662 | vli_mod_square_fast(t4, t4, curve_prime, ndigits); |
663 | /* t2 = y1*z1 = z3 */ |
664 | vli_mod_mult_fast(y1, y1, z1, curve_prime, ndigits); |
665 | /* t3 = z1^2 */ |
666 | vli_mod_square_fast(z1, z1, curve_prime, ndigits); |
667 | |
668 | /* t1 = x1 + z1^2 */ |
669 | vli_mod_add(x1, x1, z1, curve_prime, ndigits); |
670 | /* t3 = 2*z1^2 */ |
671 | vli_mod_add(z1, z1, z1, curve_prime, ndigits); |
672 | /* t3 = x1 - z1^2 */ |
673 | vli_mod_sub(z1, x1, z1, curve_prime, ndigits); |
674 | /* t1 = x1^2 - z1^4 */ |
675 | vli_mod_mult_fast(x1, x1, z1, curve_prime, ndigits); |
676 | |
677 | /* t3 = 2*(x1^2 - z1^4) */ |
678 | vli_mod_add(z1, x1, x1, curve_prime, ndigits); |
679 | /* t1 = 3*(x1^2 - z1^4) */ |
680 | vli_mod_add(x1, x1, z1, curve_prime, ndigits); |
681 | if (vli_test_bit(x1, 0)) { |
682 | u64 carry = vli_add(x1, x1, curve_prime, ndigits); |
683 | |
684 | vli_rshift1(x1, ndigits); |
685 | x1[ndigits - 1] |= carry << 63; |
686 | } else { |
687 | vli_rshift1(x1, ndigits); |
688 | } |
689 | /* t1 = 3/2*(x1^2 - z1^4) = B */ |
690 | |
691 | /* t3 = B^2 */ |
692 | vli_mod_square_fast(z1, x1, curve_prime, ndigits); |
693 | /* t3 = B^2 - A */ |
694 | vli_mod_sub(z1, z1, t5, curve_prime, ndigits); |
695 | /* t3 = B^2 - 2A = x3 */ |
696 | vli_mod_sub(z1, z1, t5, curve_prime, ndigits); |
697 | /* t5 = A - x3 */ |
698 | vli_mod_sub(t5, t5, z1, curve_prime, ndigits); |
699 | /* t1 = B * (A - x3) */ |
700 | vli_mod_mult_fast(x1, x1, t5, curve_prime, ndigits); |
701 | /* t4 = B * (A - x3) - y1^4 = y3 */ |
702 | vli_mod_sub(t4, x1, t4, curve_prime, ndigits); |
703 | |
704 | vli_set(x1, z1, ndigits); |
705 | vli_set(z1, y1, ndigits); |
706 | vli_set(y1, t4, ndigits); |
707 | } |
708 | |
709 | /* Modify (x1, y1) => (x1 * z^2, y1 * z^3) */ |
710 | static void apply_z(u64 *x1, u64 *y1, u64 *z, u64 *curve_prime, |
711 | unsigned int ndigits) |
712 | { |
713 | u64 t1[ndigits]; |
714 | |
715 | vli_mod_square_fast(t1, z, curve_prime, ndigits); /* z^2 */ |
716 | vli_mod_mult_fast(x1, x1, t1, curve_prime, ndigits); /* x1 * z^2 */ |
717 | vli_mod_mult_fast(t1, t1, z, curve_prime, ndigits); /* z^3 */ |
718 | vli_mod_mult_fast(y1, y1, t1, curve_prime, ndigits); /* y1 * z^3 */ |
719 | } |
720 | |
721 | /* P = (x1, y1) => 2P, (x2, y2) => P' */ |
722 | static void xycz_initial_double(u64 *x1, u64 *y1, u64 *x2, u64 *y2, |
723 | u64 *p_initial_z, u64 *curve_prime, |
724 | unsigned int ndigits) |
725 | { |
726 | u64 z[ndigits]; |
727 | |
728 | vli_set(x2, x1, ndigits); |
729 | vli_set(y2, y1, ndigits); |
730 | |
731 | vli_clear(z, ndigits); |
732 | z[0] = 1; |
733 | |
734 | if (p_initial_z) |
735 | vli_set(z, p_initial_z, ndigits); |
736 | |
737 | apply_z(x1, y1, z, curve_prime, ndigits); |
738 | |
739 | ecc_point_double_jacobian(x1, y1, z, curve_prime, ndigits); |
740 | |
741 | apply_z(x2, y2, z, curve_prime, ndigits); |
742 | } |
743 | |
744 | /* Input P = (x1, y1, Z), Q = (x2, y2, Z) |
745 | * Output P' = (x1', y1', Z3), P + Q = (x3, y3, Z3) |
746 | * or P => P', Q => P + Q |
747 | */ |
748 | static void xycz_add(u64 *x1, u64 *y1, u64 *x2, u64 *y2, u64 *curve_prime, |
749 | unsigned int ndigits) |
750 | { |
751 | /* t1 = X1, t2 = Y1, t3 = X2, t4 = Y2 */ |
752 | u64 t5[ndigits]; |
753 | |
754 | /* t5 = x2 - x1 */ |
755 | vli_mod_sub(t5, x2, x1, curve_prime, ndigits); |
756 | /* t5 = (x2 - x1)^2 = A */ |
757 | vli_mod_square_fast(t5, t5, curve_prime, ndigits); |
758 | /* t1 = x1*A = B */ |
759 | vli_mod_mult_fast(x1, x1, t5, curve_prime, ndigits); |
760 | /* t3 = x2*A = C */ |
761 | vli_mod_mult_fast(x2, x2, t5, curve_prime, ndigits); |
762 | /* t4 = y2 - y1 */ |
763 | vli_mod_sub(y2, y2, y1, curve_prime, ndigits); |
764 | /* t5 = (y2 - y1)^2 = D */ |
765 | vli_mod_square_fast(t5, y2, curve_prime, ndigits); |
766 | |
767 | /* t5 = D - B */ |
768 | vli_mod_sub(t5, t5, x1, curve_prime, ndigits); |
769 | /* t5 = D - B - C = x3 */ |
770 | vli_mod_sub(t5, t5, x2, curve_prime, ndigits); |
771 | /* t3 = C - B */ |
772 | vli_mod_sub(x2, x2, x1, curve_prime, ndigits); |
773 | /* t2 = y1*(C - B) */ |
774 | vli_mod_mult_fast(y1, y1, x2, curve_prime, ndigits); |
775 | /* t3 = B - x3 */ |
776 | vli_mod_sub(x2, x1, t5, curve_prime, ndigits); |
777 | /* t4 = (y2 - y1)*(B - x3) */ |
778 | vli_mod_mult_fast(y2, y2, x2, curve_prime, ndigits); |
779 | /* t4 = y3 */ |
780 | vli_mod_sub(y2, y2, y1, curve_prime, ndigits); |
781 | |
782 | vli_set(x2, t5, ndigits); |
783 | } |
784 | |
785 | /* Input P = (x1, y1, Z), Q = (x2, y2, Z) |
786 | * Output P + Q = (x3, y3, Z3), P - Q = (x3', y3', Z3) |
787 | * or P => P - Q, Q => P + Q |
788 | */ |
789 | static void xycz_add_c(u64 *x1, u64 *y1, u64 *x2, u64 *y2, u64 *curve_prime, |
790 | unsigned int ndigits) |
791 | { |
792 | /* t1 = X1, t2 = Y1, t3 = X2, t4 = Y2 */ |
793 | u64 t5[ndigits]; |
794 | u64 t6[ndigits]; |
795 | u64 t7[ndigits]; |
796 | |
797 | /* t5 = x2 - x1 */ |
798 | vli_mod_sub(t5, x2, x1, curve_prime, ndigits); |
799 | /* t5 = (x2 - x1)^2 = A */ |
800 | vli_mod_square_fast(t5, t5, curve_prime, ndigits); |
801 | /* t1 = x1*A = B */ |
802 | vli_mod_mult_fast(x1, x1, t5, curve_prime, ndigits); |
803 | /* t3 = x2*A = C */ |
804 | vli_mod_mult_fast(x2, x2, t5, curve_prime, ndigits); |
805 | /* t4 = y2 + y1 */ |
806 | vli_mod_add(t5, y2, y1, curve_prime, ndigits); |
807 | /* t4 = y2 - y1 */ |
808 | vli_mod_sub(y2, y2, y1, curve_prime, ndigits); |
809 | |
810 | /* t6 = C - B */ |
811 | vli_mod_sub(t6, x2, x1, curve_prime, ndigits); |
812 | /* t2 = y1 * (C - B) */ |
813 | vli_mod_mult_fast(y1, y1, t6, curve_prime, ndigits); |
814 | /* t6 = B + C */ |
815 | vli_mod_add(t6, x1, x2, curve_prime, ndigits); |
816 | /* t3 = (y2 - y1)^2 */ |
817 | vli_mod_square_fast(x2, y2, curve_prime, ndigits); |
818 | /* t3 = x3 */ |
819 | vli_mod_sub(x2, x2, t6, curve_prime, ndigits); |
820 | |
821 | /* t7 = B - x3 */ |
822 | vli_mod_sub(t7, x1, x2, curve_prime, ndigits); |
823 | /* t4 = (y2 - y1)*(B - x3) */ |
824 | vli_mod_mult_fast(y2, y2, t7, curve_prime, ndigits); |
825 | /* t4 = y3 */ |
826 | vli_mod_sub(y2, y2, y1, curve_prime, ndigits); |
827 | |
828 | /* t7 = (y2 + y1)^2 = F */ |
829 | vli_mod_square_fast(t7, t5, curve_prime, ndigits); |
830 | /* t7 = x3' */ |
831 | vli_mod_sub(t7, t7, t6, curve_prime, ndigits); |
832 | /* t6 = x3' - B */ |
833 | vli_mod_sub(t6, t7, x1, curve_prime, ndigits); |
834 | /* t6 = (y2 + y1)*(x3' - B) */ |
835 | vli_mod_mult_fast(t6, t6, t5, curve_prime, ndigits); |
836 | /* t2 = y3' */ |
837 | vli_mod_sub(y1, t6, y1, curve_prime, ndigits); |
838 | |
839 | vli_set(x1, t7, ndigits); |
840 | } |
841 | |
842 | static void ecc_point_mult(struct ecc_point *result, |
843 | const struct ecc_point *point, const u64 *scalar, |
844 | u64 *initial_z, u64 *curve_prime, |
845 | unsigned int ndigits) |
846 | { |
847 | /* R0 and R1 */ |
848 | u64 rx[2][ndigits]; |
849 | u64 ry[2][ndigits]; |
850 | u64 z[ndigits]; |
851 | int i, nb; |
852 | int num_bits = vli_num_bits(scalar, ndigits); |
853 | |
854 | vli_set(rx[1], point->x, ndigits); |
855 | vli_set(ry[1], point->y, ndigits); |
856 | |
857 | xycz_initial_double(rx[1], ry[1], rx[0], ry[0], initial_z, curve_prime, |
858 | ndigits); |
859 | |
860 | for (i = num_bits - 2; i > 0; i--) { |
861 | nb = !vli_test_bit(scalar, i); |
862 | xycz_add_c(rx[1 - nb], ry[1 - nb], rx[nb], ry[nb], curve_prime, |
863 | ndigits); |
864 | xycz_add(rx[nb], ry[nb], rx[1 - nb], ry[1 - nb], curve_prime, |
865 | ndigits); |
866 | } |
867 | |
868 | nb = !vli_test_bit(scalar, 0); |
869 | xycz_add_c(rx[1 - nb], ry[1 - nb], rx[nb], ry[nb], curve_prime, |
870 | ndigits); |
871 | |
872 | /* Find final 1/Z value. */ |
873 | /* X1 - X0 */ |
874 | vli_mod_sub(z, rx[1], rx[0], curve_prime, ndigits); |
875 | /* Yb * (X1 - X0) */ |
876 | vli_mod_mult_fast(z, z, ry[1 - nb], curve_prime, ndigits); |
877 | /* xP * Yb * (X1 - X0) */ |
878 | vli_mod_mult_fast(z, z, point->x, curve_prime, ndigits); |
879 | |
880 | /* 1 / (xP * Yb * (X1 - X0)) */ |
881 | vli_mod_inv(z, z, curve_prime, point->ndigits); |
882 | |
883 | /* yP / (xP * Yb * (X1 - X0)) */ |
884 | vli_mod_mult_fast(z, z, point->y, curve_prime, ndigits); |
885 | /* Xb * yP / (xP * Yb * (X1 - X0)) */ |
886 | vli_mod_mult_fast(z, z, rx[1 - nb], curve_prime, ndigits); |
887 | /* End 1/Z calculation */ |
888 | |
889 | xycz_add(rx[nb], ry[nb], rx[1 - nb], ry[1 - nb], curve_prime, ndigits); |
890 | |
891 | apply_z(rx[0], ry[0], z, curve_prime, ndigits); |
892 | |
893 | vli_set(result->x, rx[0], ndigits); |
894 | vli_set(result->y, ry[0], ndigits); |
895 | } |
896 | |
897 | static inline void ecc_swap_digits(const u64 *in, u64 *out, |
898 | unsigned int ndigits) |
899 | { |
900 | int i; |
901 | |
902 | for (i = 0; i < ndigits; i++) |
903 | out[i] = __swab64(in[ndigits - 1 - i]); |
904 | } |
905 | |
906 | int ecc_is_key_valid(unsigned int curve_id, unsigned int ndigits, |
907 | const u8 *private_key, unsigned int private_key_len) |
908 | { |
909 | int nbytes; |
910 | const struct ecc_curve *curve = ecc_get_curve(curve_id); |
911 | |
912 | if (!private_key) |
913 | return -EINVAL; |
914 | |
915 | nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT; |
916 | |
917 | if (private_key_len != nbytes) |
918 | return -EINVAL; |
919 | |
920 | if (vli_is_zero((const u64 *)&private_key[0], ndigits)) |
921 | return -EINVAL; |
922 | |
923 | /* Make sure the private key is in the range [1, n-1]. */ |
924 | if (vli_cmp(curve->n, (const u64 *)&private_key[0], ndigits) != 1) |
925 | return -EINVAL; |
926 | |
927 | return 0; |
928 | } |
929 | |
930 | int ecdh_make_pub_key(unsigned int curve_id, unsigned int ndigits, |
931 | const u8 *private_key, unsigned int private_key_len, |
932 | u8 *public_key, unsigned int public_key_len) |
933 | { |
934 | int ret = 0; |
935 | struct ecc_point *pk; |
936 | u64 priv[ndigits]; |
937 | unsigned int nbytes; |
938 | const struct ecc_curve *curve = ecc_get_curve(curve_id); |
939 | |
940 | if (!private_key || !curve) { |
941 | ret = -EINVAL; |
942 | goto out; |
943 | } |
944 | |
945 | ecc_swap_digits((const u64 *)private_key, priv, ndigits); |
946 | |
947 | pk = ecc_alloc_point(ndigits); |
948 | if (!pk) { |
949 | ret = -ENOMEM; |
950 | goto out; |
951 | } |
952 | |
953 | ecc_point_mult(pk, &curve->g, priv, NULL, curve->p, ndigits); |
954 | if (ecc_point_is_zero(pk)) { |
955 | ret = -EAGAIN; |
956 | goto err_free_point; |
957 | } |
958 | |
959 | nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT; |
960 | ecc_swap_digits(pk->x, (u64 *)public_key, ndigits); |
961 | ecc_swap_digits(pk->y, (u64 *)&public_key[nbytes], ndigits); |
962 | |
963 | err_free_point: |
964 | ecc_free_point(pk); |
965 | out: |
966 | return ret; |
967 | } |
968 | |
969 | int crypto_ecdh_shared_secret(unsigned int curve_id, unsigned int ndigits, |
970 | const u8 *private_key, unsigned int private_key_len, |
971 | const u8 *public_key, unsigned int public_key_len, |
972 | u8 *secret, unsigned int secret_len) |
973 | { |
974 | int ret = 0; |
975 | struct ecc_point *product, *pk; |
976 | u64 priv[ndigits]; |
977 | u64 rand_z[ndigits]; |
978 | unsigned int nbytes; |
979 | const struct ecc_curve *curve = ecc_get_curve(curve_id); |
980 | |
981 | if (!private_key || !public_key || !curve) { |
982 | ret = -EINVAL; |
983 | goto out; |
984 | } |
985 | |
986 | nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT; |
987 | |
988 | get_random_bytes(rand_z, nbytes); |
989 | |
990 | pk = ecc_alloc_point(ndigits); |
991 | if (!pk) { |
992 | ret = -ENOMEM; |
993 | goto out; |
994 | } |
995 | |
996 | product = ecc_alloc_point(ndigits); |
997 | if (!product) { |
998 | ret = -ENOMEM; |
999 | goto err_alloc_product; |
1000 | } |
1001 | |
1002 | ecc_swap_digits((const u64 *)public_key, pk->x, ndigits); |
1003 | ecc_swap_digits((const u64 *)&public_key[nbytes], pk->y, ndigits); |
1004 | ecc_swap_digits((const u64 *)private_key, priv, ndigits); |
1005 | |
1006 | ecc_point_mult(product, pk, priv, rand_z, curve->p, ndigits); |
1007 | |
1008 | ecc_swap_digits(product->x, (u64 *)secret, ndigits); |
1009 | |
1010 | if (ecc_point_is_zero(product)) |
1011 | ret = -EFAULT; |
1012 | |
1013 | ecc_free_point(product); |
1014 | err_alloc_product: |
1015 | ecc_free_point(pk); |
1016 | out: |
1017 | return ret; |
1018 | } |
1019 |