bigint.c 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658
  1. #include <stdint.h>
  2. #include <stdbool.h>
  3. /*
  4. In wasm/bpf, the instruction for multiplying two 64 bit values results in a 64 bit value. In
  5. other words, the result is truncated. The largest values we can multiply without truncation
  6. is 32 bit (by casting to 64 bit and doing a 64 bit multiplication). So, we divvy the work
  7. up into a 32 bit multiplications.
  8. No overflow checking is done.
  9. 0 0 0 r5 r4 r3 r2 r1
  10. 0 0 0 0 l4 l3 l2 l1 *
  11. ------------------------------------------------------------
  12. 0 0 0 r5*l1 r4*l1 r3*l1 r2*l1 r1*l1
  13. 0 0 r5*l2 r4*l2 r3*l2 r2*l2 r1*l2 0
  14. 0 r5*l3 r4*l3 r3*l3 r2*l3 r1*l3 0 0
  15. r5*l4 r4*l4 r3*l4 r2*l4 r1*l4 0 0 0 +
  16. ------------------------------------------------------------
  17. */
  18. void __mul32(uint32_t left[], uint32_t right[], uint32_t out[], int len)
  19. {
  20. uint64_t val1 = 0, carry = 0;
  21. int left_len = len, right_len = len;
  22. while (left_len > 0 && !left[left_len - 1])
  23. left_len--;
  24. while (right_len > 0 && !right[right_len - 1])
  25. right_len--;
  26. int right_start = 0, right_end = 0;
  27. int left_start = 0;
  28. for (int l = 0; l < len; l++)
  29. {
  30. int i = 0;
  31. if (l >= left_len)
  32. right_start++;
  33. if (l >= right_len)
  34. left_start++;
  35. if (right_end < right_len)
  36. right_end++;
  37. for (int r = right_end - 1; r >= right_start; r--)
  38. {
  39. uint64_t m = (uint64_t)left[left_start + i] * (uint64_t)right[r];
  40. i++;
  41. if (__builtin_add_overflow(val1, m, &val1))
  42. carry += 0x100000000;
  43. }
  44. out[l] = val1;
  45. val1 = (val1 >> 32) | carry;
  46. carry = 0;
  47. }
  48. }
  49. // A version of __mul32 that detects overflow.
  50. bool __mul32_with_builtin_ovf(uint32_t left[], uint32_t right[], uint32_t out[], int len)
  51. {
  52. bool overflow = false;
  53. uint64_t val1 = 0, carry = 0;
  54. int left_len = len, right_len = len;
  55. while (left_len > 0 && !left[left_len - 1])
  56. left_len--;
  57. while (right_len > 0 && !right[right_len - 1])
  58. right_len--;
  59. int right_start = 0, right_end = 0;
  60. int left_start = 0;
  61. // We extend len to check for possible overflow. len = bit_width / 32. Checking for overflow for intN (where N = number of bits) requires checking for any set bits beyond N up to N*2.
  62. len = len * 2;
  63. for (int l = 0; l < len; l++)
  64. {
  65. int i = 0;
  66. if (l >= left_len)
  67. right_start++;
  68. if (l >= right_len)
  69. left_start++;
  70. if (right_end < right_len)
  71. right_end++;
  72. for (int r = right_end - 1; r >= right_start; r--)
  73. {
  74. uint64_t m = (uint64_t)left[left_start + i] * (uint64_t)right[r];
  75. i++;
  76. if (__builtin_add_overflow(val1, m, &val1))
  77. carry += 0x100000000;
  78. }
  79. // If the loop is within the operand bit size, just do the assignment
  80. if (l < len / 2)
  81. {
  82. out[l] = val1;
  83. }
  84. // If the loop extends to more than the bit size, we check for overflow.
  85. else if (l >= len / 2)
  86. {
  87. if (val1 > 0)
  88. {
  89. overflow = true;
  90. break;
  91. }
  92. }
  93. val1 = (val1 >> 32) | carry;
  94. carry = 0;
  95. }
  96. return overflow;
  97. }
  98. // Some compiler runtime builtins we need.
  99. // 128 bit shift left.
  100. typedef union
  101. {
  102. __uint128_t all;
  103. struct
  104. {
  105. uint64_t low;
  106. uint64_t high;
  107. };
  108. } two64;
  109. // 128 bit shift left.
  110. typedef union
  111. {
  112. __int128_t all;
  113. struct
  114. {
  115. uint64_t low;
  116. int64_t high;
  117. };
  118. } two64s;
  119. // This assumes r >= 0 && r <= 127
  120. __uint128_t __ashlti3(__uint128_t val, int r)
  121. {
  122. two64 in;
  123. two64 result;
  124. in.all = val;
  125. if (r == 0)
  126. {
  127. // nothing to do
  128. result.all = in.all;
  129. }
  130. else if (r & 64)
  131. {
  132. // Shift more than or equal 64
  133. result.low = 0;
  134. result.high = in.low << (r & 63);
  135. }
  136. else
  137. {
  138. // Shift less than 64
  139. result.low = in.low << r;
  140. result.high = (in.high << r) | (in.low >> (64 - r));
  141. }
  142. return result.all;
  143. }
  144. // This assumes r >= 0 && r <= 127
  145. __uint128_t __lshrti3(__uint128_t val, int r)
  146. {
  147. two64 in;
  148. two64 result;
  149. in.all = val;
  150. if (r == 0)
  151. {
  152. // nothing to do
  153. result.all = in.all;
  154. }
  155. else if (r & 64)
  156. {
  157. // Shift more than or equal 64
  158. result.low = in.high >> (r & 63);
  159. result.high = 0;
  160. }
  161. else
  162. {
  163. // Shift less than 64
  164. result.low = (in.low >> r) | (in.high << (64 - r));
  165. result.high = in.high >> r;
  166. }
  167. return result.all;
  168. }
  169. __uint128_t __ashrti3(__uint128_t val, int r)
  170. {
  171. two64s in;
  172. two64s result;
  173. in.all = val;
  174. if (r == 0)
  175. {
  176. // nothing to do
  177. result.all = in.all;
  178. }
  179. else if (r & 64)
  180. {
  181. // Shift more than or equal 64
  182. result.high = in.high >> 63;
  183. result.low = in.high >> (r & 63);
  184. }
  185. else
  186. {
  187. // Shift less than 64
  188. result.low = (in.low >> r) | (in.high << (64 - r));
  189. result.high = in.high >> r;
  190. }
  191. return result.all;
  192. }
  193. // Return the highest set bit in v
  194. int bits(uint64_t v)
  195. {
  196. int h = 63;
  197. if (!(v & 0xffffffff00000000))
  198. {
  199. h -= 32;
  200. v <<= 32;
  201. }
  202. if (!(v & 0xffff000000000000))
  203. {
  204. h -= 16;
  205. v <<= 16;
  206. }
  207. if (!(v & 0xff00000000000000))
  208. {
  209. h -= 8;
  210. v <<= 8;
  211. }
  212. if (!(v & 0xf000000000000000))
  213. {
  214. h -= 4;
  215. v <<= 4;
  216. }
  217. if (!(v & 0xc000000000000000))
  218. {
  219. h -= 2;
  220. v <<= 2;
  221. }
  222. if (!(v & 0x8000000000000000))
  223. {
  224. h -= 1;
  225. }
  226. return h;
  227. }
  228. int bits128(__uint128_t v)
  229. {
  230. uint64_t upper = v >> 64;
  231. if (upper)
  232. {
  233. return bits(upper) + 64;
  234. }
  235. else
  236. {
  237. return bits(v);
  238. }
  239. }
  240. __uint128_t shl128(__uint128_t val, int r)
  241. {
  242. if (r == 0)
  243. {
  244. return val;
  245. }
  246. else if (r & 64)
  247. {
  248. // Shift more than or equal 64
  249. uint64_t low = val;
  250. __uint128_t tmp = low << (r & 63);
  251. return tmp << 64;
  252. }
  253. else
  254. {
  255. // Shift less than 64
  256. uint64_t low = val;
  257. uint64_t high = val >> 64;
  258. __uint128_t tmp = (high << r) | (low >> (64 - r));
  259. return (low << r) | (tmp << 64);
  260. }
  261. }
  262. __uint128_t shr128(__uint128_t val, int r)
  263. {
  264. if (r == 0)
  265. {
  266. return val;
  267. }
  268. else if (r & 64)
  269. {
  270. // Shift more than or equal 64
  271. uint64_t high = val >> 64;
  272. high >>= r & 63;
  273. return high;
  274. }
  275. else
  276. {
  277. // Shift less than 64
  278. uint64_t low = val;
  279. uint64_t high = val >> 64;
  280. low >>= r;
  281. high <<= 64 - r;
  282. __uint128_t tmp = high;
  283. return low | (tmp << 64);
  284. }
  285. }
  286. int udivmod128(__uint128_t *pdividend, __uint128_t *pdivisor, __uint128_t *remainder, __uint128_t *quotient)
  287. {
  288. __uint128_t dividend = *pdividend;
  289. __uint128_t divisor = *pdivisor;
  290. if (divisor == 0)
  291. return 1;
  292. if (divisor == 1)
  293. {
  294. *remainder = 0;
  295. *quotient = dividend;
  296. return 0;
  297. }
  298. if (divisor == dividend)
  299. {
  300. *remainder = 0;
  301. *quotient = 1;
  302. return 0;
  303. }
  304. if (dividend == 0 || dividend < divisor)
  305. {
  306. *remainder = dividend;
  307. *quotient = 0;
  308. return 0;
  309. }
  310. __uint128_t q = 0, r = 0;
  311. for (int x = bits128(dividend) + 1; x > 0; x--)
  312. {
  313. q <<= 1;
  314. r <<= 1;
  315. if ((dividend >> (x - 1)) & 1)
  316. {
  317. r++;
  318. }
  319. if (r >= divisor)
  320. {
  321. r -= divisor;
  322. q++;
  323. }
  324. }
  325. *quotient = q;
  326. *remainder = r;
  327. return 0;
  328. }
  329. int sdivmod128(__uint128_t *pdividend, __uint128_t *pdivisor, __uint128_t *remainder, __uint128_t *quotient)
  330. {
  331. bool dividend_negative = ((uint8_t *)pdividend)[15] >= 128;
  332. if (dividend_negative)
  333. {
  334. __uint128_t dividend = *pdividend;
  335. *pdividend = -dividend;
  336. }
  337. bool divisor_negative = ((uint8_t *)pdivisor)[15] >= 128;
  338. if (divisor_negative)
  339. {
  340. __uint128_t divisor = *pdivisor;
  341. *pdivisor = -divisor;
  342. }
  343. if (udivmod128(pdividend, pdivisor, remainder, quotient))
  344. {
  345. return 1;
  346. }
  347. if (dividend_negative != divisor_negative)
  348. {
  349. __uint128_t q = *quotient;
  350. *quotient = -q;
  351. }
  352. if (dividend_negative)
  353. {
  354. __uint128_t r = *remainder;
  355. *remainder = -r;
  356. }
  357. return 0;
  358. }
  359. typedef unsigned _ExtInt(256) uint256_t;
  360. uint256_t const uint256_0 = (uint256_t)0;
  361. uint256_t const uint256_1 = (uint256_t)1;
  362. int bits256(uint256_t *value)
  363. {
  364. // 256 bits values consist of 4 uint64_ts.
  365. uint64_t *v = (uint64_t *)value;
  366. for (int i = 3; i >= 0; i--)
  367. {
  368. if (v[i])
  369. return bits(v[i]) + 64 * i;
  370. }
  371. return 0;
  372. }
  373. int udivmod256(uint256_t *pdividend, uint256_t *pdivisor, uint256_t *remainder, uint256_t *quotient)
  374. {
  375. uint256_t dividend = *pdividend;
  376. uint256_t divisor = *pdivisor;
  377. if (divisor == uint256_0)
  378. return 1;
  379. if (divisor == uint256_1)
  380. {
  381. *remainder = uint256_0;
  382. *quotient = dividend;
  383. return 0;
  384. }
  385. if (divisor == dividend)
  386. {
  387. *remainder = uint256_0;
  388. *quotient = uint256_1;
  389. return 0;
  390. }
  391. if (dividend == uint256_0 || dividend < divisor)
  392. {
  393. *remainder = dividend;
  394. *quotient = uint256_0;
  395. return 0;
  396. }
  397. uint256_t q = uint256_0, r = dividend;
  398. uint256_t copyd = divisor << (bits256(&dividend) - bits256(&divisor));
  399. uint256_t adder = uint256_1 << (bits256(&dividend) - bits256(&divisor));
  400. if (copyd > dividend)
  401. {
  402. copyd >>= 1;
  403. adder >>= 1;
  404. }
  405. while (r >= divisor)
  406. {
  407. if (r >= copyd)
  408. {
  409. r -= copyd;
  410. q |= adder;
  411. }
  412. copyd >>= 1;
  413. adder >>= 1;
  414. }
  415. *quotient = q;
  416. *remainder = r;
  417. return 0;
  418. }
  419. int sdivmod256(uint256_t *pdividend, uint256_t *pdivisor, uint256_t *remainder, uint256_t *quotient)
  420. {
  421. bool dividend_negative = ((uint8_t *)pdividend)[31] >= 128;
  422. if (dividend_negative)
  423. {
  424. uint256_t dividend = *pdividend;
  425. *pdividend = -dividend;
  426. }
  427. bool divisor_negative = ((uint8_t *)pdivisor)[31] >= 128;
  428. if (divisor_negative)
  429. {
  430. uint256_t divisor = *pdivisor;
  431. *pdivisor = -divisor;
  432. }
  433. if (udivmod256(pdividend, pdivisor, remainder, quotient))
  434. {
  435. return 1;
  436. }
  437. if (dividend_negative != divisor_negative)
  438. {
  439. uint256_t q = *quotient;
  440. *quotient = -q;
  441. }
  442. if (dividend_negative)
  443. {
  444. uint256_t r = *remainder;
  445. *remainder = -r;
  446. }
  447. return 0;
  448. }
  449. typedef unsigned _ExtInt(512) uint512_t;
  450. uint512_t const uint512_0 = (uint512_t)0;
  451. uint512_t const uint512_1 = (uint512_t)1;
  452. int bits512(uint512_t *value)
  453. {
  454. // 512 bits values consist of 8 uint64_ts.
  455. uint64_t *v = (uint64_t *)value;
  456. for (int i = 7; i >= 0; i--)
  457. {
  458. if (v[i])
  459. return bits(v[i]) + 64 * i;
  460. }
  461. return 0;
  462. }
  463. int udivmod512(uint512_t *pdividend, uint512_t *pdivisor, uint512_t *remainder, uint512_t *quotient)
  464. {
  465. uint512_t dividend = *pdividend;
  466. uint512_t divisor = *pdivisor;
  467. if (divisor == uint512_0)
  468. return 1;
  469. if (divisor == uint512_1)
  470. {
  471. *remainder = uint512_0;
  472. *quotient = dividend;
  473. return 0;
  474. }
  475. if (divisor == dividend)
  476. {
  477. *remainder = uint512_0;
  478. *quotient = uint512_1;
  479. return 0;
  480. }
  481. if (dividend == uint512_0 || dividend < divisor)
  482. {
  483. *remainder = dividend;
  484. *quotient = uint512_0;
  485. return 0;
  486. }
  487. uint512_t q = uint512_0, r = dividend;
  488. uint512_t copyd = divisor << (bits512(&dividend) - bits512(&divisor));
  489. uint512_t adder = uint512_1 << (bits512(&dividend) - bits512(&divisor));
  490. if (copyd > dividend)
  491. {
  492. copyd >>= 1;
  493. adder >>= 1;
  494. }
  495. while (r >= divisor)
  496. {
  497. if (r >= copyd)
  498. {
  499. r -= copyd;
  500. q |= adder;
  501. }
  502. copyd >>= 1;
  503. adder >>= 1;
  504. }
  505. *quotient = q;
  506. *remainder = r;
  507. return 0;
  508. }
  509. int sdivmod512(uint512_t *pdividend, uint512_t *pdivisor, uint512_t *remainder, uint512_t *quotient)
  510. {
  511. bool dividend_negative = ((uint8_t *)pdividend)[63] >= 128;
  512. if (dividend_negative)
  513. {
  514. uint512_t dividend = *pdividend;
  515. *pdividend = -dividend;
  516. }
  517. bool divisor_negative = ((uint8_t *)pdivisor)[63] >= 128;
  518. if (divisor_negative)
  519. {
  520. uint512_t divisor = *pdivisor;
  521. *pdivisor = -divisor;
  522. }
  523. if (udivmod512(pdividend, pdivisor, remainder, quotient))
  524. {
  525. return 1;
  526. }
  527. if (dividend_negative != divisor_negative)
  528. {
  529. uint512_t q = *quotient;
  530. *quotient = -q;
  531. }
  532. if (dividend_negative)
  533. {
  534. uint512_t r = *remainder;
  535. *remainder = -r;
  536. }
  537. return 0;
  538. }