Эх сурвалжийг харах

Add Calldata variants of ECDSA.recover, ECDSA.tryRecover and SignatureChecker.isValidSignatureNow (#5788)

Hadrien Croubois 2 сар өмнө
parent
commit
bc8f775df2

+ 5 - 0
.changeset/violet-turtles-like.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`ECDSA`: Add `recoverCalldata` and `tryRecoverCalldata`, variants of `recover` and `tryRecover` that are more efficient when signatures are in calldata.

+ 5 - 0
.changeset/whole-plums-speak.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`SignatureChecker`: Add `isValidSignatureNowCalldata(address,bytes32,bytes calldata)` for efficient processing of calldata signatures.

+ 33 - 0
contracts/utils/cryptography/ECDSA.sol

@@ -74,6 +74,30 @@ library ECDSA {
         }
     }
 
+    /**
+     * @dev Variant of {tryRecover} that takes a signature in calldata
+     */
+    function tryRecoverCalldata(
+        bytes32 hash,
+        bytes calldata signature
+    ) internal pure returns (address recovered, RecoverError err, bytes32 errArg) {
+        if (signature.length == 65) {
+            bytes32 r;
+            bytes32 s;
+            uint8 v;
+            // ecrecover takes the signature parameters, calldata slices would work here, but are
+            // significantly more expensive (length check) than using calldataload in assembly.
+            assembly ("memory-safe") {
+                r := calldataload(signature.offset)
+                s := calldataload(add(signature.offset, 0x20))
+                v := byte(0, calldataload(add(signature.offset, 0x40)))
+            }
+            return tryRecover(hash, v, r, s);
+        } else {
+            return (address(0), RecoverError.InvalidSignatureLength, bytes32(signature.length));
+        }
+    }
+
     /**
      * @dev Returns the address that signed a hashed message (`hash`) with
      * `signature`. This address can then be used for verification purposes.
@@ -94,6 +118,15 @@ library ECDSA {
         return recovered;
     }
 
+    /**
+     * @dev Variant of {recover} that takes a signature in calldata
+     */
+    function recoverCalldata(bytes32 hash, bytes calldata signature) internal pure returns (address) {
+        (address recovered, RecoverError error, bytes32 errorArg) = tryRecoverCalldata(hash, signature);
+        _throwError(error, errorArg);
+        return recovered;
+    }
+
     /**
      * @dev Overload of {ECDSA-tryRecover} that receives the `r` and `vs` short-signature fields separately.
      *

+ 36 - 7
contracts/utils/cryptography/SignatureChecker.sol

@@ -38,6 +38,22 @@ library SignatureChecker {
         }
     }
 
+    /**
+     * @dev Variant of {isValidSignatureNow} that takes a signature in calldata
+     */
+    function isValidSignatureNowCalldata(
+        address signer,
+        bytes32 hash,
+        bytes calldata signature
+    ) internal view returns (bool) {
+        if (signer.code.length == 0) {
+            (address recovered, ECDSA.RecoverError err, ) = ECDSA.tryRecoverCalldata(hash, signature);
+            return err == ECDSA.RecoverError.NoError && recovered == signer;
+        } else {
+            return isValidERC1271SignatureNow(signer, hash, signature);
+        }
+    }
+
     /**
      * @dev Checks if a signature is valid for a given signer and data hash. The signature is validated
      * against the signer smart contract using ERC-1271.
@@ -49,13 +65,26 @@ library SignatureChecker {
         address signer,
         bytes32 hash,
         bytes memory signature
-    ) internal view returns (bool) {
-        (bool success, bytes memory result) = signer.staticcall(
-            abi.encodeCall(IERC1271.isValidSignature, (hash, signature))
-        );
-        return (success &&
-            result.length >= 32 &&
-            abi.decode(result, (bytes32)) == bytes32(IERC1271.isValidSignature.selector));
+    ) internal view returns (bool result) {
+        bytes4 selector = IERC1271.isValidSignature.selector;
+        uint256 length = signature.length;
+
+        assembly ("memory-safe") {
+            // Encoded calldata is :
+            // [ 0x00 - 0x03 ] <selector>
+            // [ 0x04 - 0x23 ] <hash>
+            // [ 0x24 - 0x44 ] <signature offset> (0x40)
+            // [ 0x44 - 0x64 ] <signature length>
+            // [ 0x64 - ...  ] <signature data>
+            let ptr := mload(0x40)
+            mstore(ptr, selector)
+            mstore(add(ptr, 0x04), hash)
+            mstore(add(ptr, 0x24), 0x40)
+            mcopy(add(ptr, 0x44), signature, add(length, 0x20))
+
+            let success := staticcall(gas(), signer, ptr, add(length, 0x64), 0, 0x20)
+            result := and(success, and(gt(returndatasize(), 0x19), eq(mload(0x00), selector)))
+        }
     }
 
     /**

+ 28 - 0
test/utils/cryptography/ECDSA.test.js

@@ -44,6 +44,7 @@ describe('ECDSA', function () {
 
         // Recover the signer address from the generated message and signature.
         expect(await this.mock.$recover(ethers.hashMessage(TEST_MESSAGE), signature)).to.equal(this.signer);
+        expect(await this.mock.$recoverCalldata(ethers.hashMessage(TEST_MESSAGE), signature)).to.equal(this.signer);
       });
 
       it('returns signer address with correct signature for arbitrary length message', async function () {
@@ -52,11 +53,13 @@ describe('ECDSA', function () {
 
         // Recover the signer address from the generated message and signature.
         expect(await this.mock.$recover(ethers.hashMessage(NON_HASH_MESSAGE), signature)).to.equal(this.signer);
+        expect(await this.mock.$recoverCalldata(ethers.hashMessage(NON_HASH_MESSAGE), signature)).to.equal(this.signer);
       });
 
       it('returns a different address', async function () {
         const signature = await this.signer.signMessage(TEST_MESSAGE);
         expect(await this.mock.$recover(WRONG_MESSAGE, signature)).to.not.be.equal(this.signer);
+        expect(await this.mock.$recoverCalldata(WRONG_MESSAGE, signature)).to.not.be.equal(this.signer);
       });
 
       it('reverts with invalid signature', async function () {
@@ -66,6 +69,10 @@ describe('ECDSA', function () {
           this.mock,
           'ECDSAInvalidSignature',
         );
+        await expect(this.mock.$recoverCalldata(TEST_MESSAGE, signature)).to.be.revertedWithCustomError(
+          this.mock,
+          'ECDSAInvalidSignature',
+        );
       });
     });
 
@@ -79,6 +86,7 @@ describe('ECDSA', function () {
         const v = '0x1b'; // 27 = 1b.
         const signature = ethers.concat([signatureWithoutV, v]);
         expect(await this.mock.$recover(TEST_MESSAGE, signature)).to.equal(signer);
+        expect(await this.mock.$recoverCalldata(TEST_MESSAGE, signature)).to.equal(signer);
 
         const { r, s, yParityAndS: vs } = ethers.Signature.from(signature);
         expect(await this.mock.getFunction('$recover(bytes32,uint8,bytes32,bytes32)')(TEST_MESSAGE, v, r, s)).to.equal(
@@ -92,6 +100,7 @@ describe('ECDSA', function () {
         const v = '0x1c'; // 28 = 1c.
         const signature = ethers.concat([signatureWithoutV, v]);
         expect(await this.mock.$recover(TEST_MESSAGE, signature)).to.not.equal(signer);
+        expect(await this.mock.$recoverCalldata(TEST_MESSAGE, signature)).to.not.equal(signer);
 
         const { r, s, yParityAndS: vs } = ethers.Signature.from(signature);
         expect(
@@ -110,6 +119,10 @@ describe('ECDSA', function () {
             this.mock,
             'ECDSAInvalidSignature',
           );
+          await expect(this.mock.$recoverCalldata(TEST_MESSAGE, signature)).to.be.revertedWithCustomError(
+            this.mock,
+            'ECDSAInvalidSignature',
+          );
 
           const { r, s } = ethers.Signature.from(signature);
           await expect(
@@ -126,6 +139,9 @@ describe('ECDSA', function () {
         await expect(this.mock.$recover(TEST_MESSAGE, compactSerialized))
           .to.be.revertedWithCustomError(this.mock, 'ECDSAInvalidSignatureLength')
           .withArgs(64);
+        await expect(this.mock.$recoverCalldata(TEST_MESSAGE, compactSerialized))
+          .to.be.revertedWithCustomError(this.mock, 'ECDSAInvalidSignatureLength')
+          .withArgs(64);
       });
     });
 
@@ -139,6 +155,7 @@ describe('ECDSA', function () {
         const v = '0x1c'; // 28 = 1c.
         const signature = ethers.concat([signatureWithoutV, v]);
         expect(await this.mock.$recover(TEST_MESSAGE, signature)).to.equal(signer);
+        expect(await this.mock.$recoverCalldata(TEST_MESSAGE, signature)).to.equal(signer);
 
         const { r, s, yParityAndS: vs } = ethers.Signature.from(signature);
         expect(await this.mock.getFunction('$recover(bytes32,uint8,bytes32,bytes32)')(TEST_MESSAGE, v, r, s)).to.equal(
@@ -152,6 +169,7 @@ describe('ECDSA', function () {
         const v = '0x1b'; // 27 = 1b.
         const signature = ethers.concat([signatureWithoutV, v]);
         expect(await this.mock.$recover(TEST_MESSAGE, signature)).to.not.equal(signer);
+        expect(await this.mock.$recoverCalldata(TEST_MESSAGE, signature)).to.not.equal(signer);
 
         const { r, s, yParityAndS: vs } = ethers.Signature.from(signature);
         expect(
@@ -170,6 +188,10 @@ describe('ECDSA', function () {
             this.mock,
             'ECDSAInvalidSignature',
           );
+          await expect(this.mock.$recoverCalldata(TEST_MESSAGE, signature)).to.be.revertedWithCustomError(
+            this.mock,
+            'ECDSAInvalidSignature',
+          );
 
           const { r, s } = ethers.Signature.from(signature);
           await expect(
@@ -186,6 +208,9 @@ describe('ECDSA', function () {
         await expect(this.mock.$recover(TEST_MESSAGE, compactSerialized))
           .to.be.revertedWithCustomError(this.mock, 'ECDSAInvalidSignatureLength')
           .withArgs(64);
+        await expect(this.mock.$recoverCalldata(TEST_MESSAGE, compactSerialized))
+          .to.be.revertedWithCustomError(this.mock, 'ECDSAInvalidSignatureLength')
+          .withArgs(64);
       });
     });
 
@@ -202,6 +227,9 @@ describe('ECDSA', function () {
       await expect(this.mock.$recover(message, highSSignature))
         .to.be.revertedWithCustomError(this.mock, 'ECDSAInvalidSignatureS')
         .withArgs(s);
+      await expect(this.mock.$recoverCalldata(message, highSSignature))
+        .to.be.revertedWithCustomError(this.mock, 'ECDSAInvalidSignatureS')
+        .withArgs(s);
       await expect(this.mock.getFunction('$recover(bytes32,uint8,bytes32,bytes32)')(TEST_MESSAGE, v, r, s))
         .to.be.revertedWithCustomError(this.mock, 'ECDSAInvalidSignatureS')
         .withArgs(s);

+ 7 - 1
test/utils/cryptography/SignatureChecker.test.js

@@ -36,23 +36,29 @@ describe('SignatureChecker (ERC1271)', function () {
       await expect(
         this.mock.$isValidSignatureNow(ethers.Typed.address(this.signer.address), TEST_MESSAGE_HASH, this.signature),
       ).to.eventually.be.true;
+      await expect(this.mock.$isValidSignatureNowCalldata(this.signer.address, TEST_MESSAGE_HASH, this.signature)).to
+        .eventually.be.true;
     });
 
     it('with invalid signer', async function () {
       await expect(
         this.mock.$isValidSignatureNow(ethers.Typed.address(this.other.address), TEST_MESSAGE_HASH, this.signature),
       ).to.eventually.be.false;
+      await expect(this.mock.$isValidSignatureNowCalldata(this.other.address, TEST_MESSAGE_HASH, this.signature)).to
+        .eventually.be.false;
     });
 
     it('with invalid signature', async function () {
       await expect(
         this.mock.$isValidSignatureNow(ethers.Typed.address(this.signer.address), WRONG_MESSAGE_HASH, this.signature),
       ).to.eventually.be.false;
+      await expect(this.mock.$isValidSignatureNowCalldata(this.signer.address, WRONG_MESSAGE_HASH, this.signature)).to
+        .eventually.be.false;
     });
   });
 
   describe('ERC1271 wallet', function () {
-    for (const fn of ['isValidERC1271SignatureNow', 'isValidSignatureNow']) {
+    for (const fn of ['isValidERC1271SignatureNow', 'isValidSignatureNow', 'isValidSignatureNowCalldata']) {
       describe(fn, function () {
         it('with matching signer and signature', async function () {
           await expect(