P256.sol 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. // SPDX-License-Identifier: MIT
  2. pragma solidity ^0.8.20;
  3. import {Math} from "../math/Math.sol";
  4. import {Errors} from "../Errors.sol";
  5. /**
  6. * @dev Implementation of secp256r1 verification and recovery functions.
  7. *
  8. * The secp256r1 curve (also known as P256) is a NIST standard curve with wide support in modern devices
  9. * and cryptographic standards. Some notable examples include Apple's Secure Enclave and Android's Keystore
  10. * as well as authentication protocols like FIDO2.
  11. *
  12. * Based on the original https://github.com/itsobvioustech/aa-passkeys-wallet/blob/main/src/Secp256r1.sol[implementation of itsobvioustech].
  13. * Heavily inspired in https://github.com/maxrobot/elliptic-solidity/blob/master/contracts/Secp256r1.sol[maxrobot] and
  14. * https://github.com/tdrerup/elliptic-curve-solidity/blob/master/contracts/curves/EllipticCurve.sol[tdrerup] implementations.
  15. */
  16. library P256 {
  17. struct JPoint {
  18. uint256 x;
  19. uint256 y;
  20. uint256 z;
  21. }
  22. /// @dev Generator (x component)
  23. uint256 internal constant GX = 0x6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296;
  24. /// @dev Generator (y component)
  25. uint256 internal constant GY = 0x4FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5;
  26. /// @dev P (size of the field)
  27. uint256 internal constant P = 0xFFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF;
  28. /// @dev N (order of G)
  29. uint256 internal constant N = 0xFFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551;
  30. /// @dev A parameter of the weierstrass equation
  31. uint256 internal constant A = 0xFFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFC;
  32. /// @dev B parameter of the weierstrass equation
  33. uint256 internal constant B = 0x5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B;
  34. /// @dev (P + 1) / 4. Useful to compute sqrt
  35. uint256 private constant P1DIV4 = 0x3fffffffc0000000400000000000000000000000400000000000000000000000;
  36. /// @dev N/2 for excluding higher order `s` values
  37. uint256 private constant HALF_N = 0x7fffffff800000007fffffffffffffffde737d56d38bcf4279dce5617e3192a8;
  38. /**
  39. * @dev Verifies a secp256r1 signature using the RIP-7212 precompile and falls back to the Solidity implementation
  40. * if the precompile is not available. This version should work on all chains, but requires the deployment of more
  41. * bytecode.
  42. *
  43. * @param h - hashed message
  44. * @param r - signature half R
  45. * @param s - signature half S
  46. * @param qx - public key coordinate X
  47. * @param qy - public key coordinate Y
  48. *
  49. * IMPORTANT: This function disallows signatures where the `s` value is above `N/2` to prevent malleability.
  50. * To flip the `s` value, compute `s = N - s`.
  51. */
  52. function verify(bytes32 h, bytes32 r, bytes32 s, bytes32 qx, bytes32 qy) internal view returns (bool) {
  53. (bool valid, bool supported) = _tryVerifyNative(h, r, s, qx, qy);
  54. return supported ? valid : verifySolidity(h, r, s, qx, qy);
  55. }
  56. /**
  57. * @dev Same as {verify}, but it will revert if the required precompile is not available.
  58. *
  59. * Make sure any logic (code or precompile) deployed at that address is the expected one,
  60. * otherwise the returned value may be misinterpreted as a positive boolean.
  61. */
  62. function verifyNative(bytes32 h, bytes32 r, bytes32 s, bytes32 qx, bytes32 qy) internal view returns (bool) {
  63. (bool valid, bool supported) = _tryVerifyNative(h, r, s, qx, qy);
  64. if (supported) {
  65. return valid;
  66. } else {
  67. revert Errors.MissingPrecompile(address(0x100));
  68. }
  69. }
  70. /**
  71. * @dev Same as {verify}, but it will return false if the required precompile is not available.
  72. */
  73. function _tryVerifyNative(
  74. bytes32 h,
  75. bytes32 r,
  76. bytes32 s,
  77. bytes32 qx,
  78. bytes32 qy
  79. ) private view returns (bool valid, bool supported) {
  80. if (!_isProperSignature(r, s) || !isValidPublicKey(qx, qy)) {
  81. return (false, true); // signature is invalid, and its not because the precompile is missing
  82. }
  83. (bool success, bytes memory returndata) = address(0x100).staticcall(abi.encode(h, r, s, qx, qy));
  84. return (success && returndata.length == 0x20) ? (abi.decode(returndata, (bool)), true) : (false, false);
  85. }
  86. /**
  87. * @dev Same as {verify}, but only the Solidity implementation is used.
  88. */
  89. function verifySolidity(bytes32 h, bytes32 r, bytes32 s, bytes32 qx, bytes32 qy) internal view returns (bool) {
  90. if (!_isProperSignature(r, s) || !isValidPublicKey(qx, qy)) {
  91. return false;
  92. }
  93. JPoint[16] memory points = _preComputeJacobianPoints(uint256(qx), uint256(qy));
  94. uint256 w = Math.invModPrime(uint256(s), N);
  95. uint256 u1 = mulmod(uint256(h), w, N);
  96. uint256 u2 = mulmod(uint256(r), w, N);
  97. (uint256 x, ) = _jMultShamir(points, u1, u2);
  98. return ((x % N) == uint256(r));
  99. }
  100. /**
  101. * @dev Public key recovery
  102. *
  103. * @param h - hashed message
  104. * @param v - signature recovery param
  105. * @param r - signature half R
  106. * @param s - signature half S
  107. *
  108. * IMPORTANT: This function disallows signatures where the `s` value is above `N/2` to prevent malleability.
  109. * To flip the `s` value, compute `s = N - s` and `v = 1 - v` if (`v = 0 | 1`).
  110. */
  111. function recovery(bytes32 h, uint8 v, bytes32 r, bytes32 s) internal view returns (bytes32, bytes32) {
  112. if (!_isProperSignature(r, s) || v > 1) {
  113. return (0, 0);
  114. }
  115. uint256 rx = uint256(r);
  116. uint256 ry2 = addmod(mulmod(addmod(mulmod(rx, rx, P), A, P), rx, P), B, P); // weierstrass equation y² = x³ + a.x + b
  117. uint256 ry = Math.modExp(ry2, P1DIV4, P); // This formula for sqrt work because P ≡ 3 (mod 4)
  118. if (mulmod(ry, ry, P) != ry2) return (0, 0); // Sanity check
  119. if (ry % 2 != v % 2) ry = P - ry;
  120. JPoint[16] memory points = _preComputeJacobianPoints(rx, ry);
  121. uint256 w = Math.invModPrime(uint256(r), N);
  122. uint256 u1 = mulmod(N - (uint256(h) % N), w, N);
  123. uint256 u2 = mulmod(uint256(s), w, N);
  124. (uint256 x, uint256 y) = _jMultShamir(points, u1, u2);
  125. return (bytes32(x), bytes32(y));
  126. }
  127. /**
  128. * @dev Checks if (x, y) are valid coordinates of a point on the curve.
  129. * In particular this function checks that x <= P and y <= P.
  130. */
  131. function isValidPublicKey(bytes32 x, bytes32 y) internal pure returns (bool result) {
  132. assembly ("memory-safe") {
  133. let lhs := mulmod(y, y, P) // y^2
  134. let rhs := addmod(mulmod(addmod(mulmod(x, x, P), A, P), x, P), B, P) // ((x^2 + a) * x) + b = x^3 + ax + b
  135. result := and(and(lt(x, P), lt(y, P)), eq(lhs, rhs)) // Should conform with the Weierstrass equation
  136. }
  137. }
  138. /**
  139. * @dev Checks if (r, s) is a proper signature.
  140. * In particular, this checks that `s` is in the "lower-range", making the signature non-malleable.
  141. */
  142. function _isProperSignature(bytes32 r, bytes32 s) private pure returns (bool) {
  143. return uint256(r) > 0 && uint256(r) < N && uint256(s) > 0 && uint256(s) <= HALF_N;
  144. }
  145. /**
  146. * @dev Reduce from jacobian to affine coordinates
  147. * @param jx - jacobian coordinate x
  148. * @param jy - jacobian coordinate y
  149. * @param jz - jacobian coordinate z
  150. * @return ax - affine coordinate x
  151. * @return ay - affine coordinate y
  152. */
  153. function _affineFromJacobian(uint256 jx, uint256 jy, uint256 jz) private view returns (uint256 ax, uint256 ay) {
  154. if (jz == 0) return (0, 0);
  155. uint256 zinv = Math.invModPrime(jz, P);
  156. uint256 zzinv = mulmod(zinv, zinv, P);
  157. uint256 zzzinv = mulmod(zzinv, zinv, P);
  158. ax = mulmod(jx, zzinv, P);
  159. ay = mulmod(jy, zzzinv, P);
  160. }
  161. /**
  162. * @dev Point addition on the jacobian coordinates
  163. * Reference: https://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian.html#addition-add-1998-cmo-2
  164. */
  165. function _jAdd(
  166. JPoint memory p1,
  167. uint256 x2,
  168. uint256 y2,
  169. uint256 z2
  170. ) private pure returns (uint256 rx, uint256 ry, uint256 rz) {
  171. assembly ("memory-safe") {
  172. let z1 := mload(add(p1, 0x40))
  173. let s1 := mulmod(mload(add(p1, 0x20)), mulmod(mulmod(z2, z2, P), z2, P), P) // s1 = y1*z2³
  174. let s2 := mulmod(y2, mulmod(mulmod(z1, z1, P), z1, P), P) // s2 = y2*z1³
  175. let r := addmod(s2, sub(P, s1), P) // r = s2-s1
  176. let u1 := mulmod(mload(p1), mulmod(z2, z2, P), P) // u1 = x1*z2²
  177. let u2 := mulmod(x2, mulmod(z1, z1, P), P) // u2 = x2*z1²
  178. let h := addmod(u2, sub(P, u1), P) // h = u2-u1
  179. let hh := mulmod(h, h, P) // h²
  180. // x' = r²-h³-2*u1*h²
  181. rx := addmod(
  182. addmod(mulmod(r, r, P), sub(P, mulmod(h, hh, P)), P),
  183. sub(P, mulmod(2, mulmod(u1, hh, P), P)),
  184. P
  185. )
  186. // y' = r*(u1*h²-x')-s1*h³
  187. ry := addmod(
  188. mulmod(r, addmod(mulmod(u1, hh, P), sub(P, rx), P), P),
  189. sub(P, mulmod(s1, mulmod(h, hh, P), P)),
  190. P
  191. )
  192. // z' = h*z1*z2
  193. rz := mulmod(h, mulmod(z1, z2, P), P)
  194. }
  195. }
  196. /**
  197. * @dev Point doubling on the jacobian coordinates
  198. * Reference: https://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian.html#doubling-dbl-1998-cmo-2
  199. */
  200. function _jDouble(uint256 x, uint256 y, uint256 z) private pure returns (uint256 rx, uint256 ry, uint256 rz) {
  201. assembly ("memory-safe") {
  202. let yy := mulmod(y, y, P)
  203. let zz := mulmod(z, z, P)
  204. let s := mulmod(4, mulmod(x, yy, P), P) // s = 4*x*y²
  205. let m := addmod(mulmod(3, mulmod(x, x, P), P), mulmod(A, mulmod(zz, zz, P), P), P) // m = 3*x²+a*z⁴
  206. let t := addmod(mulmod(m, m, P), sub(P, mulmod(2, s, P)), P) // t = m²-2*s
  207. // x' = t
  208. rx := t
  209. // y' = m*(s-t)-8*y⁴
  210. ry := addmod(mulmod(m, addmod(s, sub(P, t), P), P), sub(P, mulmod(8, mulmod(yy, yy, P), P)), P)
  211. // z' = 2*y*z
  212. rz := mulmod(2, mulmod(y, z, P), P)
  213. }
  214. }
  215. /**
  216. * @dev Compute P·u1 + Q·u2 using the precomputed points for P and Q (see {_preComputeJacobianPoints}).
  217. *
  218. * Uses Strauss Shamir trick for EC multiplication
  219. * https://stackoverflow.com/questions/50993471/ec-scalar-multiplication-with-strauss-shamir-method
  220. * we optimise on this a bit to do with 2 bits at a time rather than a single bit
  221. * the individual points for a single pass are precomputed
  222. * overall this reduces the number of additions while keeping the same number of doublings
  223. */
  224. function _jMultShamir(JPoint[16] memory points, uint256 u1, uint256 u2) private view returns (uint256, uint256) {
  225. uint256 x = 0;
  226. uint256 y = 0;
  227. uint256 z = 0;
  228. unchecked {
  229. for (uint256 i = 0; i < 128; ++i) {
  230. if (z > 0) {
  231. (x, y, z) = _jDouble(x, y, z);
  232. (x, y, z) = _jDouble(x, y, z);
  233. }
  234. // Read 2 bits of u1, and 2 bits of u2. Combining the two give a lookup index in the table.
  235. uint256 pos = ((u1 >> 252) & 0xc) | ((u2 >> 254) & 0x3);
  236. if (pos > 0) {
  237. if (z == 0) {
  238. (x, y, z) = (points[pos].x, points[pos].y, points[pos].z);
  239. } else {
  240. (x, y, z) = _jAdd(points[pos], x, y, z);
  241. }
  242. }
  243. u1 <<= 2;
  244. u2 <<= 2;
  245. }
  246. }
  247. return _affineFromJacobian(x, y, z);
  248. }
  249. /**
  250. * @dev Precompute a matrice of useful jacobian points associated with a given P. This can be seen as a 4x4 matrix
  251. * that contains combination of P and G (generator) up to 3 times each. See the table below:
  252. *
  253. * ┌────┬─────────────────────┐
  254. * │ i │ 0 1 2 3 │
  255. * ├────┼─────────────────────┤
  256. * │ 0 │ 0 p 2p 3p │
  257. * │ 4 │ g g+p g+2p g+3p │
  258. * │ 8 │ 2g 2g+p 2g+2p 2g+3p │
  259. * │ 12 │ 3g 3g+p 3g+2p 3g+3p │
  260. * └────┴─────────────────────┘
  261. */
  262. function _preComputeJacobianPoints(uint256 px, uint256 py) private pure returns (JPoint[16] memory points) {
  263. points[0x00] = JPoint(0, 0, 0); // 0,0
  264. points[0x01] = JPoint(px, py, 1); // 1,0 (p)
  265. points[0x04] = JPoint(GX, GY, 1); // 0,1 (g)
  266. points[0x02] = _jDoublePoint(points[0x01]); // 2,0 (2p)
  267. points[0x08] = _jDoublePoint(points[0x04]); // 0,2 (2g)
  268. points[0x03] = _jAddPoint(points[0x01], points[0x02]); // 3,0 (3p)
  269. points[0x05] = _jAddPoint(points[0x01], points[0x04]); // 1,1 (p+g)
  270. points[0x06] = _jAddPoint(points[0x02], points[0x04]); // 2,1 (2p+g)
  271. points[0x07] = _jAddPoint(points[0x03], points[0x04]); // 3,1 (3p+g)
  272. points[0x09] = _jAddPoint(points[0x01], points[0x08]); // 1,2 (p+2g)
  273. points[0x0a] = _jAddPoint(points[0x02], points[0x08]); // 2,2 (2p+2g)
  274. points[0x0b] = _jAddPoint(points[0x03], points[0x08]); // 3,2 (3p+2g)
  275. points[0x0c] = _jAddPoint(points[0x04], points[0x08]); // 0,3 (g+2g)
  276. points[0x0d] = _jAddPoint(points[0x01], points[0x0c]); // 1,3 (p+3g)
  277. points[0x0e] = _jAddPoint(points[0x02], points[0x0c]); // 2,3 (2p+3g)
  278. points[0x0f] = _jAddPoint(points[0x03], points[0x0C]); // 3,3 (3p+3g)
  279. }
  280. function _jAddPoint(JPoint memory p1, JPoint memory p2) private pure returns (JPoint memory) {
  281. (uint256 x, uint256 y, uint256 z) = _jAdd(p1, p2.x, p2.y, p2.z);
  282. return JPoint(x, y, z);
  283. }
  284. function _jDoublePoint(JPoint memory p) private pure returns (JPoint memory) {
  285. (uint256 x, uint256 y, uint256 z) = _jDouble(p.x, p.y, p.z);
  286. return JPoint(x, y, z);
  287. }
  288. }