Forráskód Böngészése

feat: add wrapper function for low level calls (#2264)

* feat: add wrapper function for low level calls

* add error message parameter

* adding unit tests and required mocks

* implement error message on SafeERC20

* fixed variable name in tests

* Add missing tests

* Improve docs.

* Add functionCallWithValue

* Add functionCallWithValue

* Skip balance check on non-value functionCall variants

* Increase out of gas test timeout

* Fix compile errors

* Apply suggestions from code review

Co-authored-by: Francisco Giordano <frangio.1@gmail.com>

* Add missing tests

* Add changelog entry

Co-authored-by: Nicolás Venturo <nicolas.venturo@gmail.com>
Co-authored-by: Francisco Giordano <frangio.1@gmail.com>
Julian M. Rodriguez 5 éve
szülő
commit
8b58fc7191

+ 1 - 0
CHANGELOG.md

@@ -4,6 +4,7 @@
 
 ### New features
  * `SafeCast`: added functions to downcast signed integers (e.g. `toInt32`), improving usability of `SignedSafeMath`. ([#2243](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2243))
+ * `functionCall`: new helpers that replicate Solidity's function call semantics, reducing the need to rely on `call`. ([#2264](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2264))
  * `ERC1155`: added support for a base implementation, non-standard extensions and a preset contract. ([#2014](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2014), [#2230](https://github.com/OpenZeppelin/openzeppelin-contracts/issues/2230))
 
 ### Improvements

+ 14 - 0
contracts/mocks/AddressImpl.sol

@@ -5,6 +5,8 @@ pragma solidity ^0.6.0;
 import "../utils/Address.sol";
 
 contract AddressImpl {
+    event CallReturnValue(string data);
+
     function isContract(address account) external view returns (bool) {
         return Address.isContract(account);
     }
@@ -13,6 +15,18 @@ contract AddressImpl {
         Address.sendValue(receiver, amount);
     }
 
+    function functionCall(address target, bytes calldata data) external {
+        bytes memory returnData = Address.functionCall(target, data);
+
+        emit CallReturnValue(abi.decode(returnData, (string)));
+    }
+
+    function functionCallWithValue(address target, bytes calldata data, uint256 value) external payable {
+        bytes memory returnData = Address.functionCallWithValue(target, data, value);
+
+        emit CallReturnValue(abi.decode(returnData, (string)));
+    }
+
     // sendValue's tests require the contract to hold Ether
     receive () external payable { }
 }

+ 40 - 0
contracts/mocks/CallReceiverMock.sol

@@ -0,0 +1,40 @@
+// SPDX-License-Identifier: MIT
+
+pragma solidity ^0.6.0;
+
+contract CallReceiverMock {
+
+    event MockFunctionCalled();
+
+    uint256[] private _array;
+
+    function mockFunction() public payable returns (string memory) {
+        emit MockFunctionCalled();
+
+        return "0x1234";
+    }
+
+    function mockFunctionNonPayable() public returns (string memory) {
+        emit MockFunctionCalled();
+
+        return "0x1234";
+    }
+
+    function mockFunctionRevertsNoReason() public payable {
+        revert();
+    }
+
+    function mockFunctionRevertsReason() public payable {
+        revert("CallReceiverMock: reverting");
+    }
+
+    function mockFunctionThrows() public payable {
+        assert(false);
+    }
+
+    function mockFunctionOutOfGas() public payable {
+        for (uint256 i = 0; ; ++i) {
+            _array.push(i);
+        }
+    }
+}

+ 3 - 12
contracts/token/ERC20/SafeERC20.sol

@@ -63,19 +63,10 @@ library SafeERC20 {
      */
     function _callOptionalReturn(IERC20 token, bytes memory data) private {
         // We need to perform a low level call here, to bypass Solidity's return data size checking mechanism, since
-        // we're implementing it ourselves.
-
-        // A Solidity high level call has three parts:
-        //  1. The target address is checked to verify it contains contract code
-        //  2. The call itself is made, and success asserted
-        //  3. The return value is decoded, which in turn checks the size of the returned data.
-        // solhint-disable-next-line max-line-length
-        require(address(token).isContract(), "SafeERC20: call to non-contract");
-
-        // solhint-disable-next-line avoid-low-level-calls
-        (bool success, bytes memory returndata) = address(token).call(data);
-        require(success, "SafeERC20: low-level call failed");
+        // we're implementing it ourselves. We use {Address.functionCall} to perform this call, which verifies that
+        // the target address contains contract code and also asserts for success in the low-level call.
 
+        bytes memory returndata = address(token).functionCall(data, "SafeERC20: low-level call failed");
         if (returndata.length > 0) { // Return data is optional
             // solhint-disable-next-line max-line-length
             require(abi.decode(returndata, (bool)), "SafeERC20: ERC20 operation did not succeed");

+ 4 - 17
contracts/token/ERC721/ERC721.sol

@@ -437,28 +437,15 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable
         if (!to.isContract()) {
             return true;
         }
-        // solhint-disable-next-line avoid-low-level-calls
-        (bool success, bytes memory returndata) = to.call(abi.encodeWithSelector(
+        bytes memory returndata = to.functionCall(abi.encodeWithSelector(
             IERC721Receiver(to).onERC721Received.selector,
             _msgSender(),
             from,
             tokenId,
             _data
-        ));
-        if (!success) {
-            if (returndata.length > 0) {
-                // solhint-disable-next-line no-inline-assembly
-                assembly {
-                    let returndata_size := mload(returndata)
-                    revert(add(32, returndata), returndata_size)
-                }
-            } else {
-                revert("ERC721: transfer to non ERC721Receiver implementer");
-            }
-        } else {
-            bytes4 retval = abi.decode(returndata, (bytes4));
-            return (retval == _ERC721_RECEIVED);
-        }
+        ), "ERC721: transfer to non ERC721Receiver implementer");
+        bytes4 retval = abi.decode(returndata, (bytes4));
+        return (retval == _ERC721_RECEIVED);
     }
 
     function _approve(address to, uint256 tokenId) private {

+ 75 - 0
contracts/utils/Address.sol

@@ -57,4 +57,79 @@ library Address {
         (bool success, ) = recipient.call{ value: amount }("");
         require(success, "Address: unable to send value, recipient may have reverted");
     }
+
+    /**
+     * @dev Performs a Solidity function call using a low level `call`. A
+     * plain`call` is an unsafe replacement for a function call: use this
+     * function instead.
+     *
+     * If `target` reverts with a revert reason, it is bubbled up by this
+     * function (like regular Solidity function calls).
+     *
+     * Requirements:
+     *
+     * - `target` must be a contract.
+     * - calling `target` with `data` must not revert.
+     */
+    function functionCall(address target, bytes memory data) internal returns (bytes memory) {
+      return functionCall(target, data, "Address: low-level call failed");
+    }
+
+    /**
+     * @dev Same as {Address-functionCall-address-bytes-}, but with
+     * `errorMessage` as a fallback revert reason when `target` reverts.
+     */
+    function functionCall(address target, bytes memory data, string memory errorMessage) internal returns (bytes memory) {
+        return _functionCallWithValue(target, data, 0, errorMessage);
+    }
+
+    /**
+     * @dev Performs a Solidity function call using a low level `call`,
+     * transferring `value` wei. A plain`call` is an unsafe replacement for a
+     * function call: use this function instead.
+     *
+     * If `target` reverts with a revert reason, it is bubbled up by this
+     * function (like regular Solidity function calls).
+     *
+     * Requirements:
+     *
+     * - `target` must be a contract.
+     * - the calling contract must have an ETH balance of at least `value`.
+     * - calling `target` with `data` must not revert.
+     */
+    function functionCallWithValue(address target, bytes memory data, uint256 value) internal returns (bytes memory) {
+        return functionCallWithValue(target, data, value, "Address: low-level call with value failed");
+    }
+
+    /**
+     * @dev Same as {Address-functionCallWithValue-address-bytes-uint256-}, but
+     * with `errorMessage` as a fallback revert reason when `target` reverts.
+     */
+    function functionCallWithValue(address target, bytes memory data, uint256 value, string memory errorMessage) internal returns (bytes memory) {
+        require(address(this).balance >= value, "Address: insufficient balance for call");
+        return _functionCallWithValue(target, data, value, errorMessage);
+    }
+
+    function _functionCallWithValue(address target, bytes memory data, uint256 weiValue, string memory errorMessage) private returns (bytes memory) {
+        require(isContract(target), "Address: call to non-contract");
+
+        // solhint-disable-next-line avoid-low-level-calls
+        (bool success, bytes memory returndata) = target.call{ value: weiValue }(data);
+        if (success) {
+            return returndata;
+        } else {
+            // Look for revert reason and bubble it up if present
+            if (returndata.length > 0) {
+                // The easiest way to bubble the revert reason is using memory via assembly
+
+                // solhint-disable-next-line no-inline-assembly
+                assembly {
+                    let returndata_size := mload(returndata)
+                    revert(add(32, returndata), returndata_size)
+                }
+            } else {
+                revert(errorMessage);
+            }
+        }
+    }
 }

+ 1 - 1
test/token/ERC20/SafeERC20.test.js

@@ -15,7 +15,7 @@ describe('SafeERC20', function () {
       this.wrapper = await SafeERC20Wrapper.new(hasNoCode);
     });
 
-    shouldRevertOnAllCalls('SafeERC20: call to non-contract');
+    shouldRevertOnAllCalls('Address: call to non-contract');
   });
 
   describe('with token that returns false on all calls', function () {

+ 191 - 2
test/utils/Address.test.js

@@ -1,10 +1,11 @@
-const { accounts, contract } = require('@openzeppelin/test-environment');
+const { accounts, contract, web3 } = require('@openzeppelin/test-environment');
 
-const { balance, ether, expectRevert, send } = require('@openzeppelin/test-helpers');
+const { balance, ether, expectRevert, send, expectEvent } = require('@openzeppelin/test-helpers');
 const { expect } = require('chai');
 
 const AddressImpl = contract.fromArtifact('AddressImpl');
 const EtherReceiver = contract.fromArtifact('EtherReceiverMock');
+const CallReceiverMock = contract.fromArtifact('CallReceiverMock');
 
 describe('Address', function () {
   const [ recipient, other ] = accounts;
@@ -90,4 +91,192 @@ describe('Address', function () {
       });
     });
   });
+
+  describe('functionCall', function () {
+    beforeEach(async function () {
+      this.contractRecipient = await CallReceiverMock.new();
+    });
+
+    context('with valid contract receiver', function () {
+      it('calls the requested function', async function () {
+        const abiEncodedCall = web3.eth.abi.encodeFunctionCall({
+          name: 'mockFunction',
+          type: 'function',
+          inputs: [],
+        }, []);
+
+        const receipt = await this.mock.functionCall(this.contractRecipient.address, abiEncodedCall);
+
+        expectEvent(receipt, 'CallReturnValue', { data: '0x1234' });
+        await expectEvent.inTransaction(receipt.tx, CallReceiverMock, 'MockFunctionCalled');
+      });
+
+      it('reverts when the called function reverts with no reason', async function () {
+        const abiEncodedCall = web3.eth.abi.encodeFunctionCall({
+          name: 'mockFunctionRevertsNoReason',
+          type: 'function',
+          inputs: [],
+        }, []);
+
+        await expectRevert(
+          this.mock.functionCall(this.contractRecipient.address, abiEncodedCall),
+          'Address: low-level call failed'
+        );
+      });
+
+      it('reverts when the called function reverts, bubbling up the revert reason', async function () {
+        const abiEncodedCall = web3.eth.abi.encodeFunctionCall({
+          name: 'mockFunctionRevertsReason',
+          type: 'function',
+          inputs: [],
+        }, []);
+
+        await expectRevert(
+          this.mock.functionCall(this.contractRecipient.address, abiEncodedCall),
+          'CallReceiverMock: reverting'
+        );
+      });
+
+      it('reverts when the called function runs out of gas', async function () {
+        const abiEncodedCall = web3.eth.abi.encodeFunctionCall({
+          name: 'mockFunctionOutOfGas',
+          type: 'function',
+          inputs: [],
+        }, []);
+
+        await expectRevert(
+          this.mock.functionCall(this.contractRecipient.address, abiEncodedCall),
+          'Address: low-level call failed'
+        );
+      }).timeout(5000);
+
+      it('reverts when the called function throws', async function () {
+        const abiEncodedCall = web3.eth.abi.encodeFunctionCall({
+          name: 'mockFunctionThrows',
+          type: 'function',
+          inputs: [],
+        }, []);
+
+        await expectRevert(
+          this.mock.functionCall(this.contractRecipient.address, abiEncodedCall),
+          'Address: low-level call failed'
+        );
+      });
+
+      it('reverts when function does not exist', async function () {
+        const abiEncodedCall = web3.eth.abi.encodeFunctionCall({
+          name: 'mockFunctionDoesNotExist',
+          type: 'function',
+          inputs: [],
+        }, []);
+
+        await expectRevert(
+          this.mock.functionCall(this.contractRecipient.address, abiEncodedCall),
+          'Address: low-level call failed'
+        );
+      });
+    });
+
+    context('with non-contract receiver', function () {
+      it('reverts when address is not a contract', async function () {
+        const [ recipient ] = accounts;
+        const abiEncodedCall = web3.eth.abi.encodeFunctionCall({
+          name: 'mockFunction',
+          type: 'function',
+          inputs: [],
+        }, []);
+        await expectRevert(this.mock.functionCall(recipient, abiEncodedCall), 'Address: call to non-contract');
+      });
+    });
+  });
+
+  describe('functionCallWithValue', function () {
+    beforeEach(async function () {
+      this.contractRecipient = await CallReceiverMock.new();
+    });
+
+    context('with zero value', function () {
+      it('calls the requested function', async function () {
+        const abiEncodedCall = web3.eth.abi.encodeFunctionCall({
+          name: 'mockFunction',
+          type: 'function',
+          inputs: [],
+        }, []);
+
+        const receipt = await this.mock.functionCallWithValue(this.contractRecipient.address, abiEncodedCall, 0);
+
+        expectEvent(receipt, 'CallReturnValue', { data: '0x1234' });
+        await expectEvent.inTransaction(receipt.tx, CallReceiverMock, 'MockFunctionCalled');
+      });
+    });
+
+    context('with non-zero value', function () {
+      const amount = ether('1.2');
+
+      it('reverts if insufficient sender balance', async function () {
+        const abiEncodedCall = web3.eth.abi.encodeFunctionCall({
+          name: 'mockFunction',
+          type: 'function',
+          inputs: [],
+        }, []);
+
+        await expectRevert(
+          this.mock.functionCallWithValue(this.contractRecipient.address, abiEncodedCall, amount),
+          'Address: insufficient balance for call'
+        );
+      });
+
+      it('calls the requested function with existing value', async function () {
+        const abiEncodedCall = web3.eth.abi.encodeFunctionCall({
+          name: 'mockFunction',
+          type: 'function',
+          inputs: [],
+        }, []);
+
+        const tracker = await balance.tracker(this.contractRecipient.address);
+
+        await send.ether(other, this.mock.address, amount);
+        const receipt = await this.mock.functionCallWithValue(this.contractRecipient.address, abiEncodedCall, amount);
+
+        expect(await tracker.delta()).to.be.bignumber.equal(amount);
+
+        expectEvent(receipt, 'CallReturnValue', { data: '0x1234' });
+        await expectEvent.inTransaction(receipt.tx, CallReceiverMock, 'MockFunctionCalled');
+      });
+
+      it('calls the requested function with transaction funds', async function () {
+        const abiEncodedCall = web3.eth.abi.encodeFunctionCall({
+          name: 'mockFunction',
+          type: 'function',
+          inputs: [],
+        }, []);
+
+        const tracker = await balance.tracker(this.contractRecipient.address);
+
+        expect(await balance.current(this.mock.address)).to.be.bignumber.equal('0');
+        const receipt = await this.mock.functionCallWithValue(
+          this.contractRecipient.address, abiEncodedCall, amount, { from: other, value: amount }
+        );
+
+        expect(await tracker.delta()).to.be.bignumber.equal(amount);
+
+        expectEvent(receipt, 'CallReturnValue', { data: '0x1234' });
+        await expectEvent.inTransaction(receipt.tx, CallReceiverMock, 'MockFunctionCalled');
+      });
+
+      it('reverts when calling non-payable functions', async function () {
+        const abiEncodedCall = web3.eth.abi.encodeFunctionCall({
+          name: 'mockFunctionNonPayable',
+          type: 'function',
+          inputs: [],
+        }, []);
+
+        await send.ether(other, this.mock.address, amount);
+        await expectRevert(
+          this.mock.functionCallWithValue(this.contractRecipient.address, abiEncodedCall, amount),
+          'Address: low-level call with value failed'
+        );
+      });
+    });
+  });
 });