ERC2771Forwarder.t.sol 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. // SPDX-License-Identifier: MIT
  2. pragma solidity ^0.8.20;
  3. import {Test} from "forge-std/Test.sol";
  4. import {ERC2771Forwarder} from "@openzeppelin/contracts/metatx/ERC2771Forwarder.sol";
  5. import {CallReceiverMockTrustingForwarder, CallReceiverMock} from "@openzeppelin/contracts/mocks/CallReceiverMock.sol";
  6. import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol";
  7. import {SafeCast} from "@openzeppelin/contracts/utils/math/SafeCast.sol";
  8. enum TamperType {
  9. FROM,
  10. TO,
  11. VALUE,
  12. DATA,
  13. SIGNATURE
  14. }
  15. contract ERC2771ForwarderMock is ERC2771Forwarder {
  16. constructor(string memory name) ERC2771Forwarder(name) {}
  17. function forwardRequestStructHash(
  18. ERC2771Forwarder.ForwardRequestData calldata request,
  19. uint256 nonce
  20. ) external view returns (bytes32) {
  21. return
  22. _hashTypedDataV4(
  23. keccak256(
  24. abi.encode(
  25. _FORWARD_REQUEST_TYPEHASH,
  26. request.from,
  27. request.to,
  28. request.value,
  29. request.gas,
  30. nonce,
  31. request.deadline,
  32. keccak256(request.data)
  33. )
  34. )
  35. );
  36. }
  37. }
  38. contract ERC2771ForwarderTest is Test {
  39. using ECDSA for bytes32;
  40. ERC2771ForwarderMock internal _erc2771Forwarder;
  41. CallReceiverMockTrustingForwarder internal _receiver;
  42. uint256 internal _signerPrivateKey = 0xA11CE;
  43. address internal _signer = vm.addr(_signerPrivateKey);
  44. uint256 internal constant _MAX_ETHER = 10_000_000; // To avoid overflow
  45. function setUp() public {
  46. _erc2771Forwarder = new ERC2771ForwarderMock("ERC2771Forwarder");
  47. _receiver = new CallReceiverMockTrustingForwarder(address(_erc2771Forwarder));
  48. }
  49. // Forge a new ForwardRequestData
  50. function _forgeRequestData() private view returns (ERC2771Forwarder.ForwardRequestData memory) {
  51. return
  52. _forgeRequestData({
  53. value: 0,
  54. deadline: uint48(block.timestamp + 1),
  55. data: abi.encodeCall(CallReceiverMock.mockFunction, ())
  56. });
  57. }
  58. function _forgeRequestData(
  59. uint256 value,
  60. uint48 deadline,
  61. bytes memory data
  62. ) private view returns (ERC2771Forwarder.ForwardRequestData memory) {
  63. return
  64. ERC2771Forwarder.ForwardRequestData({
  65. from: _signer,
  66. to: address(_receiver),
  67. value: value,
  68. gas: 30000,
  69. deadline: deadline,
  70. data: data,
  71. signature: ""
  72. });
  73. }
  74. // Sign a ForwardRequestData (in place) for a given nonce. Also returns it for convenience.
  75. function _signRequestData(
  76. ERC2771Forwarder.ForwardRequestData memory request,
  77. uint256 nonce
  78. ) private view returns (ERC2771Forwarder.ForwardRequestData memory) {
  79. bytes32 digest = _erc2771Forwarder.forwardRequestStructHash(request, nonce);
  80. (uint8 v, bytes32 r, bytes32 s) = vm.sign(_signerPrivateKey, digest);
  81. request.signature = abi.encodePacked(r, s, v);
  82. return request;
  83. }
  84. // Tamper a ForwardRequestData (in place). Also returns it for convenience.
  85. function _tamperRequestData(
  86. ERC2771Forwarder.ForwardRequestData memory request,
  87. TamperType tamper
  88. ) private returns (ERC2771Forwarder.ForwardRequestData memory) {
  89. if (tamper == TamperType.FROM) request.from = vm.randomAddress();
  90. else if (tamper == TamperType.TO) request.to = vm.randomAddress();
  91. else if (tamper == TamperType.VALUE) request.value = vm.randomUint();
  92. else if (tamper == TamperType.DATA) request.data = vm.randomBytes(4);
  93. else if (tamper == TamperType.SIGNATURE) request.signature = vm.randomBytes(65);
  94. return request;
  95. }
  96. // Predict the revert error for a tampered request, and expect it is emitted.
  97. function _tamperedExpectRevert(
  98. ERC2771Forwarder.ForwardRequestData memory request,
  99. TamperType tamper,
  100. uint256 nonce
  101. ) private returns (ERC2771Forwarder.ForwardRequestData memory) {
  102. if (tamper == TamperType.FROM) nonce = _erc2771Forwarder.nonces(request.from);
  103. // predict revert
  104. if (tamper == TamperType.TO) {
  105. vm.expectRevert(
  106. abi.encodeWithSelector(
  107. ERC2771Forwarder.ERC2771UntrustfulTarget.selector,
  108. request.to,
  109. address(_erc2771Forwarder)
  110. )
  111. );
  112. } else {
  113. (address recovered, , ) = _erc2771Forwarder.forwardRequestStructHash(request, nonce).tryRecover(
  114. request.signature
  115. );
  116. vm.expectRevert(
  117. abi.encodeWithSelector(ERC2771Forwarder.ERC2771ForwarderInvalidSigner.selector, recovered, request.from)
  118. );
  119. }
  120. return request;
  121. }
  122. function testExecuteAvoidsETHStuck(uint256 initialBalance, uint256 value, bool targetReverts) public {
  123. initialBalance = bound(initialBalance, 0, _MAX_ETHER);
  124. value = bound(value, 0, _MAX_ETHER);
  125. // create and sign request
  126. ERC2771Forwarder.ForwardRequestData memory request = _forgeRequestData({
  127. value: value,
  128. deadline: uint48(block.timestamp + 1),
  129. data: targetReverts
  130. ? abi.encodeCall(CallReceiverMock.mockFunctionRevertsNoReason, ())
  131. : abi.encodeCall(CallReceiverMock.mockFunction, ())
  132. });
  133. _signRequestData(request, _erc2771Forwarder.nonces(_signer));
  134. vm.deal(address(_erc2771Forwarder), initialBalance);
  135. vm.deal(address(this), request.value);
  136. if (targetReverts) vm.expectRevert();
  137. _erc2771Forwarder.execute{value: value}(request);
  138. assertEq(address(_erc2771Forwarder).balance, initialBalance);
  139. }
  140. function testExecuteBatchAvoidsETHStuck(uint256 initialBalance, uint256 batchSize, uint256 value) public {
  141. uint256 seed = uint256(keccak256(abi.encodePacked(initialBalance, batchSize, value)));
  142. batchSize = bound(batchSize, 1, 10);
  143. initialBalance = bound(initialBalance, 0, _MAX_ETHER);
  144. value = bound(value, 0, _MAX_ETHER);
  145. address refundReceiver = address(0xebe);
  146. uint256 refundExpected = 0;
  147. uint256 nonce = _erc2771Forwarder.nonces(_signer);
  148. // create an sign array or requests (that may fail)
  149. ERC2771Forwarder.ForwardRequestData[] memory requests = new ERC2771Forwarder.ForwardRequestData[](batchSize);
  150. for (uint256 i = 0; i < batchSize; ++i) {
  151. bool failure = (seed >> i) & 0x1 == 0x1;
  152. requests[i] = _forgeRequestData({
  153. value: value,
  154. deadline: uint48(block.timestamp + 1),
  155. data: failure
  156. ? abi.encodeCall(CallReceiverMock.mockFunctionRevertsNoReason, ())
  157. : abi.encodeCall(CallReceiverMock.mockFunction, ())
  158. });
  159. _signRequestData(requests[i], nonce + i);
  160. refundExpected += SafeCast.toUint(failure) * value;
  161. }
  162. // distribute ether
  163. vm.deal(address(_erc2771Forwarder), initialBalance);
  164. vm.deal(address(this), value * batchSize);
  165. // execute batch
  166. _erc2771Forwarder.executeBatch{value: value * batchSize}(requests, payable(refundReceiver));
  167. // check balances
  168. assertEq(address(_erc2771Forwarder).balance, initialBalance);
  169. assertEq(refundReceiver.balance, refundExpected);
  170. }
  171. function testVerifyTamperedValues(uint8 _tamper) public {
  172. TamperType tamper = _asTamper(_tamper);
  173. // create request, sign, tamper
  174. ERC2771Forwarder.ForwardRequestData memory request = _forgeRequestData();
  175. _signRequestData(request, 0);
  176. _tamperRequestData(request, tamper);
  177. // should not pass verification
  178. assertFalse(_erc2771Forwarder.verify(request));
  179. }
  180. function testExecuteTamperedValues(uint8 _tamper) public {
  181. TamperType tamper = _asTamper(_tamper);
  182. // create request, sign, tamper, expect execution revert
  183. ERC2771Forwarder.ForwardRequestData memory request = _forgeRequestData();
  184. _signRequestData(request, 0);
  185. _tamperRequestData(request, tamper);
  186. _tamperedExpectRevert(request, tamper, 0);
  187. vm.deal(address(this), request.value);
  188. _erc2771Forwarder.execute{value: request.value}(request);
  189. }
  190. function testExecuteBatchTamperedValuesZeroReceiver(uint8 _tamper) public {
  191. TamperType tamper = _asTamper(_tamper);
  192. uint256 nonce = _erc2771Forwarder.nonces(_signer);
  193. // create an sign array or requests
  194. ERC2771Forwarder.ForwardRequestData[] memory requests = new ERC2771Forwarder.ForwardRequestData[](3);
  195. for (uint256 i = 0; i < requests.length; ++i) {
  196. requests[i] = _forgeRequestData({
  197. value: 0,
  198. deadline: uint48(block.timestamp + 1),
  199. data: abi.encodeCall(CallReceiverMock.mockFunction, ())
  200. });
  201. _signRequestData(requests[i], nonce + i);
  202. }
  203. // tamper with request[1] and expect execution revert
  204. _tamperRequestData(requests[1], tamper);
  205. _tamperedExpectRevert(requests[1], tamper, nonce + 1);
  206. vm.deal(address(this), requests[1].value);
  207. _erc2771Forwarder.executeBatch{value: requests[1].value}(requests, payable(address(0)));
  208. }
  209. function testExecuteBatchTamperedValues(uint8 _tamper) public {
  210. TamperType tamper = _asTamper(_tamper);
  211. uint256 nonce = _erc2771Forwarder.nonces(_signer);
  212. // create an sign array or requests
  213. ERC2771Forwarder.ForwardRequestData[] memory requests = new ERC2771Forwarder.ForwardRequestData[](3);
  214. for (uint256 i = 0; i < requests.length; ++i) {
  215. requests[i] = _forgeRequestData({
  216. value: 0,
  217. deadline: uint48(block.timestamp + 1),
  218. data: abi.encodeCall(CallReceiverMock.mockFunction, ())
  219. });
  220. _signRequestData(requests[i], nonce + i);
  221. }
  222. // tamper with request[1]
  223. _tamperRequestData(requests[1], tamper);
  224. // should not revert
  225. vm.expectCall(address(_receiver), abi.encodeCall(CallReceiverMock.mockFunction, ()), 1);
  226. vm.deal(address(this), requests[1].value);
  227. _erc2771Forwarder.executeBatch{value: requests[1].value}(requests, payable(address(0xebe)));
  228. }
  229. function _asTamper(uint8 _tamper) private pure returns (TamperType) {
  230. return TamperType(bound(_tamper, uint8(TamperType.FROM), uint8(TamperType.SIGNATURE)));
  231. }
  232. }