Quellcode durchsuchen

Optimize votes lookups for recent checkpoints (#3673)

Francisco vor 3 Jahren
Ursprung
Commit
e09ccd1449

+ 1 - 0
CHANGELOG.md

@@ -8,6 +8,7 @@
  * `Address`: optimize `functionCall` functions by checking contract size only if there is no returned data. ([#3469](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3469))
  * `GovernorCompatibilityBravo`: remove unused `using` statements. ([#3506](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3506))
  * `ERC20`: optimize `_transfer`, `_mint` and `_burn` by using `unchecked` arithmetic when possible. ([#3513](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3513))
+ * `ERC20Votes`, `ERC721Votes`: optimize `getPastVotes` for looking up recent checkpoints. ([#3673](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3673))
  * `ERC20FlashMint`: add an internal `_flashFee` function for overriding. ([#3551](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3551))
  * `ERC4626`: use the same `decimals()` as the underlying asset by default (if available). ([#3639](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3639))
  * `ERC4626`: add internal `_initialConvertToShares` and `_initialConvertToAssets` functions to customize empty vaults behavior. ([#3639](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3639))

+ 2 - 2
contracts/governance/utils/Votes.sol

@@ -56,7 +56,7 @@ abstract contract Votes is IVotes, Context, EIP712 {
      * - `blockNumber` must have been already mined
      */
     function getPastVotes(address account, uint256 blockNumber) public view virtual override returns (uint256) {
-        return _delegateCheckpoints[account].getAtBlock(blockNumber);
+        return _delegateCheckpoints[account].getAtProbablyRecentBlock(blockNumber);
     }
 
     /**
@@ -72,7 +72,7 @@ abstract contract Votes is IVotes, Context, EIP712 {
      */
     function getPastTotalSupply(uint256 blockNumber) public view virtual override returns (uint256) {
         require(blockNumber < block.number, "Votes: block not yet mined");
-        return _totalCheckpoints.getAtBlock(blockNumber);
+        return _totalCheckpoints.getAtProbablyRecentBlock(blockNumber);
     }
 
     /**

+ 2 - 10
contracts/mocks/CheckpointsMock.sol

@@ -22,8 +22,8 @@ contract CheckpointsMock {
         return _totalCheckpoints.getAtBlock(blockNumber);
     }
 
-    function getAtRecentBlock(uint256 blockNumber) public view returns (uint256) {
-        return _totalCheckpoints.getAtRecentBlock(blockNumber);
+    function getAtProbablyRecentBlock(uint256 blockNumber) public view returns (uint256) {
+        return _totalCheckpoints.getAtProbablyRecentBlock(blockNumber);
     }
 
     function length() public view returns (uint256) {
@@ -52,10 +52,6 @@ contract Checkpoints224Mock {
         return _totalCheckpoints.upperLookup(key);
     }
 
-    function upperLookupRecent(uint32 key) public view returns (uint224) {
-        return _totalCheckpoints.upperLookupRecent(key);
-    }
-
     function length() public view returns (uint256) {
         return _totalCheckpoints._checkpoints.length;
     }
@@ -82,10 +78,6 @@ contract Checkpoints160Mock {
         return _totalCheckpoints.upperLookup(key);
     }
 
-    function upperLookupRecent(uint96 key) public view returns (uint224) {
-        return _totalCheckpoints.upperLookupRecent(key);
-    }
-
     function length() public view returns (uint256) {
         return _totalCheckpoints._checkpoints.length;
     }

+ 29 - 6
contracts/token/ERC20/extensions/ERC20Votes.sol

@@ -97,6 +97,7 @@ abstract contract ERC20Votes is IVotes, ERC20Permit {
     function _checkpointsLookup(Checkpoint[] storage ckpts, uint256 blockNumber) private view returns (uint256) {
         // We run a binary search to look for the earliest checkpoint taken after `blockNumber`.
         //
+        // Initially we check if the block is recent to narrow the search range.
         // During the loop, the index of the wanted checkpoint remains in the range [low-1, high).
         // With each iteration, either `low` or `high` is moved towards the middle of the range to maintain the invariant.
         // - If the middle checkpoint is after `blockNumber`, we look in [low, mid)
@@ -106,18 +107,30 @@ abstract contract ERC20Votes is IVotes, ERC20Permit {
         // Note that if the latest checkpoint available is exactly for `blockNumber`, we end up with an index that is
         // past the end of the array, so we technically don't find a checkpoint after `blockNumber`, but it works out
         // the same.
-        uint256 high = ckpts.length;
+        uint256 length = ckpts.length;
+
         uint256 low = 0;
+        uint256 high = length;
+
+        if (length > 5) {
+            uint256 mid = length - Math.sqrt(length);
+            if (_unsafeAccess(ckpts, mid).fromBlock > blockNumber) {
+                high = mid;
+            } else {
+                low = mid + 1;
+            }
+        }
+
         while (low < high) {
             uint256 mid = Math.average(low, high);
-            if (ckpts[mid].fromBlock > blockNumber) {
+            if (_unsafeAccess(ckpts, mid).fromBlock > blockNumber) {
                 high = mid;
             } else {
                 low = mid + 1;
             }
         }
 
-        return high == 0 ? 0 : ckpts[high - 1].votes;
+        return high == 0 ? 0 : _unsafeAccess(ckpts, high - 1).votes;
     }
 
     /**
@@ -229,11 +242,14 @@ abstract contract ERC20Votes is IVotes, ERC20Permit {
         uint256 delta
     ) private returns (uint256 oldWeight, uint256 newWeight) {
         uint256 pos = ckpts.length;
-        oldWeight = pos == 0 ? 0 : ckpts[pos - 1].votes;
+
+        Checkpoint memory oldCkpt = pos == 0 ? Checkpoint(0, 0) : _unsafeAccess(ckpts, pos - 1);
+
+        oldWeight = oldCkpt.votes;
         newWeight = op(oldWeight, delta);
 
-        if (pos > 0 && ckpts[pos - 1].fromBlock == block.number) {
-            ckpts[pos - 1].votes = SafeCast.toUint224(newWeight);
+        if (pos > 0 && oldCkpt.fromBlock == block.number) {
+            _unsafeAccess(ckpts, pos - 1).votes = SafeCast.toUint224(newWeight);
         } else {
             ckpts.push(Checkpoint({fromBlock: SafeCast.toUint32(block.number), votes: SafeCast.toUint224(newWeight)}));
         }
@@ -246,4 +262,11 @@ abstract contract ERC20Votes is IVotes, ERC20Permit {
     function _subtract(uint256 a, uint256 b) private pure returns (uint256) {
         return a - b;
     }
+
+    function _unsafeAccess(Checkpoint[] storage ckpts, uint256 pos) private view returns (Checkpoint storage result) {
+        assembly {
+            mstore(0, ckpts.slot)
+            result.slot := add(keccak256(0, 0x20), pos)
+        }
+    }
 }

+ 14 - 46
contracts/utils/Checkpoints.sol

@@ -49,22 +49,28 @@ library Checkpoints {
 
     /**
      * @dev Returns the value at a given block number. If a checkpoint is not available at that block, the closest one
-     * before it is returned, or zero otherwise. Similarly to {upperLookup} but optimized for the case when the search
-     * key is known to be recent.
+     * before it is returned, or zero otherwise. Similar to {upperLookup} but optimized for the case when the searched
+     * checkpoint is probably "recent", defined as being among the last sqrt(N) checkpoints where N is the number of
+     * checkpoints.
      */
-    function getAtRecentBlock(History storage self, uint256 blockNumber) internal view returns (uint256) {
+    function getAtProbablyRecentBlock(History storage self, uint256 blockNumber) internal view returns (uint256) {
         require(blockNumber < block.number, "Checkpoints: block not yet mined");
         uint32 key = SafeCast.toUint32(blockNumber);
 
         uint256 length = self._checkpoints.length;
-        uint256 offset = 1;
 
-        while (offset <= length && _unsafeAccess(self._checkpoints, length - offset)._blockNumber > key) {
-            offset <<= 1;
+        uint256 low = 0;
+        uint256 high = length;
+
+        if (length > 5) {
+            uint256 mid = length - Math.sqrt(length);
+            if (key < _unsafeAccess(self._checkpoints, mid)._blockNumber) {
+                high = mid;
+            } else {
+                low = mid + 1;
+            }
         }
 
-        uint256 low = offset < length ? length - offset : 0;
-        uint256 high = length - (offset >> 1);
         uint256 pos = _upperBinaryLookup(self._checkpoints, key, low, high);
 
         return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value;
@@ -225,25 +231,6 @@ library Checkpoints {
         return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value;
     }
 
-    /**
-     * @dev Returns the value in the most recent checkpoint with key lower or equal than the search key (similarly to
-     * {upperLookup}), optimized for the case when the search key is known to be recent.
-     */
-    function upperLookupRecent(Trace224 storage self, uint32 key) internal view returns (uint224) {
-        uint256 length = self._checkpoints.length;
-        uint256 offset = 1;
-
-        while (offset <= length && _unsafeAccess(self._checkpoints, length - offset)._key > key) {
-            offset <<= 1;
-        }
-
-        uint256 low = 0 < offset && offset < length ? length - offset : 0;
-        uint256 high = length - (offset >> 1);
-        uint256 pos = _upperBinaryLookup(self._checkpoints, key, low, high);
-
-        return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value;
-    }
-
     /**
      * @dev Pushes a (`key`, `value`) pair into an ordered list of checkpoints, either by inserting a new checkpoint,
      * or by updating the last one.
@@ -380,25 +367,6 @@ library Checkpoints {
         return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value;
     }
 
-    /**
-     * @dev Returns the value in the most recent checkpoint with key lower or equal than the search key (similarly to
-     * {upperLookup}), optimized for the case when the search key is known to be recent.
-     */
-    function upperLookupRecent(Trace160 storage self, uint96 key) internal view returns (uint160) {
-        uint256 length = self._checkpoints.length;
-        uint256 offset = 1;
-
-        while (offset <= length && _unsafeAccess(self._checkpoints, length - offset)._key > key) {
-            offset <<= 1;
-        }
-
-        uint256 low = 0 < offset && offset < length ? length - offset : 0;
-        uint256 high = length - (offset >> 1);
-        uint256 pos = _upperBinaryLookup(self._checkpoints, key, low, high);
-
-        return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value;
-    }
-
     /**
      * @dev Pushes a (`key`, `value`) pair into an ordered list of checkpoints, either by inserting a new checkpoint,
      * or by updating the last one.

+ 14 - 27
scripts/generate/templates/Checkpoints.js

@@ -70,25 +70,6 @@ function upperLookup(${opts.historyTypeName} storage self, ${opts.keyTypeName} k
     uint256 pos = _upperBinaryLookup(self.${opts.checkpointFieldName}, key, 0, length);
     return pos == 0 ? 0 : _unsafeAccess(self.${opts.checkpointFieldName}, pos - 1).${opts.valueFieldName};
 }
-
-/**
- * @dev Returns the value in the most recent checkpoint with key lower or equal than the search key (similarly to
- * {upperLookup}), optimized for the case when the search key is known to be recent.
- */
-function upperLookupRecent(${opts.historyTypeName} storage self, ${opts.keyTypeName} key) internal view returns (${opts.valueTypeName}) {
-    uint256 length = self.${opts.checkpointFieldName}.length;
-    uint256 offset = 1;
-
-    while (offset <= length && _unsafeAccess(self.${opts.checkpointFieldName}, length - offset).${opts.keyFieldName} > key) {
-        offset <<= 1;
-    }
-
-    uint256 low = 0 < offset && offset < length ? length - offset : 0;
-    uint256 high = length - (offset >> 1);
-    uint256 pos = _upperBinaryLookup(self.${opts.checkpointFieldName}, key, low, high);
-
-    return pos == 0 ? 0 : _unsafeAccess(self.${opts.checkpointFieldName}, pos - 1).${opts.valueFieldName};
-}
 `;
 
 const legacyOperations = opts => `\
@@ -115,22 +96,28 @@ function getAtBlock(${opts.historyTypeName} storage self, uint256 blockNumber) i
 
 /**
  * @dev Returns the value at a given block number. If a checkpoint is not available at that block, the closest one
- * before it is returned, or zero otherwise. Similarly to {upperLookup} but optimized for the case when the search
- * key is known to be recent.
+ * before it is returned, or zero otherwise. Similar to {upperLookup} but optimized for the case when the searched
+ * checkpoint is probably "recent", defined as being among the last sqrt(N) checkpoints where N is the number of
+ * checkpoints.
  */
-function getAtRecentBlock(${opts.historyTypeName} storage self, uint256 blockNumber) internal view returns (uint256) {
+function getAtProbablyRecentBlock(${opts.historyTypeName} storage self, uint256 blockNumber) internal view returns (uint256) {
     require(blockNumber < block.number, "Checkpoints: block not yet mined");
     uint32 key = SafeCast.toUint32(blockNumber);
 
     uint256 length = self.${opts.checkpointFieldName}.length;
-    uint256 offset = 1;
 
-    while (offset <= length && _unsafeAccess(self.${opts.checkpointFieldName}, length - offset).${opts.keyFieldName} > key) {
-        offset <<= 1;
+    uint256 low = 0;
+    uint256 high = length;
+
+    if (length > 5) {
+        uint256 mid = length - Math.sqrt(length);
+        if (key < _unsafeAccess(self.${opts.checkpointFieldName}, mid)._blockNumber) {
+            high = mid;
+        } else {
+            low = mid + 1;
+        }
     }
 
-    uint256 low = offset < length ? length - offset : 0;
-    uint256 high = length - (offset >> 1);
     uint256 pos = _upperBinaryLookup(self.${opts.checkpointFieldName}, key, low, high);
 
     return pos == 0 ? 0 : _unsafeAccess(self.${opts.checkpointFieldName}, pos - 1).${opts.valueFieldName};

+ 2 - 6
scripts/generate/templates/CheckpointsMock.js

@@ -26,8 +26,8 @@ contract CheckpointsMock {
         return _totalCheckpoints.getAtBlock(blockNumber);
     }
 
-    function getAtRecentBlock(uint256 blockNumber) public view returns (uint256) {
-        return _totalCheckpoints.getAtRecentBlock(blockNumber);
+    function getAtProbablyRecentBlock(uint256 blockNumber) public view returns (uint256) {
+        return _totalCheckpoints.getAtProbablyRecentBlock(blockNumber);
     }
 
     function length() public view returns (uint256) {
@@ -58,10 +58,6 @@ contract Checkpoints${length}Mock {
         return _totalCheckpoints.upperLookup(key);
     }
 
-    function upperLookupRecent(uint${256 - length} key) public view returns (uint224) {
-        return _totalCheckpoints.upperLookupRecent(key);
-    }
-
     function length() public view returns (uint256) {
         return _totalCheckpoints._checkpoints.length;
     }

+ 13 - 0
test/token/ERC20/extensions/ERC20Votes.test.js

@@ -56,6 +56,19 @@ contract('ERC20Votes', function (accounts) {
     );
   });
 
+  it('recent checkpoints', async function () {
+    await this.token.delegate(holder, { from: holder });
+    for (let i = 0; i < 6; i++) {
+      await this.token.mint(holder, 1);
+    }
+    const block = await web3.eth.getBlockNumber();
+    expect(await this.token.numCheckpoints(holder)).to.be.bignumber.equal('6');
+    // recent
+    expect(await this.token.getPastVotes(holder, block - 1)).to.be.bignumber.equal('5');
+    // non-recent
+    expect(await this.token.getPastVotes(holder, block - 6)).to.be.bignumber.equal('0');
+  });
+
   describe('set delegation', function () {
     describe('call', function () {
       it('delegation with balance', async function () {

+ 17 - 5
test/utils/Checkpoints.test.js

@@ -22,8 +22,10 @@ contract('Checkpoints', function (accounts) {
 
       it('returns zero as past value', async function () {
         await time.advanceBlock();
-        expect(await this.checkpoint.getAtBlock(await web3.eth.getBlockNumber() - 1)).to.be.bignumber.equal('0');
-        expect(await this.checkpoint.getAtRecentBlock(await web3.eth.getBlockNumber() - 1)).to.be.bignumber.equal('0');
+        expect(await this.checkpoint.getAtBlock(await web3.eth.getBlockNumber() - 1))
+          .to.be.bignumber.equal('0');
+        expect(await this.checkpoint.getAtProbablyRecentBlock(await web3.eth.getBlockNumber() - 1))
+          .to.be.bignumber.equal('0');
       });
     });
 
@@ -41,7 +43,7 @@ contract('Checkpoints', function (accounts) {
         expect(await this.checkpoint.latest()).to.be.bignumber.equal('3');
       });
 
-      for (const fn of [ 'getAtBlock(uint256)', 'getAtRecentBlock(uint256)' ]) {
+      for (const fn of [ 'getAtBlock(uint256)', 'getAtProbablyRecentBlock(uint256)' ]) {
         describe(`lookup: ${fn}`, function () {
           it('returns past values', async function () {
             expect(await this.checkpoint.methods[fn](this.tx1.receipt.blockNumber - 1)).to.be.bignumber.equal('0');
@@ -78,6 +80,18 @@ contract('Checkpoints', function (accounts) {
         expect(await this.checkpoint.length()).to.be.bignumber.equal(lengthBefore.addn(1));
         expect(await this.checkpoint.latest()).to.be.bignumber.equal('10');
       });
+
+      it('more than 5 checkpoints', async function () {
+        for (let i = 4; i <= 6; i++) {
+          await this.checkpoint.push(i);
+        }
+        expect(await this.checkpoint.length()).to.be.bignumber.equal('6');
+        const block = await web3.eth.getBlockNumber();
+        // recent
+        expect(await this.checkpoint.getAtProbablyRecentBlock(block - 1)).to.be.bignumber.equal('5');
+        // non-recent
+        expect(await this.checkpoint.getAtProbablyRecentBlock(block - 9)).to.be.bignumber.equal('0');
+      });
     });
   });
 
@@ -95,7 +109,6 @@ contract('Checkpoints', function (accounts) {
         it('lookup returns 0', async function () {
           expect(await this.contract.lowerLookup(0)).to.be.bignumber.equal('0');
           expect(await this.contract.upperLookup(0)).to.be.bignumber.equal('0');
-          expect(await this.contract.upperLookupRecent(0)).to.be.bignumber.equal('0');
         });
       });
 
@@ -149,7 +162,6 @@ contract('Checkpoints', function (accounts) {
             const value = last(this.checkpoints.filter(x => i >= x.key))?.value || '0';
 
             expect(await this.contract.upperLookup(i)).to.be.bignumber.equal(value);
-            expect(await this.contract.upperLookupRecent(i)).to.be.bignumber.equal(value);
           }
         });
       });