Ver Fonte

Refactor SafeMath to avoid memory leaks (#2462)

Co-authored-by: Francisco Giordano <frangio.1@gmail.com>
Hadrien Croubois há 4 anos atrás
pai
commit
c34211417c

+ 4 - 2
CHANGELOG.md

@@ -7,9 +7,11 @@
  * `ERC20Permit`: added an implementation of the ERC20 permit extension for gasless token approvals. ([#2237](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2237))
  * Presets: added token presets with preminted fixed supply `ERC20PresetFixedSupply` and `ERC777PresetFixedSupply`. ([#2399](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2399))
  * `Address`: added `functionDelegateCall`, similar to the existing `functionCall`. ([#2333](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2333))
- * `Context`: moved from `contracts/GSN` to `contracts/utils`. ([#2453](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2453)) 
+ * `Context`: moved from `contracts/GSN` to `contracts/utils`. ([#2453](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2453))
  * `PaymentSplitter`: replace usage of `.transfer()` with `Address.sendValue` for improved compatibility with smart wallets. ([#2455](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2455))
- * `UpgradeableProxy`: bubble revert reasons from initialization calls. ([#2454](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2454)) 
+ * `UpgradeableProxy`: bubble revert reasons from initialization calls. ([#2454](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2454))
+ * `SafeMath`: fix a memory allocation issue by adding new `SafeMath.tryOp(uint,uint)→(bool,uint)` functions. `SafeMath.op(uint,uint,string)→uint` are now deprecated. ([#2462](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2462))
+ * `EnumerableMap`: fix a memory allocation issue by adding new `EnumerableMap.tryGet(uint)→(bool,address)` functions. `EnumerableMap.get(uint)→string` is now deprecated. ([#2462](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2462))
 
 ## 3.3.0 (2020-11-26)
 

+ 93 - 48
contracts/math/SafeMath.sol

@@ -16,6 +16,52 @@ pragma solidity >=0.6.0 <0.8.0;
  * class of bugs, so it's recommended to use it always.
  */
 library SafeMath {
+    /**
+     * @dev Returns the addition of two unsigned integers, with an overflow flag.
+     */
+    function tryAdd(uint256 a, uint256 b) internal pure returns (bool, uint256) {
+        uint256 c = a + b;
+        if (c < a) return (false, 0);
+        return (true, c);
+    }
+
+    /**
+     * @dev Returns the substraction of two unsigned integers, with an overflow flag.
+     */
+    function trySub(uint256 a, uint256 b) internal pure returns (bool, uint256) {
+        if (b > a) return (false, 0);
+        return (true, a - b);
+    }
+
+    /**
+     * @dev Returns the multiplication of two unsigned integers, with an overflow flag.
+     */
+    function tryMul(uint256 a, uint256 b) internal pure returns (bool, uint256) {
+        // Gas optimization: this is cheaper than requiring 'a' not being zero, but the
+        // benefit is lost if 'b' is also tested.
+        // See: https://github.com/OpenZeppelin/openzeppelin-contracts/pull/522
+        if (a == 0) return (true, 0);
+        uint256 c = a * b;
+        if (c / a != b) return (false, 0);
+        return (true, c);
+    }
+
+    /**
+     * @dev Returns the division of two unsigned integers, with a division by zero flag.
+     */
+    function tryDiv(uint256 a, uint256 b) internal pure returns (bool, uint256) {
+        if (b == 0) return (false, 0);
+        return (true, a / b);
+    }
+
+    /**
+     * @dev Returns the remainder of dividing two unsigned integers, with a division by zero flag.
+     */
+    function tryMod(uint256 a, uint256 b) internal pure returns (bool, uint256) {
+        if (b == 0) return (false, 0);
+        return (true, a % b);
+    }
+
     /**
      * @dev Returns the addition of two unsigned integers, reverting on
      * overflow.
@@ -29,7 +75,6 @@ library SafeMath {
     function add(uint256 a, uint256 b) internal pure returns (uint256) {
         uint256 c = a + b;
         require(c >= a, "SafeMath: addition overflow");
-
         return c;
     }
 
@@ -44,24 +89,8 @@ library SafeMath {
      * - Subtraction cannot overflow.
      */
     function sub(uint256 a, uint256 b) internal pure returns (uint256) {
-        return sub(a, b, "SafeMath: subtraction overflow");
-    }
-
-    /**
-     * @dev Returns the subtraction of two unsigned integers, reverting with custom message on
-     * overflow (when the result is negative).
-     *
-     * Counterpart to Solidity's `-` operator.
-     *
-     * Requirements:
-     *
-     * - Subtraction cannot overflow.
-     */
-    function sub(uint256 a, uint256 b, string memory errorMessage) internal pure returns (uint256) {
-        require(b <= a, errorMessage);
-        uint256 c = a - b;
-
-        return c;
+        require(b <= a, "SafeMath: subtraction overflow");
+        return a - b;
     }
 
     /**
@@ -75,21 +104,14 @@ library SafeMath {
      * - Multiplication cannot overflow.
      */
     function mul(uint256 a, uint256 b) internal pure returns (uint256) {
-        // Gas optimization: this is cheaper than requiring 'a' not being zero, but the
-        // benefit is lost if 'b' is also tested.
-        // See: https://github.com/OpenZeppelin/openzeppelin-contracts/pull/522
-        if (a == 0) {
-            return 0;
-        }
-
+        if (a == 0) return 0;
         uint256 c = a * b;
         require(c / a == b, "SafeMath: multiplication overflow");
-
         return c;
     }
 
     /**
-     * @dev Returns the integer division of two unsigned integers. Reverts on
+     * @dev Returns the integer division of two unsigned integers, reverting on
      * division by zero. The result is rounded towards zero.
      *
      * Counterpart to Solidity's `/` operator. Note: this function uses a
@@ -101,48 +123,71 @@ library SafeMath {
      * - The divisor cannot be zero.
      */
     function div(uint256 a, uint256 b) internal pure returns (uint256) {
-        return div(a, b, "SafeMath: division by zero");
+        require(b > 0, "SafeMath: division by zero");
+        return a / b;
     }
 
     /**
-     * @dev Returns the integer division of two unsigned integers. Reverts with custom message on
-     * division by zero. The result is rounded towards zero.
+     * @dev Returns the remainder of dividing two unsigned integers. (unsigned integer modulo),
+     * reverting when dividing by zero.
      *
-     * Counterpart to Solidity's `/` operator. Note: this function uses a
-     * `revert` opcode (which leaves remaining gas untouched) while Solidity
-     * uses an invalid opcode to revert (consuming all remaining gas).
+     * Counterpart to Solidity's `%` operator. This function uses a `revert`
+     * opcode (which leaves remaining gas untouched) while Solidity uses an
+     * invalid opcode to revert (consuming all remaining gas).
      *
      * Requirements:
      *
      * - The divisor cannot be zero.
      */
-    function div(uint256 a, uint256 b, string memory errorMessage) internal pure returns (uint256) {
-        require(b > 0, errorMessage);
-        uint256 c = a / b;
-        // assert(a == b * c + a % b); // There is no case in which this doesn't hold
+    function mod(uint256 a, uint256 b) internal pure returns (uint256) {
+        require(b > 0, "SafeMath: modulo by zero");
+        return a % b;
+    }
 
-        return c;
+    /**
+     * @dev Returns the subtraction of two unsigned integers, reverting with custom message on
+     * overflow (when the result is negative).
+     *
+     * CAUTION: This function is deprecated because it requires allocating memory for the error
+     * message unnecessarily. For custom revert reasons use {trySub}.
+     *
+     * Counterpart to Solidity's `-` operator.
+     *
+     * Requirements:
+     *
+     * - Subtraction cannot overflow.
+     */
+    function sub(uint256 a, uint256 b, string memory errorMessage) internal pure returns (uint256) {
+        require(b <= a, errorMessage);
+        return a - b;
     }
 
     /**
-     * @dev Returns the remainder of dividing two unsigned integers. (unsigned integer modulo),
-     * Reverts when dividing by zero.
+     * @dev Returns the integer division of two unsigned integers, reverting with custom message on
+     * division by zero. The result is rounded towards zero.
      *
-     * Counterpart to Solidity's `%` operator. This function uses a `revert`
-     * opcode (which leaves remaining gas untouched) while Solidity uses an
-     * invalid opcode to revert (consuming all remaining gas).
+     * CAUTION: This function is deprecated because it requires allocating memory for the error
+     * message unnecessarily. For custom revert reasons use {tryDiv}.
+     *
+     * Counterpart to Solidity's `/` operator. Note: this function uses a
+     * `revert` opcode (which leaves remaining gas untouched) while Solidity
+     * uses an invalid opcode to revert (consuming all remaining gas).
      *
      * Requirements:
      *
      * - The divisor cannot be zero.
      */
-    function mod(uint256 a, uint256 b) internal pure returns (uint256) {
-        return mod(a, b, "SafeMath: modulo by zero");
+    function div(uint256 a, uint256 b, string memory errorMessage) internal pure returns (uint256) {
+        require(b > 0, errorMessage);
+        return a / b;
     }
 
     /**
      * @dev Returns the remainder of dividing two unsigned integers. (unsigned integer modulo),
-     * Reverts with custom message when dividing by zero.
+     * reverting with custom message when dividing by zero.
+     *
+     * CAUTION: This function is deprecated because it requires allocating memory for the error
+     * message unnecessarily. For custom revert reasons use {tryMod}.
      *
      * Counterpart to Solidity's `%` operator. This function uses a `revert`
      * opcode (which leaves remaining gas untouched) while Solidity uses an
@@ -153,7 +198,7 @@ library SafeMath {
      * - The divisor cannot be zero.
      */
     function mod(uint256 a, uint256 b, string memory errorMessage) internal pure returns (uint256) {
-        require(b != 0, errorMessage);
+        require(b > 0, errorMessage);
         return a % b;
     }
 }

+ 8 - 0
contracts/mocks/EnumerableMapMock.sol

@@ -34,7 +34,15 @@ contract EnumerableMapMock {
     }
 
 
+    function tryGet(uint256 key) public view returns (bool, address) {
+        return _map.tryGet(key);
+    }
+
     function get(uint256 key) public view returns (address) {
         return _map.get(key);
     }
+
+    function getWithMessage(uint256 key, string calldata errorMessage) public view returns (address) {
+        return _map.get(key, errorMessage);
+    }
 }

+ 84 - 6
contracts/mocks/SafeMathMock.sol

@@ -5,23 +5,101 @@ pragma solidity >=0.6.0 <0.8.0;
 import "../math/SafeMath.sol";
 
 contract SafeMathMock {
-    function mul(uint256 a, uint256 b) public pure returns (uint256) {
-        return SafeMath.mul(a, b);
+    function tryAdd(uint256 a, uint256 b) public pure returns (bool flag, uint256 value) {
+        return SafeMath.tryAdd(a, b);
     }
 
-    function div(uint256 a, uint256 b) public pure returns (uint256) {
-        return SafeMath.div(a, b);
+    function trySub(uint256 a, uint256 b) public pure returns (bool flag, uint256 value) {
+        return SafeMath.trySub(a, b);
     }
 
-    function sub(uint256 a, uint256 b) public pure returns (uint256) {
-        return SafeMath.sub(a, b);
+    function tryMul(uint256 a, uint256 b) public pure returns (bool flag, uint256 value) {
+        return SafeMath.tryMul(a, b);
+    }
+
+    function tryDiv(uint256 a, uint256 b) public pure returns (bool flag, uint256 value) {
+        return SafeMath.tryDiv(a, b);
+    }
+
+    function tryMod(uint256 a, uint256 b) public pure returns (bool flag, uint256 value) {
+        return SafeMath.tryMod(a, b);
     }
 
     function add(uint256 a, uint256 b) public pure returns (uint256) {
         return SafeMath.add(a, b);
     }
 
+    function sub(uint256 a, uint256 b) public pure returns (uint256) {
+        return SafeMath.sub(a, b);
+    }
+
+    function mul(uint256 a, uint256 b) public pure returns (uint256) {
+        return SafeMath.mul(a, b);
+    }
+
+    function div(uint256 a, uint256 b) public pure returns (uint256) {
+        return SafeMath.div(a, b);
+    }
+
     function mod(uint256 a, uint256 b) public pure returns (uint256) {
         return SafeMath.mod(a, b);
     }
+
+    function subWithMessage(uint256 a, uint256 b, string memory errorMessage) public pure returns (uint256) {
+        return SafeMath.sub(a, b, errorMessage);
+    }
+
+    function divWithMessage(uint256 a, uint256 b, string memory errorMessage) public pure returns (uint256) {
+        return SafeMath.div(a, b, errorMessage);
+    }
+
+    function modWithMessage(uint256 a, uint256 b, string memory errorMessage) public pure returns (uint256) {
+        return SafeMath.mod(a, b, errorMessage);
+    }
+
+    function addMemoryCheck() public pure returns (uint256 mem) {
+        uint256 length = 32;
+        // solhint-disable-next-line no-inline-assembly
+        assembly { mem := mload(0x40) }
+        for (uint256 i = 0; i < length; ++i) { SafeMath.add(1, 1); }
+        // solhint-disable-next-line no-inline-assembly
+        assembly { mem := sub(mload(0x40), mem) }
+    }
+
+    function subMemoryCheck() public pure returns (uint256 mem) {
+        uint256 length = 32;
+        // solhint-disable-next-line no-inline-assembly
+        assembly { mem := mload(0x40) }
+        for (uint256 i = 0; i < length; ++i) { SafeMath.sub(1, 1); }
+        // solhint-disable-next-line no-inline-assembly
+        assembly { mem := sub(mload(0x40), mem) }
+    }
+
+    function mulMemoryCheck() public pure returns (uint256 mem) {
+        uint256 length = 32;
+        // solhint-disable-next-line no-inline-assembly
+        assembly { mem := mload(0x40) }
+        for (uint256 i = 0; i < length; ++i) { SafeMath.mul(1, 1); }
+        // solhint-disable-next-line no-inline-assembly
+        assembly { mem := sub(mload(0x40), mem) }
+    }
+
+    function divMemoryCheck() public pure returns (uint256 mem) {
+        uint256 length = 32;
+        // solhint-disable-next-line no-inline-assembly
+        assembly { mem := mload(0x40) }
+        for (uint256 i = 0; i < length; ++i) { SafeMath.div(1, 1); }
+        // solhint-disable-next-line no-inline-assembly
+        assembly { mem := sub(mload(0x40), mem) }
+    }
+
+    function modMemoryCheck() public pure returns (uint256 mem) {
+        uint256 length = 32;
+        // solhint-disable-next-line no-inline-assembly
+        assembly { mem := mload(0x40) }
+        for (uint256 i = 0; i < length; ++i) { SafeMath.mod(1, 1); }
+        // solhint-disable-next-line no-inline-assembly
+        assembly { mem := sub(mload(0x40), mem) }
+    }
+
 }

+ 28 - 1
contracts/utils/EnumerableMap.sol

@@ -143,6 +143,16 @@ library EnumerableMap {
         return (entry._key, entry._value);
     }
 
+    /**
+     * @dev Tries to returns the value associated with `key`.  O(1).
+     * Does not revert if `key` is not in the map.
+     */
+    function _tryGet(Map storage map, bytes32 key) private view returns (bool, bytes32) {
+        uint256 keyIndex = map._indexes[key];
+        if (keyIndex == 0) return (false, 0); // Equivalent to contains(map, key)
+        return (true, map._entries[keyIndex - 1]._value); // All indexes are 1-based
+    }
+
     /**
      * @dev Returns the value associated with `key`.  O(1).
      *
@@ -151,11 +161,16 @@ library EnumerableMap {
      * - `key` must be in the map.
      */
     function _get(Map storage map, bytes32 key) private view returns (bytes32) {
-        return _get(map, key, "EnumerableMap: nonexistent key");
+        uint256 keyIndex = map._indexes[key];
+        require(keyIndex != 0, "EnumerableMap: nonexistent key"); // Equivalent to contains(map, key)
+        return map._entries[keyIndex - 1]._value; // All indexes are 1-based
     }
 
     /**
      * @dev Same as {_get}, with a custom error message when `key` is not in the map.
+     *
+     * CAUTION: This function is deprecated because it requires allocating memory for the error
+     * message unnecessarily. For custom revert reasons use {_tryGet}.
      */
     function _get(Map storage map, bytes32 key, string memory errorMessage) private view returns (bytes32) {
         uint256 keyIndex = map._indexes[key];
@@ -217,6 +232,15 @@ library EnumerableMap {
         return (uint256(key), address(uint160(uint256(value))));
     }
 
+    /**
+     * @dev Tries to returns the value associated with `key`.  O(1).
+     * Does not revert if `key` is not in the map.
+     */
+    function tryGet(UintToAddressMap storage map, uint256 key) internal view returns (bool, address) {
+        (bool success, bytes32 value) = _tryGet(map._inner, bytes32(key));
+        return (success, address(uint160(uint256(value))));
+    }
+
     /**
      * @dev Returns the value associated with `key`.  O(1).
      *
@@ -230,6 +254,9 @@ library EnumerableMap {
 
     /**
      * @dev Same as {get}, with a custom error message when `key` is not in the map.
+     *
+     * CAUTION: This function is deprecated because it requires allocating memory for the error
+     * message unnecessarily. For custom revert reasons use {tryGet}.
      */
     function get(UintToAddressMap storage map, uint256 key, string memory errorMessage) internal view returns (address) {
         return address(uint160(uint256(_get(map._inner, bytes32(key), errorMessage))));

+ 327 - 75
test/math/SafeMath.test.js

@@ -5,142 +5,394 @@ const { expect } = require('chai');
 
 const SafeMathMock = artifacts.require('SafeMathMock');
 
+function expectStruct (value, expected) {
+  for (const key in expected) {
+    if (BN.isBN(value[key])) {
+      expect(value[key]).to.be.bignumber.equal(expected[key]);
+    } else {
+      expect(value[key]).to.be.equal(expected[key]);
+    }
+  }
+}
+
 contract('SafeMath', function (accounts) {
   beforeEach(async function () {
     this.safeMath = await SafeMathMock.new();
   });
 
-  async function testCommutative (fn, lhs, rhs, expected) {
-    expect(await fn(lhs, rhs)).to.be.bignumber.equal(expected);
-    expect(await fn(rhs, lhs)).to.be.bignumber.equal(expected);
+  async function testCommutative (fn, lhs, rhs, expected, ...extra) {
+    expect(await fn(lhs, rhs, ...extra)).to.be.bignumber.equal(expected);
+    expect(await fn(rhs, lhs, ...extra)).to.be.bignumber.equal(expected);
   }
 
-  async function testFailsCommutative (fn, lhs, rhs, reason) {
-    await expectRevert(fn(lhs, rhs), reason);
-    await expectRevert(fn(rhs, lhs), reason);
+  async function testFailsCommutative (fn, lhs, rhs, reason, ...extra) {
+    await expectRevert(fn(lhs, rhs, ...extra), reason);
+    await expectRevert(fn(rhs, lhs, ...extra), reason);
   }
 
-  describe('add', function () {
-    it('adds correctly', async function () {
-      const a = new BN('5678');
-      const b = new BN('1234');
+  async function testCommutativeIterable (fn, lhs, rhs, expected, ...extra) {
+    expectStruct(await fn(lhs, rhs, ...extra), expected);
+    expectStruct(await fn(rhs, lhs, ...extra), expected);
+  }
 
-      await testCommutative(this.safeMath.add, a, b, a.add(b));
-    });
+  describe('with flag', function () {
+    describe('add', function () {
+      it('adds correctly', async function () {
+        const a = new BN('5678');
+        const b = new BN('1234');
 
-    it('reverts on addition overflow', async function () {
-      const a = MAX_UINT256;
-      const b = new BN('1');
+        testCommutativeIterable(this.safeMath.tryAdd, a, b, { flag: true, value: a.add(b) });
+      });
+
+      it('reverts on addition overflow', async function () {
+        const a = MAX_UINT256;
+        const b = new BN('1');
 
-      await testFailsCommutative(this.safeMath.add, a, b, 'SafeMath: addition overflow');
+        testCommutativeIterable(this.safeMath.tryAdd, a, b, { flag: false, value: '0' });
+      });
     });
-  });
 
-  describe('sub', function () {
-    it('subtracts correctly', async function () {
-      const a = new BN('5678');
-      const b = new BN('1234');
+    describe('sub', function () {
+      it('subtracts correctly', async function () {
+        const a = new BN('5678');
+        const b = new BN('1234');
 
-      expect(await this.safeMath.sub(a, b)).to.be.bignumber.equal(a.sub(b));
-    });
+        expectStruct(await this.safeMath.trySub(a, b), { flag: true, value: a.sub(b) });
+      });
 
-    it('reverts if subtraction result would be negative', async function () {
-      const a = new BN('1234');
-      const b = new BN('5678');
+      it('reverts if subtraction result would be negative', async function () {
+        const a = new BN('1234');
+        const b = new BN('5678');
 
-      await expectRevert(this.safeMath.sub(a, b), 'SafeMath: subtraction overflow');
+        expectStruct(await this.safeMath.trySub(a, b), { flag: false, value: '0' });
+      });
     });
-  });
 
-  describe('mul', function () {
-    it('multiplies correctly', async function () {
-      const a = new BN('1234');
-      const b = new BN('5678');
+    describe('mul', function () {
+      it('multiplies correctly', async function () {
+        const a = new BN('1234');
+        const b = new BN('5678');
+
+        testCommutativeIterable(this.safeMath.tryMul, a, b, { flag: true, value: a.mul(b) });
+      });
+
+      it('multiplies by zero correctly', async function () {
+        const a = new BN('0');
+        const b = new BN('5678');
+
+        testCommutativeIterable(this.safeMath.tryMul, a, b, { flag: true, value: a.mul(b) });
+      });
 
-      await testCommutative(this.safeMath.mul, a, b, a.mul(b));
+      it('reverts on multiplication overflow', async function () {
+        const a = MAX_UINT256;
+        const b = new BN('2');
+
+        testCommutativeIterable(this.safeMath.tryMul, a, b, { flag: false, value: '0' });
+      });
     });
 
-    it('multiplies by zero correctly', async function () {
-      const a = new BN('0');
-      const b = new BN('5678');
+    describe('div', function () {
+      it('divides correctly', async function () {
+        const a = new BN('5678');
+        const b = new BN('5678');
 
-      await testCommutative(this.safeMath.mul, a, b, '0');
+        expectStruct(await this.safeMath.tryDiv(a, b), { flag: true, value: a.div(b) });
+      });
+
+      it('divides zero correctly', async function () {
+        const a = new BN('0');
+        const b = new BN('5678');
+
+        expectStruct(await this.safeMath.tryDiv(a, b), { flag: true, value: a.div(b) });
+      });
+
+      it('returns complete number result on non-even division', async function () {
+        const a = new BN('7000');
+        const b = new BN('5678');
+
+        expectStruct(await this.safeMath.tryDiv(a, b), { flag: true, value: a.div(b) });
+      });
+
+      it('reverts on division by zero', async function () {
+        const a = new BN('5678');
+        const b = new BN('0');
+
+        expectStruct(await this.safeMath.tryDiv(a, b), { flag: false, value: '0' });
+      });
     });
 
-    it('reverts on multiplication overflow', async function () {
-      const a = MAX_UINT256;
-      const b = new BN('2');
+    describe('mod', function () {
+      describe('modulos correctly', async function () {
+        it('when the dividend is smaller than the divisor', async function () {
+          const a = new BN('284');
+          const b = new BN('5678');
+
+          expectStruct(await this.safeMath.tryMod(a, b), { flag: true, value: a.mod(b) });
+        });
+
+        it('when the dividend is equal to the divisor', async function () {
+          const a = new BN('5678');
+          const b = new BN('5678');
+
+          expectStruct(await this.safeMath.tryMod(a, b), { flag: true, value: a.mod(b) });
+        });
 
-      await testFailsCommutative(this.safeMath.mul, a, b, 'SafeMath: multiplication overflow');
+        it('when the dividend is larger than the divisor', async function () {
+          const a = new BN('7000');
+          const b = new BN('5678');
+
+          expectStruct(await this.safeMath.tryMod(a, b), { flag: true, value: a.mod(b) });
+        });
+
+        it('when the dividend is a multiple of the divisor', async function () {
+          const a = new BN('17034'); // 17034 == 5678 * 3
+          const b = new BN('5678');
+
+          expectStruct(await this.safeMath.tryMod(a, b), { flag: true, value: a.mod(b) });
+        });
+      });
+
+      it('reverts with a 0 divisor', async function () {
+        const a = new BN('5678');
+        const b = new BN('0');
+
+        expectStruct(await this.safeMath.tryMod(a, b), { flag: false, value: '0' });
+      });
     });
   });
 
-  describe('div', function () {
-    it('divides correctly', async function () {
-      const a = new BN('5678');
-      const b = new BN('5678');
+  describe('with default revert message', function () {
+    describe('add', function () {
+      it('adds correctly', async function () {
+        const a = new BN('5678');
+        const b = new BN('1234');
+
+        await testCommutative(this.safeMath.add, a, b, a.add(b));
+      });
+
+      it('reverts on addition overflow', async function () {
+        const a = MAX_UINT256;
+        const b = new BN('1');
+
+        await testFailsCommutative(this.safeMath.add, a, b, 'SafeMath: addition overflow');
+      });
+    });
+
+    describe('sub', function () {
+      it('subtracts correctly', async function () {
+        const a = new BN('5678');
+        const b = new BN('1234');
+
+        expect(await this.safeMath.sub(a, b)).to.be.bignumber.equal(a.sub(b));
+      });
+
+      it('reverts if subtraction result would be negative', async function () {
+        const a = new BN('1234');
+        const b = new BN('5678');
 
-      expect(await this.safeMath.div(a, b)).to.be.bignumber.equal(a.div(b));
+        await expectRevert(this.safeMath.sub(a, b), 'SafeMath: subtraction overflow');
+      });
     });
 
-    it('divides zero correctly', async function () {
-      const a = new BN('0');
-      const b = new BN('5678');
+    describe('mul', function () {
+      it('multiplies correctly', async function () {
+        const a = new BN('1234');
+        const b = new BN('5678');
+
+        await testCommutative(this.safeMath.mul, a, b, a.mul(b));
+      });
+
+      it('multiplies by zero correctly', async function () {
+        const a = new BN('0');
+        const b = new BN('5678');
+
+        await testCommutative(this.safeMath.mul, a, b, '0');
+      });
 
-      expect(await this.safeMath.div(a, b)).to.be.bignumber.equal('0');
+      it('reverts on multiplication overflow', async function () {
+        const a = MAX_UINT256;
+        const b = new BN('2');
+
+        await testFailsCommutative(this.safeMath.mul, a, b, 'SafeMath: multiplication overflow');
+      });
     });
 
-    it('returns complete number result on non-even division', async function () {
-      const a = new BN('7000');
-      const b = new BN('5678');
+    describe('div', function () {
+      it('divides correctly', async function () {
+        const a = new BN('5678');
+        const b = new BN('5678');
+
+        expect(await this.safeMath.div(a, b)).to.be.bignumber.equal(a.div(b));
+      });
+
+      it('divides zero correctly', async function () {
+        const a = new BN('0');
+        const b = new BN('5678');
+
+        expect(await this.safeMath.div(a, b)).to.be.bignumber.equal('0');
+      });
+
+      it('returns complete number result on non-even division', async function () {
+        const a = new BN('7000');
+        const b = new BN('5678');
+
+        expect(await this.safeMath.div(a, b)).to.be.bignumber.equal('1');
+      });
+
+      it('reverts on division by zero', async function () {
+        const a = new BN('5678');
+        const b = new BN('0');
 
-      expect(await this.safeMath.div(a, b)).to.be.bignumber.equal('1');
+        await expectRevert(this.safeMath.div(a, b), 'SafeMath: division by zero');
+      });
     });
 
-    it('reverts on division by zero', async function () {
-      const a = new BN('5678');
-      const b = new BN('0');
+    describe('mod', function () {
+      describe('modulos correctly', async function () {
+        it('when the dividend is smaller than the divisor', async function () {
+          const a = new BN('284');
+          const b = new BN('5678');
+
+          expect(await this.safeMath.mod(a, b)).to.be.bignumber.equal(a.mod(b));
+        });
+
+        it('when the dividend is equal to the divisor', async function () {
+          const a = new BN('5678');
+          const b = new BN('5678');
+
+          expect(await this.safeMath.mod(a, b)).to.be.bignumber.equal(a.mod(b));
+        });
 
-      await expectRevert(this.safeMath.div(a, b), 'SafeMath: division by zero');
+        it('when the dividend is larger than the divisor', async function () {
+          const a = new BN('7000');
+          const b = new BN('5678');
+
+          expect(await this.safeMath.mod(a, b)).to.be.bignumber.equal(a.mod(b));
+        });
+
+        it('when the dividend is a multiple of the divisor', async function () {
+          const a = new BN('17034'); // 17034 == 5678 * 3
+          const b = new BN('5678');
+
+          expect(await this.safeMath.mod(a, b)).to.be.bignumber.equal(a.mod(b));
+        });
+      });
+
+      it('reverts with a 0 divisor', async function () {
+        const a = new BN('5678');
+        const b = new BN('0');
+
+        await expectRevert(this.safeMath.mod(a, b), 'SafeMath: modulo by zero');
+      });
     });
   });
 
-  describe('mod', function () {
-    describe('modulos correctly', async function () {
-      it('when the dividend is smaller than the divisor', async function () {
-        const a = new BN('284');
+  describe('with custom revert message', function () {
+    describe('sub', function () {
+      it('subtracts correctly', async function () {
+        const a = new BN('5678');
+        const b = new BN('1234');
+
+        expect(await this.safeMath.subWithMessage(a, b, 'MyErrorMessage')).to.be.bignumber.equal(a.sub(b));
+      });
+
+      it('reverts if subtraction result would be negative', async function () {
+        const a = new BN('1234');
         const b = new BN('5678');
 
-        expect(await this.safeMath.mod(a, b)).to.be.bignumber.equal(a.mod(b));
+        await expectRevert(this.safeMath.subWithMessage(a, b, 'MyErrorMessage'), 'MyErrorMessage');
       });
+    });
 
-      it('when the dividend is equal to the divisor', async function () {
+    describe('div', function () {
+      it('divides correctly', async function () {
         const a = new BN('5678');
         const b = new BN('5678');
 
-        expect(await this.safeMath.mod(a, b)).to.be.bignumber.equal(a.mod(b));
+        expect(await this.safeMath.divWithMessage(a, b, 'MyErrorMessage')).to.be.bignumber.equal(a.div(b));
       });
 
-      it('when the dividend is larger than the divisor', async function () {
-        const a = new BN('7000');
+      it('divides zero correctly', async function () {
+        const a = new BN('0');
         const b = new BN('5678');
 
-        expect(await this.safeMath.mod(a, b)).to.be.bignumber.equal(a.mod(b));
+        expect(await this.safeMath.divWithMessage(a, b, 'MyErrorMessage')).to.be.bignumber.equal('0');
       });
 
-      it('when the dividend is a multiple of the divisor', async function () {
-        const a = new BN('17034'); // 17034 == 5678 * 3
+      it('returns complete number result on non-even division', async function () {
+        const a = new BN('7000');
         const b = new BN('5678');
 
-        expect(await this.safeMath.mod(a, b)).to.be.bignumber.equal(a.mod(b));
+        expect(await this.safeMath.divWithMessage(a, b, 'MyErrorMessage')).to.be.bignumber.equal('1');
+      });
+
+      it('reverts on division by zero', async function () {
+        const a = new BN('5678');
+        const b = new BN('0');
+
+        await expectRevert(this.safeMath.divWithMessage(a, b, 'MyErrorMessage'), 'MyErrorMessage');
       });
     });
 
-    it('reverts with a 0 divisor', async function () {
-      const a = new BN('5678');
-      const b = new BN('0');
+    describe('mod', function () {
+      describe('modulos correctly', async function () {
+        it('when the dividend is smaller than the divisor', async function () {
+          const a = new BN('284');
+          const b = new BN('5678');
+
+          expect(await this.safeMath.modWithMessage(a, b, 'MyErrorMessage')).to.be.bignumber.equal(a.mod(b));
+        });
+
+        it('when the dividend is equal to the divisor', async function () {
+          const a = new BN('5678');
+          const b = new BN('5678');
+
+          expect(await this.safeMath.modWithMessage(a, b, 'MyErrorMessage')).to.be.bignumber.equal(a.mod(b));
+        });
+
+        it('when the dividend is larger than the divisor', async function () {
+          const a = new BN('7000');
+          const b = new BN('5678');
+
+          expect(await this.safeMath.modWithMessage(a, b, 'MyErrorMessage')).to.be.bignumber.equal(a.mod(b));
+        });
+
+        it('when the dividend is a multiple of the divisor', async function () {
+          const a = new BN('17034'); // 17034 == 5678 * 3
+          const b = new BN('5678');
+
+          expect(await this.safeMath.modWithMessage(a, b, 'MyErrorMessage')).to.be.bignumber.equal(a.mod(b));
+        });
+      });
+
+      it('reverts with a 0 divisor', async function () {
+        const a = new BN('5678');
+        const b = new BN('0');
+
+        await expectRevert(this.safeMath.modWithMessage(a, b, 'MyErrorMessage'), 'MyErrorMessage');
+      });
+    });
+  });
+
+  describe('memory leakage', function () {
+    it('add', async function () {
+      expect(await this.safeMath.addMemoryCheck()).to.be.bignumber.equal('0');
+    });
+
+    it('sub', async function () {
+      expect(await this.safeMath.subMemoryCheck()).to.be.bignumber.equal('0');
+    });
+
+    it('mul', async function () {
+      expect(await this.safeMath.mulMemoryCheck()).to.be.bignumber.equal('0');
+    });
+
+    it('div', async function () {
+      expect(await this.safeMath.divMemoryCheck()).to.be.bignumber.equal('0');
+    });
 
-      await expectRevert(this.safeMath.mod(a, b), 'SafeMath: modulo by zero');
+    it('mod', async function () {
+      expect(await this.safeMath.modMemoryCheck()).to.be.bignumber.equal('0');
     });
   });
 });

+ 40 - 1
test/utils/EnumerableMap.test.js

@@ -1,4 +1,4 @@
-const { BN, expectEvent } = require('@openzeppelin/test-helpers');
+const { BN, constants, expectEvent, expectRevert } = require('@openzeppelin/test-helpers');
 const { expect } = require('chai');
 
 const zip = require('lodash.zip');
@@ -139,4 +139,43 @@ contract('EnumerableMap', function (accounts) {
       expect(await this.map.contains(keyB)).to.equal(false);
     });
   });
+
+  describe('read', function () {
+    beforeEach(async function () {
+      await this.map.set(keyA, accountA);
+    });
+
+    describe('get', function () {
+      it('existing value', async function () {
+        expect(await this.map.get(keyA)).to.be.equal(accountA);
+      });
+      it('missing value', async function () {
+        await expectRevert(this.map.get(keyB), 'EnumerableMap: nonexistent key');
+      });
+    });
+
+    describe('get with message', function () {
+      it('existing value', async function () {
+        expect(await this.map.getWithMessage(keyA, 'custom error string')).to.be.equal(accountA);
+      });
+      it('missing value', async function () {
+        await expectRevert(this.map.getWithMessage(keyB, 'custom error string'), 'custom error string');
+      });
+    });
+
+    describe('tryGet', function () {
+      it('existing value', async function () {
+        expect(await this.map.tryGet(keyA)).to.be.deep.equal({
+          0: true,
+          1: accountA,
+        });
+      });
+      it('missing value', async function () {
+        expect(await this.map.tryGet(keyB)).to.be.deep.equal({
+          0: false,
+          1: constants.ZERO_ADDRESS,
+        });
+      });
+    });
+  });
 });