浏览代码

Procedural SafeCast.sol generation (#3245)

Hadrien Croubois 3 年之前
父节点
当前提交
b61faf8368

+ 1 - 0
.github/workflows/test.yml

@@ -30,6 +30,7 @@ jobs:
           FORCE_COLOR: 1
           ENABLE_GAS_REPORT: true
       - run: npm run test:inheritance
+      - run: npm run test:generation
       - name: Print gas report
         run: cat gas-report.txt
 

+ 1 - 0
CHANGELOG.md

@@ -8,6 +8,7 @@
  * `ERC20FlashMint`: Add customizable flash fee receiver. ([#3327](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3327))
  * `Strings`: add a new overloaded function `toHexString` that converts an `address` with fixed length of 20 bytes to its not checksummed ASCII `string` hexadecimal representation. ([#3403](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3403))
  * `EnumerableMap`: add new `UintToUintMap` map type. ([#3338](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3338))
+ * `SafeCast`: add support for many more types, using procedural code generation. ([#3245](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3245))
 
 ## 4.6.0 (2022-04-26)
 

+ 200 - 0
contracts/mocks/SafeCastMock.sol

@@ -12,26 +12,122 @@ contract SafeCastMock {
         return a.toUint256();
     }
 
+    function toUint248(uint256 a) public pure returns (uint248) {
+        return a.toUint248();
+    }
+
+    function toUint240(uint256 a) public pure returns (uint240) {
+        return a.toUint240();
+    }
+
+    function toUint232(uint256 a) public pure returns (uint232) {
+        return a.toUint232();
+    }
+
     function toUint224(uint256 a) public pure returns (uint224) {
         return a.toUint224();
     }
 
+    function toUint216(uint256 a) public pure returns (uint216) {
+        return a.toUint216();
+    }
+
+    function toUint208(uint256 a) public pure returns (uint208) {
+        return a.toUint208();
+    }
+
+    function toUint200(uint256 a) public pure returns (uint200) {
+        return a.toUint200();
+    }
+
+    function toUint192(uint256 a) public pure returns (uint192) {
+        return a.toUint192();
+    }
+
+    function toUint184(uint256 a) public pure returns (uint184) {
+        return a.toUint184();
+    }
+
+    function toUint176(uint256 a) public pure returns (uint176) {
+        return a.toUint176();
+    }
+
+    function toUint168(uint256 a) public pure returns (uint168) {
+        return a.toUint168();
+    }
+
+    function toUint160(uint256 a) public pure returns (uint160) {
+        return a.toUint160();
+    }
+
+    function toUint152(uint256 a) public pure returns (uint152) {
+        return a.toUint152();
+    }
+
+    function toUint144(uint256 a) public pure returns (uint144) {
+        return a.toUint144();
+    }
+
+    function toUint136(uint256 a) public pure returns (uint136) {
+        return a.toUint136();
+    }
+
     function toUint128(uint256 a) public pure returns (uint128) {
         return a.toUint128();
     }
 
+    function toUint120(uint256 a) public pure returns (uint120) {
+        return a.toUint120();
+    }
+
+    function toUint112(uint256 a) public pure returns (uint112) {
+        return a.toUint112();
+    }
+
+    function toUint104(uint256 a) public pure returns (uint104) {
+        return a.toUint104();
+    }
+
     function toUint96(uint256 a) public pure returns (uint96) {
         return a.toUint96();
     }
 
+    function toUint88(uint256 a) public pure returns (uint88) {
+        return a.toUint88();
+    }
+
+    function toUint80(uint256 a) public pure returns (uint80) {
+        return a.toUint80();
+    }
+
+    function toUint72(uint256 a) public pure returns (uint72) {
+        return a.toUint72();
+    }
+
     function toUint64(uint256 a) public pure returns (uint64) {
         return a.toUint64();
     }
 
+    function toUint56(uint256 a) public pure returns (uint56) {
+        return a.toUint56();
+    }
+
+    function toUint48(uint256 a) public pure returns (uint48) {
+        return a.toUint48();
+    }
+
+    function toUint40(uint256 a) public pure returns (uint40) {
+        return a.toUint40();
+    }
+
     function toUint32(uint256 a) public pure returns (uint32) {
         return a.toUint32();
     }
 
+    function toUint24(uint256 a) public pure returns (uint24) {
+        return a.toUint24();
+    }
+
     function toUint16(uint256 a) public pure returns (uint16) {
         return a.toUint16();
     }
@@ -44,18 +140,122 @@ contract SafeCastMock {
         return a.toInt256();
     }
 
+    function toInt248(int256 a) public pure returns (int248) {
+        return a.toInt248();
+    }
+
+    function toInt240(int256 a) public pure returns (int240) {
+        return a.toInt240();
+    }
+
+    function toInt232(int256 a) public pure returns (int232) {
+        return a.toInt232();
+    }
+
+    function toInt224(int256 a) public pure returns (int224) {
+        return a.toInt224();
+    }
+
+    function toInt216(int256 a) public pure returns (int216) {
+        return a.toInt216();
+    }
+
+    function toInt208(int256 a) public pure returns (int208) {
+        return a.toInt208();
+    }
+
+    function toInt200(int256 a) public pure returns (int200) {
+        return a.toInt200();
+    }
+
+    function toInt192(int256 a) public pure returns (int192) {
+        return a.toInt192();
+    }
+
+    function toInt184(int256 a) public pure returns (int184) {
+        return a.toInt184();
+    }
+
+    function toInt176(int256 a) public pure returns (int176) {
+        return a.toInt176();
+    }
+
+    function toInt168(int256 a) public pure returns (int168) {
+        return a.toInt168();
+    }
+
+    function toInt160(int256 a) public pure returns (int160) {
+        return a.toInt160();
+    }
+
+    function toInt152(int256 a) public pure returns (int152) {
+        return a.toInt152();
+    }
+
+    function toInt144(int256 a) public pure returns (int144) {
+        return a.toInt144();
+    }
+
+    function toInt136(int256 a) public pure returns (int136) {
+        return a.toInt136();
+    }
+
     function toInt128(int256 a) public pure returns (int128) {
         return a.toInt128();
     }
 
+    function toInt120(int256 a) public pure returns (int120) {
+        return a.toInt120();
+    }
+
+    function toInt112(int256 a) public pure returns (int112) {
+        return a.toInt112();
+    }
+
+    function toInt104(int256 a) public pure returns (int104) {
+        return a.toInt104();
+    }
+
+    function toInt96(int256 a) public pure returns (int96) {
+        return a.toInt96();
+    }
+
+    function toInt88(int256 a) public pure returns (int88) {
+        return a.toInt88();
+    }
+
+    function toInt80(int256 a) public pure returns (int80) {
+        return a.toInt80();
+    }
+
+    function toInt72(int256 a) public pure returns (int72) {
+        return a.toInt72();
+    }
+
     function toInt64(int256 a) public pure returns (int64) {
         return a.toInt64();
     }
 
+    function toInt56(int256 a) public pure returns (int56) {
+        return a.toInt56();
+    }
+
+    function toInt48(int256 a) public pure returns (int48) {
+        return a.toInt48();
+    }
+
+    function toInt40(int256 a) public pure returns (int40) {
+        return a.toInt40();
+    }
+
     function toInt32(int256 a) public pure returns (int32) {
         return a.toInt32();
     }
 
+    function toInt24(int256 a) public pure returns (int24) {
+        return a.toInt24();
+    }
+
     function toInt16(int256 a) public pure returns (int16) {
         return a.toInt16();
     }

文件差异内容过多而无法显示
+ 894 - 2
contracts/utils/math/SafeCast.sol


+ 5 - 2
package.json

@@ -25,11 +25,14 @@
     "clean": "hardhat clean && rimraf build contracts/build",
     "prepare": "npm run clean && env COMPILE_MODE=production npm run compile",
     "prepack": "scripts/prepack.sh",
+    "generate": "scripts/generate/run.js",
     "release": "scripts/release/release.sh",
     "version": "scripts/release/version.sh",
     "test": "hardhat test",
-    "test:inheritance": "node scripts/inheritanceOrdering artifacts/build-info/*",
-    "gas-report": "env ENABLE_GAS_REPORT=true npm run test"
+    "test:inheritance": "scripts/checks/inheritanceOrdering.js artifacts/build-info/*",
+    "test:generation": "scripts/checks/generation.sh",
+    "gas-report": "env ENABLE_GAS_REPORT=true npm run test",
+    "slither": "npm run clean && slither . --detect reentrancy-eth,reentrancy-no-eth,reentrancy-unlimited-gas"
   },
   "repository": {
     "type": "git",

+ 6 - 0
scripts/checks/generation.sh

@@ -0,0 +1,6 @@
+#!/usr/bin/env bash
+
+set -euo pipefail
+
+npm run generate
+git diff --quiet --exit-code

+ 3 - 1
scripts/inheritanceOrdering.js → scripts/checks/inheritanceOrdering.js

@@ -1,10 +1,12 @@
+#!/usr/bin/env node
+
 const path = require('path');
 const graphlib = require('graphlib');
 const { findAll } = require('solidity-ast/utils');
 const { _: artifacts } = require('yargs').argv;
 
 for (const artifact of artifacts) {
-  const { output: solcOutput } = require(path.resolve(__dirname, '..', artifact));
+  const { output: solcOutput } = require(path.resolve(__dirname, '../..', artifact));
 
   const graph = new graphlib.Graph({ directed: true });
   const names = {};

+ 16 - 0
scripts/generate/format-lines.js

@@ -0,0 +1,16 @@
+function formatLines (...lines) {
+  return [...indentEach(0, lines)].join('\n') + '\n';
+}
+
+function *indentEach (indent, lines) {
+  for (const line of lines) {
+    if (Array.isArray(line)) {
+      yield * indentEach(indent + 1, line);
+    } else {
+      const padding = '    '.repeat(indent);
+      yield * line.split('\n').map(subline => subline === '' ? '' : padding + subline);
+    }
+  }
+}
+
+module.exports = formatLines;

+ 29 - 0
scripts/generate/run.js

@@ -0,0 +1,29 @@
+#!/usr/bin/env node
+
+const fs = require('fs');
+const format = require('./format-lines');
+
+function getVersion (path) {
+  try {
+    return fs
+      .readFileSync(path, 'utf8')
+      .match(/\/\/ OpenZeppelin Contracts \(last updated v\d+\.\d+\.\d+\)/)[0];
+  } catch (err) {
+    return null;
+  }
+}
+
+for (const [ file, template ] of Object.entries({
+  'utils/math/SafeCast.sol': './templates/SafeCast',
+  'mocks/SafeCastMock.sol': './templates/SafeCastMock',
+})) {
+  const path = `./contracts/${file}`;
+  const version = getVersion(path);
+  const content = format(
+    '// SPDX-License-Identifier: MIT',
+    (version ? version + ` (${file})\n` : ''),
+    require(template).trimEnd(),
+  );
+
+  fs.writeFileSync(path, content);
+}

+ 168 - 0
scripts/generate/templates/SafeCast.js

@@ -0,0 +1,168 @@
+const assert = require('assert');
+const format = require('../format-lines');
+const { range } = require('../../helpers');
+
+const LENGTHS = range(8, 256, 8).reverse(); // 248 → 8 (in steps of 8)
+
+// Returns the version of OpenZeppelin Contracts in which a particular function was introduced.
+// This is used in the docs for each function.
+const version = (selector, length) => {
+  switch (selector) {
+  case 'toUint(uint)': {
+    switch (length) {
+    case 8:
+    case 16:
+    case 32:
+    case 64:
+    case 128:
+      return '2.5';
+    case 96:
+    case 224:
+      return '4.2';
+    default:
+      assert(LENGTHS.includes(length));
+      return '4.7';
+    }
+  }
+  case 'toInt(int)': {
+    switch (length) {
+    case 8:
+    case 16:
+    case 32:
+    case 64:
+    case 128:
+      return '3.1';
+    default:
+      assert(LENGTHS.includes(length));
+      return '4.7';
+    }
+  }
+  case 'toUint(int)': {
+    switch (length) {
+    case 256:
+      return '3.0';
+    default:
+      assert(false);
+      return;
+    }
+  }
+  case 'toInt(uint)': {
+    switch (length) {
+    case 256:
+      return '3.0';
+    default:
+      assert(false);
+      return;
+    }
+  }
+  default:
+    assert(false);
+  }
+};
+
+const header = `\
+pragma solidity ^0.8.0;
+
+/**
+ * @dev Wrappers over Solidity's uintXX/intXX casting operators with added overflow
+ * checks.
+ *
+ * Downcasting from uint256/int256 in Solidity does not revert on overflow. This can
+ * easily result in undesired exploitation or bugs, since developers usually
+ * assume that overflows raise errors. \`SafeCast\` restores this intuition by
+ * reverting the transaction when such an operation overflows.
+ *
+ * Using this library instead of the unchecked operations eliminates an entire
+ * class of bugs, so it's recommended to use it always.
+ *
+ * Can be combined with {SafeMath} and {SignedSafeMath} to extend it to smaller types, by performing
+ * all math on \`uint256\` and \`int256\` and then downcasting.
+ */
+`;
+
+const toUintDownCast = length => `\
+/**
+ * @dev Returns the downcasted uint${length} from uint256, reverting on
+ * overflow (when the input is greater than largest uint${length}).
+ *
+ * Counterpart to Solidity's \`uint${length}\` operator.
+ *
+ * Requirements:
+ *
+ * - input must fit into ${length} bits
+ *
+ * _Available since v${version('toUint(uint)', length)}._
+ */
+function toUint${length}(uint256 value) internal pure returns (uint${length}) {
+    require(value <= type(uint${length}).max, "SafeCast: value doesn't fit in ${length} bits");
+    return uint${length}(value);
+}
+`;
+
+/* eslint-disable max-len */
+const toIntDownCast = length => `\
+/**
+ * @dev Returns the downcasted int${length} from int256, reverting on
+ * overflow (when the input is less than smallest int${length} or
+ * greater than largest int${length}).
+ *
+ * Counterpart to Solidity's \`int${length}\` operator.
+ *
+ * Requirements:
+ *
+ * - input must fit into ${length} bits
+ *
+ * _Available since v${version('toInt(int)', length)}._
+ */
+function toInt${length}(int256 value) internal pure returns (int${length}) {
+    require(value >= type(int${length}).min && value <= type(int${length}).max, "SafeCast: value doesn't fit in ${length} bits");
+    return int${length}(value);
+}
+`;
+/* eslint-enable max-len */
+
+const toInt = length => `\
+/**
+ * @dev Converts an unsigned uint${length} into a signed int${length}.
+ *
+ * Requirements:
+ *
+ * - input must be less than or equal to maxInt${length}.
+ *
+ * _Available since v${version('toInt(uint)', length)}._
+ */
+function toInt${length}(uint${length} value) internal pure returns (int${length}) {
+    // Note: Unsafe cast below is okay because \`type(int${length}).max\` is guaranteed to be positive
+    require(value <= uint${length}(type(int${length}).max), "SafeCast: value doesn't fit in an int${length}");
+    return int${length}(value);
+}
+`;
+
+const toUint = length => `\
+/**
+ * @dev Converts a signed int${length} into an unsigned uint${length}.
+ *
+ * Requirements:
+ *
+ * - input must be greater than or equal to 0.
+ *
+ * _Available since v${version('toUint(int)', length)}._
+ */
+function toUint${length}(int${length} value) internal pure returns (uint${length}) {
+    require(value >= 0, "SafeCast: value must be positive");
+    return uint${length}(value);
+}
+`;
+
+// GENERATE
+module.exports = format(
+  header.trimEnd(),
+  'library SafeCast {',
+  [
+    ...LENGTHS.map(size => toUintDownCast(size)),
+    toUint(256),
+    ...LENGTHS.map(size => toIntDownCast(size)),
+    toInt(256).trimEnd(),
+  ],
+  '}',
+);

+ 50 - 0
scripts/generate/templates/SafeCastMock.js

@@ -0,0 +1,50 @@
+const format = require('../format-lines');
+const { range } = require('../../helpers');
+
+const LENGTHS = range(8, 256, 8).reverse(); // 248 → 8 (in steps of 8)
+
+const header = `\
+pragma solidity ^0.8.0;
+
+import "../utils/math/SafeCast.sol";
+`;
+
+const toInt = length => `\
+function toInt${length}(uint${length} a) public pure returns (int${length}) {
+    return a.toInt${length}();
+}
+`;
+
+const toUint = length => `\
+function toUint${length}(int${length} a) public pure returns (uint${length}) {
+    return a.toUint${length}();
+}
+`;
+
+const toIntDownCast = length => `\
+function toInt${length}(int256 a) public pure returns (int${length}) {
+    return a.toInt${length}();
+}
+`;
+
+const toUintDownCast = length => `\
+function toUint${length}(uint256 a) public pure returns (uint${length}) {
+    return a.toUint${length}();
+}
+`;
+
+// GENERATE
+module.exports = format(
+  header,
+  'contract SafeCastMock {',
+  [
+    'using SafeCast for uint256;',
+    'using SafeCast for int256;',
+    '',
+    toUint(256),
+    ...LENGTHS.map(size => toUintDownCast(size)),
+    toInt(256),
+    ...LENGTHS.map(size => toIntDownCast(size)),
+  ].flatMap(fn => fn.split('\n')).slice(0, -1),
+  '}',
+);

+ 23 - 0
scripts/helpers.js

@@ -0,0 +1,23 @@
+function chunk (array, size = 1) {
+  return Array.range(Math.ceil(array.length / size)).map(i => array.slice(i * size, i * size + size));
+}
+
+function range (start, stop = undefined, step = 1) {
+  if (!stop) { stop = start; start = 0; }
+  return start < stop ? Array(Math.ceil((stop - start) / step)).fill().map((_, i) => start + i * step) : [];
+}
+
+function unique (array, op = x => x) {
+  return array.filter((obj, i) => array.findIndex(entry => op(obj) === op(entry)) === i);
+}
+
+function zip (...args) {
+  return Array(Math.max(...args.map(arg => arg.length))).fill(null).map((_, i) => args.map(arg => arg[i]));
+}
+
+module.exports = {
+  chunk,
+  range,
+  unique,
+  zip,
+};

+ 3 - 3
test/utils/math/SafeCast.test.js

@@ -1,6 +1,6 @@
 const { BN, expectRevert } = require('@openzeppelin/test-helpers');
-
 const { expect } = require('chai');
+const { range } = require('../../../scripts/helpers');
 
 const SafeCastMock = artifacts.require('SafeCastMock');
 
@@ -41,7 +41,7 @@ contract('SafeCast', async (accounts) => {
     });
   }
 
-  [8, 16, 32, 64, 96, 128, 224].forEach(bits => testToUint(bits));
+  range(8, 256, 8).forEach(bits => testToUint(bits));
 
   describe('toUint256', () => {
     const maxInt256 = new BN('2').pow(new BN(255)).subn(1);
@@ -129,7 +129,7 @@ contract('SafeCast', async (accounts) => {
     });
   }
 
-  [8, 16, 32, 64, 128].forEach(bits => testToInt(bits));
+  range(8, 256, 8).forEach(bits => testToInt(bits));
 
   describe('toInt256', () => {
     const maxUint256 = new BN('2').pow(new BN(256)).subn(1);

部分文件因为文件数量过多而无法显示