bigint.c 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657
  1. // SPDX-License-Identifier: Apache-2.0
  2. #include <stdint.h>
  3. #include <stdbool.h>
  4. /*
  5. In wasm/bpf, the instruction for multiplying two 64 bit values results in a 64 bit value. In
  6. other words, the result is truncated. The largest values we can multiply without truncation
  7. is 32 bit (by casting to 64 bit and doing a 64 bit multiplication). So, we divvy the work
  8. up into a 32 bit multiplications.
  9. No overflow checking is done.
  10. 0 0 0 r5 r4 r3 r2 r1
  11. 0 0 0 0 l4 l3 l2 l1 *
  12. ------------------------------------------------------------
  13. 0 0 0 r5*l1 r4*l1 r3*l1 r2*l1 r1*l1
  14. 0 0 r5*l2 r4*l2 r3*l2 r2*l2 r1*l2 0
  15. 0 r5*l3 r4*l3 r3*l3 r2*l3 r1*l3 0 0
  16. r5*l4 r4*l4 r3*l4 r2*l4 r1*l4 0 0 0 +
  17. ------------------------------------------------------------
  18. */
  19. void __mul32(uint32_t left[], uint32_t right[], uint32_t out[], int len)
  20. {
  21. uint64_t val1 = 0, carry = 0;
  22. int left_len = len, right_len = len;
  23. while (left_len > 0 && !left[left_len - 1])
  24. left_len--;
  25. while (right_len > 0 && !right[right_len - 1])
  26. right_len--;
  27. int right_start = 0, right_end = 0;
  28. int left_start = 0;
  29. for (int l = 0; l < len; l++)
  30. {
  31. int i = 0;
  32. if (l >= left_len)
  33. right_start++;
  34. if (l >= right_len)
  35. left_start++;
  36. if (right_end < right_len)
  37. right_end++;
  38. for (int r = right_end - 1; r >= right_start; r--)
  39. {
  40. uint64_t m = (uint64_t)left[left_start + i] * (uint64_t)right[r];
  41. i++;
  42. if (__builtin_add_overflow(val1, m, &val1))
  43. carry += 0x100000000;
  44. }
  45. out[l] = val1;
  46. val1 = (val1 >> 32) | carry;
  47. carry = 0;
  48. }
  49. }
  50. // A version of __mul32 that detects overflow.
  51. bool __mul32_with_builtin_ovf(uint32_t left[], uint32_t right[], uint32_t out[], int len)
  52. {
  53. bool overflow = false;
  54. uint64_t val1 = 0, carry = 0;
  55. int left_len = len, right_len = len;
  56. while (left_len > 0 && !left[left_len - 1])
  57. left_len--;
  58. while (right_len > 0 && !right[right_len - 1])
  59. right_len--;
  60. int right_start = 0, right_end = 0;
  61. int left_start = 0;
  62. // We extend len to check for possible overflow. len = bit_width / 32. Checking for overflow for intN (where N = number of bits) requires checking for any set bits beyond N up to N*2.
  63. len = len * 2;
  64. for (int l = 0; l < len; l++)
  65. {
  66. int i = 0;
  67. if (l >= left_len)
  68. right_start++;
  69. if (l >= right_len)
  70. left_start++;
  71. if (right_end < right_len)
  72. right_end++;
  73. for (int r = right_end - 1; r >= right_start; r--)
  74. {
  75. uint64_t m = (uint64_t)left[left_start + i] * (uint64_t)right[r];
  76. i++;
  77. if (__builtin_add_overflow(val1, m, &val1))
  78. carry += 0x100000000;
  79. }
  80. // If the loop is within the operand bit size, just do the assignment
  81. if (l < len / 2)
  82. {
  83. out[l] = val1;
  84. }
  85. // If the loop extends to more than the bit size, we check for overflow.
  86. else if (l >= len / 2)
  87. {
  88. if (val1 > 0)
  89. {
  90. overflow = true;
  91. break;
  92. }
  93. }
  94. val1 = (val1 >> 32) | carry;
  95. carry = 0;
  96. }
  97. return overflow;
  98. }
  99. // Some compiler runtime builtins we need.
  100. // 128 bit shift left.
  101. typedef union {
  102. __uint128_t all;
  103. struct
  104. {
  105. uint64_t low;
  106. uint64_t high;
  107. };
  108. } two64;
  109. // 128 bit shift left.
  110. typedef union {
  111. __int128_t all;
  112. struct
  113. {
  114. uint64_t low;
  115. int64_t high;
  116. };
  117. } two64s;
  118. // This assumes r >= 0 && r <= 127
  119. __uint128_t __ashlti3(__uint128_t val, int r)
  120. {
  121. two64 in;
  122. two64 result;
  123. in.all = val;
  124. if (r == 0)
  125. {
  126. // nothing to do
  127. result.all = in.all;
  128. }
  129. else if (r & 64)
  130. {
  131. // Shift more than or equal 64
  132. result.low = 0;
  133. result.high = in.low << (r & 63);
  134. }
  135. else
  136. {
  137. // Shift less than 64
  138. result.low = in.low << r;
  139. result.high = (in.high << r) | (in.low >> (64 - r));
  140. }
  141. return result.all;
  142. }
  143. // This assumes r >= 0 && r <= 127
  144. __uint128_t __lshrti3(__uint128_t val, int r)
  145. {
  146. two64 in;
  147. two64 result;
  148. in.all = val;
  149. if (r == 0)
  150. {
  151. // nothing to do
  152. result.all = in.all;
  153. }
  154. else if (r & 64)
  155. {
  156. // Shift more than or equal 64
  157. result.low = in.high >> (r & 63);
  158. result.high = 0;
  159. }
  160. else
  161. {
  162. // Shift less than 64
  163. result.low = (in.low >> r) | (in.high << (64 - r));
  164. result.high = in.high >> r;
  165. }
  166. return result.all;
  167. }
  168. __uint128_t __ashrti3(__uint128_t val, int r)
  169. {
  170. two64s in;
  171. two64s result;
  172. in.all = val;
  173. if (r == 0)
  174. {
  175. // nothing to do
  176. result.all = in.all;
  177. }
  178. else if (r & 64)
  179. {
  180. // Shift more than or equal 64
  181. result.high = in.high >> 63;
  182. result.low = in.high >> (r & 63);
  183. }
  184. else
  185. {
  186. // Shift less than 64
  187. result.low = (in.low >> r) | (in.high << (64 - r));
  188. result.high = in.high >> r;
  189. }
  190. return result.all;
  191. }
  192. // Return the highest set bit in v
  193. int bits(uint64_t v)
  194. {
  195. int h = 63;
  196. if (!(v & 0xffffffff00000000))
  197. {
  198. h -= 32;
  199. v <<= 32;
  200. }
  201. if (!(v & 0xffff000000000000))
  202. {
  203. h -= 16;
  204. v <<= 16;
  205. }
  206. if (!(v & 0xff00000000000000))
  207. {
  208. h -= 8;
  209. v <<= 8;
  210. }
  211. if (!(v & 0xf000000000000000))
  212. {
  213. h -= 4;
  214. v <<= 4;
  215. }
  216. if (!(v & 0xc000000000000000))
  217. {
  218. h -= 2;
  219. v <<= 2;
  220. }
  221. if (!(v & 0x8000000000000000))
  222. {
  223. h -= 1;
  224. }
  225. return h;
  226. }
  227. int bits128(__uint128_t v)
  228. {
  229. uint64_t upper = v >> 64;
  230. if (upper)
  231. {
  232. return bits(upper) + 64;
  233. }
  234. else
  235. {
  236. return bits(v);
  237. }
  238. }
  239. __uint128_t shl128(__uint128_t val, int r)
  240. {
  241. if (r == 0)
  242. {
  243. return val;
  244. }
  245. else if (r & 64)
  246. {
  247. // Shift more than or equal 64
  248. uint64_t low = val;
  249. __uint128_t tmp = low << (r & 63);
  250. return tmp << 64;
  251. }
  252. else
  253. {
  254. // Shift less than 64
  255. uint64_t low = val;
  256. uint64_t high = val >> 64;
  257. __uint128_t tmp = (high << r) | (low >> (64 - r));
  258. return (low << r) | (tmp << 64);
  259. }
  260. }
  261. __uint128_t shr128(__uint128_t val, int r)
  262. {
  263. if (r == 0)
  264. {
  265. return val;
  266. }
  267. else if (r & 64)
  268. {
  269. // Shift more than or equal 64
  270. uint64_t high = val >> 64;
  271. high >>= r & 63;
  272. return high;
  273. }
  274. else
  275. {
  276. // Shift less than 64
  277. uint64_t low = val;
  278. uint64_t high = val >> 64;
  279. low >>= r;
  280. high <<= 64 - r;
  281. __uint128_t tmp = high;
  282. return low | (tmp << 64);
  283. }
  284. }
  285. int udivmod128(__uint128_t *pdividend, __uint128_t *pdivisor, __uint128_t *remainder, __uint128_t *quotient)
  286. {
  287. __uint128_t dividend = *pdividend;
  288. __uint128_t divisor = *pdivisor;
  289. if (divisor == 0)
  290. return 1;
  291. if (divisor == 1)
  292. {
  293. *remainder = 0;
  294. *quotient = dividend;
  295. return 0;
  296. }
  297. if (divisor == dividend)
  298. {
  299. *remainder = 0;
  300. *quotient = 1;
  301. return 0;
  302. }
  303. if (dividend == 0 || dividend < divisor)
  304. {
  305. *remainder = dividend;
  306. *quotient = 0;
  307. return 0;
  308. }
  309. __uint128_t q = 0, r = 0;
  310. for (int x = bits128(dividend) + 1; x > 0; x--)
  311. {
  312. q <<= 1;
  313. r <<= 1;
  314. if ((dividend >> (x - 1)) & 1)
  315. {
  316. r++;
  317. }
  318. if (r >= divisor)
  319. {
  320. r -= divisor;
  321. q++;
  322. }
  323. }
  324. *quotient = q;
  325. *remainder = r;
  326. return 0;
  327. }
  328. int sdivmod128(__uint128_t *pdividend, __uint128_t *pdivisor, __uint128_t *remainder, __uint128_t *quotient)
  329. {
  330. bool dividend_negative = ((uint8_t *)pdividend)[15] >= 128;
  331. if (dividend_negative)
  332. {
  333. __uint128_t dividend = *pdividend;
  334. *pdividend = -dividend;
  335. }
  336. bool divisor_negative = ((uint8_t *)pdivisor)[15] >= 128;
  337. if (divisor_negative)
  338. {
  339. __uint128_t divisor = *pdivisor;
  340. *pdivisor = -divisor;
  341. }
  342. if (udivmod128(pdividend, pdivisor, remainder, quotient))
  343. {
  344. return 1;
  345. }
  346. if (dividend_negative != divisor_negative)
  347. {
  348. __uint128_t q = *quotient;
  349. *quotient = -q;
  350. }
  351. if (dividend_negative)
  352. {
  353. __uint128_t r = *remainder;
  354. *remainder = -r;
  355. }
  356. return 0;
  357. }
  358. typedef unsigned _BitInt(256) uint256_t;
  359. uint256_t const uint256_0 = (uint256_t)0;
  360. uint256_t const uint256_1 = (uint256_t)1;
  361. int bits256(uint256_t *value)
  362. {
  363. // 256 bits values consist of 4 uint64_ts.
  364. uint64_t *v = (uint64_t *)value;
  365. for (int i = 3; i >= 0; i--)
  366. {
  367. if (v[i])
  368. return bits(v[i]) + 64 * i;
  369. }
  370. return 0;
  371. }
  372. int udivmod256(uint256_t *pdividend, uint256_t *pdivisor, uint256_t *remainder, uint256_t *quotient)
  373. {
  374. uint256_t dividend = *pdividend;
  375. uint256_t divisor = *pdivisor;
  376. if (divisor == uint256_0)
  377. return 1;
  378. if (divisor == uint256_1)
  379. {
  380. *remainder = uint256_0;
  381. *quotient = dividend;
  382. return 0;
  383. }
  384. if (divisor == dividend)
  385. {
  386. *remainder = uint256_0;
  387. *quotient = uint256_1;
  388. return 0;
  389. }
  390. if (dividend == uint256_0 || dividend < divisor)
  391. {
  392. *remainder = dividend;
  393. *quotient = uint256_0;
  394. return 0;
  395. }
  396. uint256_t q = uint256_0, r = dividend;
  397. uint256_t copyd = divisor << (bits256(&dividend) - bits256(&divisor));
  398. uint256_t adder = uint256_1 << (bits256(&dividend) - bits256(&divisor));
  399. if (copyd > dividend)
  400. {
  401. copyd >>= 1;
  402. adder >>= 1;
  403. }
  404. while (r >= divisor)
  405. {
  406. if (r >= copyd)
  407. {
  408. r -= copyd;
  409. q |= adder;
  410. }
  411. copyd >>= 1;
  412. adder >>= 1;
  413. }
  414. *quotient = q;
  415. *remainder = r;
  416. return 0;
  417. }
  418. int sdivmod256(uint256_t *pdividend, uint256_t *pdivisor, uint256_t *remainder, uint256_t *quotient)
  419. {
  420. bool dividend_negative = ((uint8_t *)pdividend)[31] >= 128;
  421. if (dividend_negative)
  422. {
  423. uint256_t dividend = *pdividend;
  424. *pdividend = -dividend;
  425. }
  426. bool divisor_negative = ((uint8_t *)pdivisor)[31] >= 128;
  427. if (divisor_negative)
  428. {
  429. uint256_t divisor = *pdivisor;
  430. *pdivisor = -divisor;
  431. }
  432. if (udivmod256(pdividend, pdivisor, remainder, quotient))
  433. {
  434. return 1;
  435. }
  436. if (dividend_negative != divisor_negative)
  437. {
  438. uint256_t q = *quotient;
  439. *quotient = -q;
  440. }
  441. if (dividend_negative)
  442. {
  443. uint256_t r = *remainder;
  444. *remainder = -r;
  445. }
  446. return 0;
  447. }
  448. typedef unsigned _BitInt(512) uint512_t;
  449. uint512_t const uint512_0 = (uint512_t)0;
  450. uint512_t const uint512_1 = (uint512_t)1;
  451. int bits512(uint512_t *value)
  452. {
  453. // 512 bits values consist of 8 uint64_ts.
  454. uint64_t *v = (uint64_t *)value;
  455. for (int i = 7; i >= 0; i--)
  456. {
  457. if (v[i])
  458. return bits(v[i]) + 64 * i;
  459. }
  460. return 0;
  461. }
  462. int udivmod512(uint512_t *pdividend, uint512_t *pdivisor, uint512_t *remainder, uint512_t *quotient)
  463. {
  464. uint512_t dividend = *pdividend;
  465. uint512_t divisor = *pdivisor;
  466. if (divisor == uint512_0)
  467. return 1;
  468. if (divisor == uint512_1)
  469. {
  470. *remainder = uint512_0;
  471. *quotient = dividend;
  472. return 0;
  473. }
  474. if (divisor == dividend)
  475. {
  476. *remainder = uint512_0;
  477. *quotient = uint512_1;
  478. return 0;
  479. }
  480. if (dividend == uint512_0 || dividend < divisor)
  481. {
  482. *remainder = dividend;
  483. *quotient = uint512_0;
  484. return 0;
  485. }
  486. uint512_t q = uint512_0, r = dividend;
  487. uint512_t copyd = divisor << (bits512(&dividend) - bits512(&divisor));
  488. uint512_t adder = uint512_1 << (bits512(&dividend) - bits512(&divisor));
  489. if (copyd > dividend)
  490. {
  491. copyd >>= 1;
  492. adder >>= 1;
  493. }
  494. while (r >= divisor)
  495. {
  496. if (r >= copyd)
  497. {
  498. r -= copyd;
  499. q |= adder;
  500. }
  501. copyd >>= 1;
  502. adder >>= 1;
  503. }
  504. *quotient = q;
  505. *remainder = r;
  506. return 0;
  507. }
  508. int sdivmod512(uint512_t *pdividend, uint512_t *pdivisor, uint512_t *remainder, uint512_t *quotient)
  509. {
  510. bool dividend_negative = ((uint8_t *)pdividend)[63] >= 128;
  511. if (dividend_negative)
  512. {
  513. uint512_t dividend = *pdividend;
  514. *pdividend = -dividend;
  515. }
  516. bool divisor_negative = ((uint8_t *)pdivisor)[63] >= 128;
  517. if (divisor_negative)
  518. {
  519. uint512_t divisor = *pdivisor;
  520. *pdivisor = -divisor;
  521. }
  522. if (udivmod512(pdividend, pdivisor, remainder, quotient))
  523. {
  524. return 1;
  525. }
  526. if (dividend_negative != divisor_negative)
  527. {
  528. uint512_t q = *quotient;
  529. *quotient = -q;
  530. }
  531. if (dividend_negative)
  532. {
  533. uint512_t r = *remainder;
  534. *remainder = -r;
  535. }
  536. return 0;
  537. }