Переглянути джерело

Make Multicall context-aware

ernestognw 1 рік тому
батько
коміт
9ce0340466

+ 5 - 0
.changeset/rude-weeks-beg.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': patch
+---
+
+`ERC2771Context` and `Context`: Introduce a `_contextPrefixLength()` getter, used to trim extra information appended to `msg.data`.

+ 5 - 0
.changeset/strong-points-invent.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': patch
+---
+
+`Multicall`: Make aware of non-canonical context (i.e. `msg.sender` is not `_msgSender()`), allowing compatibility with `ERC2771Context`.

+ 20 - 9
contracts/metatx/ERC2771Context.sol

@@ -13,6 +13,10 @@ import {Context} from "../utils/Context.sol";
  * specification adding the address size in bytes (20) to the calldata size. An example of an unexpected
  * behavior could be an unintended fallback (or another function) invocation while trying to invoke the `receive`
  * function only accessible if `msg.data.length == 0`.
+ *
+ * WARNING: The usage of `delegatecall` in this contract is dangerous and may result in context corruption.
+ * Any forwarded request to this contract triggering a `delegatecall` to itself will result in an invalid {_msgSender}
+ * recovery.
  */
 abstract contract ERC2771Context is Context {
     /// @custom:oz-upgrades-unsafe-allow state-variable-immutable
@@ -48,13 +52,11 @@ abstract contract ERC2771Context is Context {
      * a call is not performed by the trusted forwarder or the calldata length is less than
      * 20 bytes (an address length).
      */
-    function _msgSender() internal view virtual override returns (address sender) {
-        if (isTrustedForwarder(msg.sender) && msg.data.length >= 20) {
-            // The assembly code is more direct than the Solidity version using `abi.decode`.
-            /// @solidity memory-safe-assembly
-            assembly {
-                sender := shr(96, calldataload(sub(calldatasize(), 20)))
-            }
+    function _msgSender() internal view virtual override returns (address) {
+        uint256 calldataLength = msg.data.length;
+        uint256 contextSuffixLength = _contextSuffixLength();
+        if (isTrustedForwarder(msg.sender) && calldataLength >= contextSuffixLength) {
+            return address(bytes20(msg.data[calldataLength - contextSuffixLength:]));
         } else {
             return super._msgSender();
         }
@@ -66,10 +68,19 @@ abstract contract ERC2771Context is Context {
      * 20 bytes (an address length).
      */
     function _msgData() internal view virtual override returns (bytes calldata) {
-        if (isTrustedForwarder(msg.sender) && msg.data.length >= 20) {
-            return msg.data[:msg.data.length - 20];
+        uint256 calldataLength = msg.data.length;
+        uint256 contextSuffixLength = _contextSuffixLength();
+        if (isTrustedForwarder(msg.sender) && calldataLength >= contextSuffixLength) {
+            return msg.data[:calldataLength - contextSuffixLength];
         } else {
             return super._msgData();
         }
     }
+
+    /**
+     * @dev ERC-2771 specifies the context as being a single address (20 bytes).
+     */
+    function _contextSuffixLength() internal view virtual override returns (uint256) {
+        return 20;
+    }
 }

+ 6 - 1
contracts/mocks/ERC2771ContextMock.sol

@@ -4,10 +4,11 @@ pragma solidity ^0.8.20;
 
 import {ContextMock} from "./ContextMock.sol";
 import {Context} from "../utils/Context.sol";
+import {Multicall} from "../utils/Multicall.sol";
 import {ERC2771Context} from "../metatx/ERC2771Context.sol";
 
 // By inheriting from ERC2771Context, Context's internal functions are overridden automatically
-contract ERC2771ContextMock is ContextMock, ERC2771Context {
+contract ERC2771ContextMock is ContextMock, ERC2771Context, Multicall {
     /// @custom:oz-upgrades-unsafe-allow constructor
     constructor(address trustedForwarder) ERC2771Context(trustedForwarder) {
         emit Sender(_msgSender()); // _msgSender() should be accessible during construction
@@ -20,4 +21,8 @@ contract ERC2771ContextMock is ContextMock, ERC2771Context {
     function _msgData() internal view override(Context, ERC2771Context) returns (bytes calldata) {
         return ERC2771Context._msgData();
     }
+
+    function _contextSuffixLength() internal view override(Context, ERC2771Context) returns (uint256) {
+        return ERC2771Context._contextSuffixLength();
+    }
 }

+ 4 - 0
contracts/utils/Context.sol

@@ -21,4 +21,8 @@ abstract contract Context {
     function _msgData() internal view virtual returns (bytes calldata) {
         return msg.data;
     }
+
+    function _contextSuffixLength() internal view virtual returns (uint256) {
+        return 0;
+    }
 }

+ 16 - 2
contracts/utils/Multicall.sol

@@ -4,19 +4,33 @@
 pragma solidity ^0.8.20;
 
 import {Address} from "./Address.sol";
+import {Context} from "./Context.sol";
 
 /**
  * @dev Provides a function to batch together multiple calls in a single external call.
+ *
+ * Consider any assumption about calldata validation performed by the sender may be violated if it's not especially
+ * careful about sending transactions invoking {multicall}. For example, a relay address that filters function
+ * selectors won't filter calls nested within a {multicall} operation.
+ *
+ * NOTE: Since 5.0.1 and 4.9.4, this contract identifies non-canonical contexts (i.e. `msg.sender` is not {_msgSender}).
+ * If a non-canonical context is identified, the following self `delegatecall` appends the last bytes of `msg.data`
+ * to the subcall. This makes it safe to use with {ERC2771Context}. Contexts that don't affect the resolution of
+ * {_msgSender} are not propagated to subcalls.
  */
-abstract contract Multicall {
+abstract contract Multicall is Context {
     /**
      * @dev Receives and executes a batch of function calls on this contract.
      * @custom:oz-upgrades-unsafe-allow-reachable delegatecall
      */
     function multicall(bytes[] calldata data) external virtual returns (bytes[] memory results) {
+        bytes memory context = msg.sender == _msgSender()
+            ? new bytes(0)
+            : msg.data[msg.data.length - _contextSuffixLength():];
+
         results = new bytes[](data.length);
         for (uint256 i = 0; i < data.length; i++) {
-            results[i] = Address.functionDelegateCall(address(this), data[i]);
+            results[i] = Address.functionDelegateCall(address(this), bytes.concat(data[i], context));
         }
         return results;
     }

+ 55 - 1
test/metatx/ERC2771Context.test.js

@@ -13,7 +13,7 @@ const ContextMockCaller = artifacts.require('ContextMockCaller');
 const { shouldBehaveLikeRegularContext } = require('../utils/Context.behavior');
 
 contract('ERC2771Context', function (accounts) {
-  const [, trustedForwarder] = accounts;
+  const [, trustedForwarder, other] = accounts;
 
   beforeEach(async function () {
     this.forwarder = await ERC2771Forwarder.new('ERC2771Forwarder');
@@ -131,4 +131,58 @@ contract('ERC2771Context', function (accounts) {
       await expectEvent(receipt, 'DataShort', { data });
     });
   });
+
+  it('multicall poison attack', async function () {
+    const attacker = Wallet.generate();
+    const attackerAddress = attacker.getChecksumAddressString();
+    const nonce = await this.forwarder.nonces(attackerAddress);
+
+    const msgSenderCall = web3.eth.abi.encodeFunctionCall(
+      {
+        name: 'msgSender',
+        type: 'function',
+        inputs: [],
+      },
+      [],
+    );
+
+    const data = web3.eth.abi.encodeFunctionCall(
+      {
+        name: 'multicall',
+        type: 'function',
+        inputs: [
+          {
+            internalType: 'bytes[]',
+            name: 'data',
+            type: 'bytes[]',
+          },
+        ],
+      },
+      [[web3.utils.encodePacked({ value: msgSenderCall, type: 'bytes' }, { value: other, type: 'address' })]],
+    );
+
+    const req = {
+      from: attackerAddress,
+      to: this.recipient.address,
+      value: '0',
+      gas: '100000',
+      data,
+      nonce: Number(nonce),
+      deadline: MAX_UINT48,
+    };
+
+    req.signature = await ethSigUtil.signTypedMessage(attacker.getPrivateKey(), {
+      data: {
+        types: this.types,
+        domain: this.domain,
+        primaryType: 'ForwardRequest',
+        message: req,
+      },
+    });
+
+    expect(await this.forwarder.verify(req)).to.equal(true);
+
+    const receipt = await this.forwarder.execute(req);
+    await expectEvent.inTransaction(receipt.tx, ERC2771ContextMock, 'Sender', { sender: attackerAddress });
+  });
 });