bigint.c 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659
  1. // SPDX-License-Identifier: Apache-2.0
  2. #include <stdint.h>
  3. #include <stdbool.h>
  4. /*
  5. In wasm/bpf, the instruction for multiplying two 64 bit values results in a 64 bit value. In
  6. other words, the result is truncated. The largest values we can multiply without truncation
  7. is 32 bit (by casting to 64 bit and doing a 64 bit multiplication). So, we divvy the work
  8. up into a 32 bit multiplications.
  9. No overflow checking is done.
  10. 0 0 0 r5 r4 r3 r2 r1
  11. 0 0 0 0 l4 l3 l2 l1 *
  12. ------------------------------------------------------------
  13. 0 0 0 r5*l1 r4*l1 r3*l1 r2*l1 r1*l1
  14. 0 0 r5*l2 r4*l2 r3*l2 r2*l2 r1*l2 0
  15. 0 r5*l3 r4*l3 r3*l3 r2*l3 r1*l3 0 0
  16. r5*l4 r4*l4 r3*l4 r2*l4 r1*l4 0 0 0 +
  17. ------------------------------------------------------------
  18. */
  19. void __mul32(uint32_t left[], uint32_t right[], uint32_t out[], int len)
  20. {
  21. uint64_t val1 = 0, carry = 0;
  22. int left_len = len, right_len = len;
  23. while (left_len > 0 && !left[left_len - 1])
  24. left_len--;
  25. while (right_len > 0 && !right[right_len - 1])
  26. right_len--;
  27. int right_start = 0, right_end = 0;
  28. int left_start = 0;
  29. for (int l = 0; l < len; l++)
  30. {
  31. int i = 0;
  32. if (l >= left_len)
  33. right_start++;
  34. if (l >= right_len)
  35. left_start++;
  36. if (right_end < right_len)
  37. right_end++;
  38. for (int r = right_end - 1; r >= right_start; r--)
  39. {
  40. uint64_t m = (uint64_t)left[left_start + i] * (uint64_t)right[r];
  41. i++;
  42. if (__builtin_add_overflow(val1, m, &val1))
  43. carry += 0x100000000;
  44. }
  45. out[l] = val1;
  46. val1 = (val1 >> 32) | carry;
  47. carry = 0;
  48. }
  49. }
  50. // A version of __mul32 that detects overflow.
  51. bool __mul32_with_builtin_ovf(uint32_t left[], uint32_t right[], uint32_t out[], int len)
  52. {
  53. bool overflow = false;
  54. uint64_t val1 = 0, carry = 0;
  55. int left_len = len, right_len = len;
  56. while (left_len > 0 && !left[left_len - 1])
  57. left_len--;
  58. while (right_len > 0 && !right[right_len - 1])
  59. right_len--;
  60. int right_start = 0, right_end = 0;
  61. int left_start = 0;
  62. // We extend len to check for possible overflow. len = bit_width / 32. Checking for overflow for intN (where N = number of bits) requires checking for any set bits beyond N up to N*2.
  63. len = len * 2;
  64. for (int l = 0; l < len; l++)
  65. {
  66. int i = 0;
  67. if (l >= left_len)
  68. right_start++;
  69. if (l >= right_len)
  70. left_start++;
  71. if (right_end < right_len)
  72. right_end++;
  73. for (int r = right_end - 1; r >= right_start; r--)
  74. {
  75. uint64_t m = (uint64_t)left[left_start + i] * (uint64_t)right[r];
  76. i++;
  77. if (__builtin_add_overflow(val1, m, &val1))
  78. carry += 0x100000000;
  79. }
  80. // If the loop is within the operand bit size, just do the assignment
  81. if (l < len / 2)
  82. {
  83. out[l] = val1;
  84. }
  85. // If the loop extends to more than the bit size, we check for overflow.
  86. else if (l >= len / 2)
  87. {
  88. if (val1 > 0)
  89. {
  90. overflow = true;
  91. break;
  92. }
  93. }
  94. val1 = (val1 >> 32) | carry;
  95. carry = 0;
  96. }
  97. return overflow;
  98. }
  99. // Some compiler runtime builtins we need.
  100. // 128 bit shift left.
  101. typedef union
  102. {
  103. __uint128_t all;
  104. struct
  105. {
  106. uint64_t low;
  107. uint64_t high;
  108. };
  109. } two64;
  110. // 128 bit shift left.
  111. typedef union
  112. {
  113. __int128_t all;
  114. struct
  115. {
  116. uint64_t low;
  117. int64_t high;
  118. };
  119. } two64s;
  120. // This assumes r >= 0 && r <= 127
  121. __uint128_t __ashlti3(__uint128_t val, int r)
  122. {
  123. two64 in;
  124. two64 result;
  125. in.all = val;
  126. if (r == 0)
  127. {
  128. // nothing to do
  129. result.all = in.all;
  130. }
  131. else if (r & 64)
  132. {
  133. // Shift more than or equal 64
  134. result.low = 0;
  135. result.high = in.low << (r & 63);
  136. }
  137. else
  138. {
  139. // Shift less than 64
  140. result.low = in.low << r;
  141. result.high = (in.high << r) | (in.low >> (64 - r));
  142. }
  143. return result.all;
  144. }
  145. // This assumes r >= 0 && r <= 127
  146. __uint128_t __lshrti3(__uint128_t val, int r)
  147. {
  148. two64 in;
  149. two64 result;
  150. in.all = val;
  151. if (r == 0)
  152. {
  153. // nothing to do
  154. result.all = in.all;
  155. }
  156. else if (r & 64)
  157. {
  158. // Shift more than or equal 64
  159. result.low = in.high >> (r & 63);
  160. result.high = 0;
  161. }
  162. else
  163. {
  164. // Shift less than 64
  165. result.low = (in.low >> r) | (in.high << (64 - r));
  166. result.high = in.high >> r;
  167. }
  168. return result.all;
  169. }
  170. __uint128_t __ashrti3(__uint128_t val, int r)
  171. {
  172. two64s in;
  173. two64s result;
  174. in.all = val;
  175. if (r == 0)
  176. {
  177. // nothing to do
  178. result.all = in.all;
  179. }
  180. else if (r & 64)
  181. {
  182. // Shift more than or equal 64
  183. result.high = in.high >> 63;
  184. result.low = in.high >> (r & 63);
  185. }
  186. else
  187. {
  188. // Shift less than 64
  189. result.low = (in.low >> r) | (in.high << (64 - r));
  190. result.high = in.high >> r;
  191. }
  192. return result.all;
  193. }
  194. // Return the highest set bit in v
  195. int bits(uint64_t v)
  196. {
  197. int h = 63;
  198. if (!(v & 0xffffffff00000000))
  199. {
  200. h -= 32;
  201. v <<= 32;
  202. }
  203. if (!(v & 0xffff000000000000))
  204. {
  205. h -= 16;
  206. v <<= 16;
  207. }
  208. if (!(v & 0xff00000000000000))
  209. {
  210. h -= 8;
  211. v <<= 8;
  212. }
  213. if (!(v & 0xf000000000000000))
  214. {
  215. h -= 4;
  216. v <<= 4;
  217. }
  218. if (!(v & 0xc000000000000000))
  219. {
  220. h -= 2;
  221. v <<= 2;
  222. }
  223. if (!(v & 0x8000000000000000))
  224. {
  225. h -= 1;
  226. }
  227. return h;
  228. }
  229. int bits128(__uint128_t v)
  230. {
  231. uint64_t upper = v >> 64;
  232. if (upper)
  233. {
  234. return bits(upper) + 64;
  235. }
  236. else
  237. {
  238. return bits(v);
  239. }
  240. }
  241. __uint128_t shl128(__uint128_t val, int r)
  242. {
  243. if (r == 0)
  244. {
  245. return val;
  246. }
  247. else if (r & 64)
  248. {
  249. // Shift more than or equal 64
  250. uint64_t low = val;
  251. __uint128_t tmp = low << (r & 63);
  252. return tmp << 64;
  253. }
  254. else
  255. {
  256. // Shift less than 64
  257. uint64_t low = val;
  258. uint64_t high = val >> 64;
  259. __uint128_t tmp = (high << r) | (low >> (64 - r));
  260. return (low << r) | (tmp << 64);
  261. }
  262. }
  263. __uint128_t shr128(__uint128_t val, int r)
  264. {
  265. if (r == 0)
  266. {
  267. return val;
  268. }
  269. else if (r & 64)
  270. {
  271. // Shift more than or equal 64
  272. uint64_t high = val >> 64;
  273. high >>= r & 63;
  274. return high;
  275. }
  276. else
  277. {
  278. // Shift less than 64
  279. uint64_t low = val;
  280. uint64_t high = val >> 64;
  281. low >>= r;
  282. high <<= 64 - r;
  283. __uint128_t tmp = high;
  284. return low | (tmp << 64);
  285. }
  286. }
  287. int udivmod128(__uint128_t *pdividend, __uint128_t *pdivisor, __uint128_t *remainder, __uint128_t *quotient)
  288. {
  289. __uint128_t dividend = *pdividend;
  290. __uint128_t divisor = *pdivisor;
  291. if (divisor == 0)
  292. return 1;
  293. if (divisor == 1)
  294. {
  295. *remainder = 0;
  296. *quotient = dividend;
  297. return 0;
  298. }
  299. if (divisor == dividend)
  300. {
  301. *remainder = 0;
  302. *quotient = 1;
  303. return 0;
  304. }
  305. if (dividend == 0 || dividend < divisor)
  306. {
  307. *remainder = dividend;
  308. *quotient = 0;
  309. return 0;
  310. }
  311. __uint128_t q = 0, r = 0;
  312. for (int x = bits128(dividend) + 1; x > 0; x--)
  313. {
  314. q <<= 1;
  315. r <<= 1;
  316. if ((dividend >> (x - 1)) & 1)
  317. {
  318. r++;
  319. }
  320. if (r >= divisor)
  321. {
  322. r -= divisor;
  323. q++;
  324. }
  325. }
  326. *quotient = q;
  327. *remainder = r;
  328. return 0;
  329. }
  330. int sdivmod128(__uint128_t *pdividend, __uint128_t *pdivisor, __uint128_t *remainder, __uint128_t *quotient)
  331. {
  332. bool dividend_negative = ((uint8_t *)pdividend)[15] >= 128;
  333. if (dividend_negative)
  334. {
  335. __uint128_t dividend = *pdividend;
  336. *pdividend = -dividend;
  337. }
  338. bool divisor_negative = ((uint8_t *)pdivisor)[15] >= 128;
  339. if (divisor_negative)
  340. {
  341. __uint128_t divisor = *pdivisor;
  342. *pdivisor = -divisor;
  343. }
  344. if (udivmod128(pdividend, pdivisor, remainder, quotient))
  345. {
  346. return 1;
  347. }
  348. if (dividend_negative != divisor_negative)
  349. {
  350. __uint128_t q = *quotient;
  351. *quotient = -q;
  352. }
  353. if (dividend_negative)
  354. {
  355. __uint128_t r = *remainder;
  356. *remainder = -r;
  357. }
  358. return 0;
  359. }
  360. typedef unsigned _BitInt(256) uint256_t;
  361. uint256_t const uint256_0 = (uint256_t)0;
  362. uint256_t const uint256_1 = (uint256_t)1;
  363. int bits256(uint256_t *value)
  364. {
  365. // 256 bits values consist of 4 uint64_ts.
  366. uint64_t *v = (uint64_t *)value;
  367. for (int i = 3; i >= 0; i--)
  368. {
  369. if (v[i])
  370. return bits(v[i]) + 64 * i;
  371. }
  372. return 0;
  373. }
  374. int udivmod256(uint256_t *pdividend, uint256_t *pdivisor, uint256_t *remainder, uint256_t *quotient)
  375. {
  376. uint256_t dividend = *pdividend;
  377. uint256_t divisor = *pdivisor;
  378. if (divisor == uint256_0)
  379. return 1;
  380. if (divisor == uint256_1)
  381. {
  382. *remainder = uint256_0;
  383. *quotient = dividend;
  384. return 0;
  385. }
  386. if (divisor == dividend)
  387. {
  388. *remainder = uint256_0;
  389. *quotient = uint256_1;
  390. return 0;
  391. }
  392. if (dividend == uint256_0 || dividend < divisor)
  393. {
  394. *remainder = dividend;
  395. *quotient = uint256_0;
  396. return 0;
  397. }
  398. uint256_t q = uint256_0, r = dividend;
  399. uint256_t copyd = divisor << (bits256(&dividend) - bits256(&divisor));
  400. uint256_t adder = uint256_1 << (bits256(&dividend) - bits256(&divisor));
  401. if (copyd > dividend)
  402. {
  403. copyd >>= 1;
  404. adder >>= 1;
  405. }
  406. while (r >= divisor)
  407. {
  408. if (r >= copyd)
  409. {
  410. r -= copyd;
  411. q |= adder;
  412. }
  413. copyd >>= 1;
  414. adder >>= 1;
  415. }
  416. *quotient = q;
  417. *remainder = r;
  418. return 0;
  419. }
  420. int sdivmod256(uint256_t *pdividend, uint256_t *pdivisor, uint256_t *remainder, uint256_t *quotient)
  421. {
  422. bool dividend_negative = ((uint8_t *)pdividend)[31] >= 128;
  423. if (dividend_negative)
  424. {
  425. uint256_t dividend = *pdividend;
  426. *pdividend = -dividend;
  427. }
  428. bool divisor_negative = ((uint8_t *)pdivisor)[31] >= 128;
  429. if (divisor_negative)
  430. {
  431. uint256_t divisor = *pdivisor;
  432. *pdivisor = -divisor;
  433. }
  434. if (udivmod256(pdividend, pdivisor, remainder, quotient))
  435. {
  436. return 1;
  437. }
  438. if (dividend_negative != divisor_negative)
  439. {
  440. uint256_t q = *quotient;
  441. *quotient = -q;
  442. }
  443. if (dividend_negative)
  444. {
  445. uint256_t r = *remainder;
  446. *remainder = -r;
  447. }
  448. return 0;
  449. }
  450. typedef unsigned _BitInt(512) uint512_t;
  451. uint512_t const uint512_0 = (uint512_t)0;
  452. uint512_t const uint512_1 = (uint512_t)1;
  453. int bits512(uint512_t *value)
  454. {
  455. // 512 bits values consist of 8 uint64_ts.
  456. uint64_t *v = (uint64_t *)value;
  457. for (int i = 7; i >= 0; i--)
  458. {
  459. if (v[i])
  460. return bits(v[i]) + 64 * i;
  461. }
  462. return 0;
  463. }
  464. int udivmod512(uint512_t *pdividend, uint512_t *pdivisor, uint512_t *remainder, uint512_t *quotient)
  465. {
  466. uint512_t dividend = *pdividend;
  467. uint512_t divisor = *pdivisor;
  468. if (divisor == uint512_0)
  469. return 1;
  470. if (divisor == uint512_1)
  471. {
  472. *remainder = uint512_0;
  473. *quotient = dividend;
  474. return 0;
  475. }
  476. if (divisor == dividend)
  477. {
  478. *remainder = uint512_0;
  479. *quotient = uint512_1;
  480. return 0;
  481. }
  482. if (dividend == uint512_0 || dividend < divisor)
  483. {
  484. *remainder = dividend;
  485. *quotient = uint512_0;
  486. return 0;
  487. }
  488. uint512_t q = uint512_0, r = dividend;
  489. uint512_t copyd = divisor << (bits512(&dividend) - bits512(&divisor));
  490. uint512_t adder = uint512_1 << (bits512(&dividend) - bits512(&divisor));
  491. if (copyd > dividend)
  492. {
  493. copyd >>= 1;
  494. adder >>= 1;
  495. }
  496. while (r >= divisor)
  497. {
  498. if (r >= copyd)
  499. {
  500. r -= copyd;
  501. q |= adder;
  502. }
  503. copyd >>= 1;
  504. adder >>= 1;
  505. }
  506. *quotient = q;
  507. *remainder = r;
  508. return 0;
  509. }
  510. int sdivmod512(uint512_t *pdividend, uint512_t *pdivisor, uint512_t *remainder, uint512_t *quotient)
  511. {
  512. bool dividend_negative = ((uint8_t *)pdividend)[63] >= 128;
  513. if (dividend_negative)
  514. {
  515. uint512_t dividend = *pdividend;
  516. *pdividend = -dividend;
  517. }
  518. bool divisor_negative = ((uint8_t *)pdivisor)[63] >= 128;
  519. if (divisor_negative)
  520. {
  521. uint512_t divisor = *pdivisor;
  522. *pdivisor = -divisor;
  523. }
  524. if (udivmod512(pdividend, pdivisor, remainder, quotient))
  525. {
  526. return 1;
  527. }
  528. if (dividend_negative != divisor_negative)
  529. {
  530. uint512_t q = *quotient;
  531. *quotient = -q;
  532. }
  533. if (dividend_negative)
  534. {
  535. uint512_t r = *remainder;
  536. *remainder = -r;
  537. }
  538. return 0;
  539. }