Selaa lähdekoodia

Bound lookup in arrays with duplicate (#4842)

Co-authored-by: RenanSouza2 <renan.rodrigues.souza1@gmail.com>
Co-authored-by: ernestognw <ernestognw@gmail.com>
Hadrien Croubois 1 vuosi sitten
vanhempi
sitoutus
61117c4db8

+ 5 - 0
.changeset/flat-turtles-repeat.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`Arrays`: deprecate `findUpperBound` in favor of the new `lowerBound`.

+ 5 - 0
.changeset/thick-pumpkins-report.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`Arrays`: add new functions `lowerBound`, `upperBound`, `lowerBoundMemory` and `upperBoundMemory` for lookups in sorted arrays with potential duplicates.

+ 18 - 2
contracts/mocks/ArraysMock.sol

@@ -13,8 +13,24 @@ contract Uint256ArraysMock {
         _array = array;
     }
 
-    function findUpperBound(uint256 element) external view returns (uint256) {
-        return _array.findUpperBound(element);
+    function findUpperBound(uint256 value) external view returns (uint256) {
+        return _array.findUpperBound(value);
+    }
+
+    function lowerBound(uint256 value) external view returns (uint256) {
+        return _array.lowerBound(value);
+    }
+
+    function upperBound(uint256 value) external view returns (uint256) {
+        return _array.upperBound(value);
+    }
+
+    function lowerBoundMemory(uint256[] memory array, uint256 value) external pure returns (uint256) {
+        return array.lowerBoundMemory(value);
+    }
+
+    function upperBoundMemory(uint256[] memory array, uint256 value) external pure returns (uint256) {
+        return array.upperBoundMemory(value);
     }
 
     function unsafeAccess(uint256 pos) external view returns (uint256) {

+ 132 - 2
contracts/utils/Arrays.sol

@@ -18,8 +18,12 @@ library Arrays {
      * values in the array are strictly less than `element`), the array length is
      * returned. Time complexity O(log n).
      *
-     * `array` is expected to be sorted in ascending order, and to contain no
-     * repeated elements.
+     * NOTE: The `array` is expected to be sorted in ascending order, and to
+     * contain no repeated elements.
+     *
+     * IMPORTANT: Deprecated. This implementation behaves as {lowerBound} but lacks
+     * support for repeated elements in the array. The {lowerBound} function should
+     * be used instead.
      */
     function findUpperBound(uint256[] storage array, uint256 element) internal view returns (uint256) {
         uint256 low = 0;
@@ -49,6 +53,132 @@ library Arrays {
         }
     }
 
+    /**
+     * @dev Searches an `array` sorted in ascending order and returns the first
+     * index that contains a value greater or equal than `element`. If no such index
+     * exists (i.e. all values in the array are strictly less than `element`), the array
+     * length is returned. Time complexity O(log n).
+     *
+     * See C++'s https://en.cppreference.com/w/cpp/algorithm/lower_bound[lower_bound].
+     */
+    function lowerBound(uint256[] storage array, uint256 element) internal view returns (uint256) {
+        uint256 low = 0;
+        uint256 high = array.length;
+
+        if (high == 0) {
+            return 0;
+        }
+
+        while (low < high) {
+            uint256 mid = Math.average(low, high);
+
+            // Note that mid will always be strictly less than high (i.e. it will be a valid array index)
+            // because Math.average rounds towards zero (it does integer division with truncation).
+            if (unsafeAccess(array, mid).value < element) {
+                // this cannot overflow because mid < high
+                unchecked {
+                    low = mid + 1;
+                }
+            } else {
+                high = mid;
+            }
+        }
+
+        return low;
+    }
+
+    /**
+     * @dev Searches an `array` sorted in ascending order and returns the first
+     * index that contains a value strictly greater than `element`. If no such index
+     * exists (i.e. all values in the array are strictly less than `element`), the array
+     * length is returned. Time complexity O(log n).
+     *
+     * See C++'s https://en.cppreference.com/w/cpp/algorithm/upper_bound[upper_bound].
+     */
+    function upperBound(uint256[] storage array, uint256 element) internal view returns (uint256) {
+        uint256 low = 0;
+        uint256 high = array.length;
+
+        if (high == 0) {
+            return 0;
+        }
+
+        while (low < high) {
+            uint256 mid = Math.average(low, high);
+
+            // Note that mid will always be strictly less than high (i.e. it will be a valid array index)
+            // because Math.average rounds towards zero (it does integer division with truncation).
+            if (unsafeAccess(array, mid).value > element) {
+                high = mid;
+            } else {
+                // this cannot overflow because mid < high
+                unchecked {
+                    low = mid + 1;
+                }
+            }
+        }
+
+        return low;
+    }
+
+    /**
+     * @dev Same as {lowerBound}, but with an array in memory.
+     */
+    function lowerBoundMemory(uint256[] memory array, uint256 element) internal pure returns (uint256) {
+        uint256 low = 0;
+        uint256 high = array.length;
+
+        if (high == 0) {
+            return 0;
+        }
+
+        while (low < high) {
+            uint256 mid = Math.average(low, high);
+
+            // Note that mid will always be strictly less than high (i.e. it will be a valid array index)
+            // because Math.average rounds towards zero (it does integer division with truncation).
+            if (unsafeMemoryAccess(array, mid) < element) {
+                // this cannot overflow because mid < high
+                unchecked {
+                    low = mid + 1;
+                }
+            } else {
+                high = mid;
+            }
+        }
+
+        return low;
+    }
+
+    /**
+     * @dev Same as {upperBound}, but with an array in memory.
+     */
+    function upperBoundMemory(uint256[] memory array, uint256 element) internal pure returns (uint256) {
+        uint256 low = 0;
+        uint256 high = array.length;
+
+        if (high == 0) {
+            return 0;
+        }
+
+        while (low < high) {
+            uint256 mid = Math.average(low, high);
+
+            // Note that mid will always be strictly less than high (i.e. it will be a valid array index)
+            // because Math.average rounds towards zero (it does integer division with truncation).
+            if (unsafeMemoryAccess(array, mid) > element) {
+                high = mid;
+            } else {
+                // this cannot overflow because mid < high
+                unchecked {
+                    low = mid + 1;
+                }
+            }
+        }
+
+        return low;
+    }
+
     /**
      * @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check.
      *

+ 41 - 26
test/utils/Arrays.test.js

@@ -4,22 +4,22 @@ const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');
 
 const { randomArray, generators } = require('../helpers/random');
 
-// See https://en.cppreference.com/w/cpp/algorithm/ranges/lower_bound
+// See https://en.cppreference.com/w/cpp/algorithm/lower_bound
 const lowerBound = (array, value) => {
   const i = array.findIndex(element => value <= element);
   return i == -1 ? array.length : i;
 };
 
 // See https://en.cppreference.com/w/cpp/algorithm/upper_bound
-// const upperBound = (array, value) => {
-//   const i = array.findIndex(element => value < element);
-//   return i == -1 ? array.length : i;
-// };
+const upperBound = (array, value) => {
+  const i = array.findIndex(element => value < element);
+  return i == -1 ? array.length : i;
+};
 
 const hasDuplicates = array => array.some((v, i) => array.indexOf(v) != i);
 
 describe('Arrays', function () {
-  describe('findUpperBound', function () {
+  describe('search', function () {
     for (const [title, { array, tests }] of Object.entries({
       'Even number of elements': {
         array: [11n, 12n, 13n, 14n, 15n, 16n, 17n, 18n, 19n, 20n],
@@ -82,10 +82,25 @@ describe('Arrays', function () {
         });
 
         for (const [name, input] of Object.entries(tests)) {
-          it(name, async function () {
-            // findUpperBound does not support duplicated
-            if (hasDuplicates(array)) this.skip();
-            expect(await this.mock.findUpperBound(input)).to.equal(lowerBound(array, input));
+          describe(name, function () {
+            it('[deprecated] findUpperBound', async function () {
+              // findUpperBound does not support duplicated
+              if (hasDuplicates(array)) {
+                expect(await this.mock.findUpperBound(input)).to.be.equal(upperBound(array, input) - 1);
+              } else {
+                expect(await this.mock.findUpperBound(input)).to.be.equal(lowerBound(array, input));
+              }
+            });
+
+            it('lowerBound', async function () {
+              expect(await this.mock.lowerBound(input)).to.be.equal(lowerBound(array, input));
+              expect(await this.mock.lowerBoundMemory(array, input)).to.be.equal(lowerBound(array, input));
+            });
+
+            it('upperBound', async function () {
+              expect(await this.mock.upperBound(input)).to.be.equal(upperBound(array, input));
+              expect(await this.mock.upperBoundMemory(array, input)).to.be.equal(upperBound(array, input));
+            });
           });
         }
       });
@@ -93,29 +108,29 @@ describe('Arrays', function () {
   });
 
   describe('unsafeAccess', function () {
-    const contractCases = {
+    for (const [title, { artifact, elements }] of Object.entries({
       address: { artifact: 'AddressArraysMock', elements: randomArray(generators.address, 10) },
       bytes32: { artifact: 'Bytes32ArraysMock', elements: randomArray(generators.bytes32, 10) },
       uint256: { artifact: 'Uint256ArraysMock', elements: randomArray(generators.uint256, 10) },
-    };
-
-    const fixture = async () => {
-      const contracts = {};
-      for (const [name, { artifact, elements }] of Object.entries(contractCases)) {
-        contracts[name] = await ethers.deployContract(artifact, [elements]);
-      }
-      return { contracts };
-    };
+    })) {
+      describe(title, function () {
+        const fixture = async () => {
+          return { mock: await ethers.deployContract(artifact, [elements]) };
+        };
 
-    beforeEach(async function () {
-      Object.assign(this, await loadFixture(fixture));
-    });
+        beforeEach(async function () {
+          Object.assign(this, await loadFixture(fixture));
+        });
 
-    for (const [name, { elements }] of Object.entries(contractCases)) {
-      it(name, async function () {
         for (const i in elements) {
-          expect(await this.contracts[name].unsafeAccess(i)).to.equal(elements[i]);
+          it(`unsafeAccess within bounds #${i}`, async function () {
+            expect(await this.mock.unsafeAccess(i)).to.equal(elements[i]);
+          });
         }
+
+        it('unsafeAccess outside bounds', async function () {
+          await expect(this.mock.unsafeAccess(elements.length)).to.not.be.rejected;
+        });
       });
     }
   });

+ 3 - 5
test/utils/structs/Checkpoints.test.js

@@ -4,8 +4,6 @@ const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');
 
 const { VALUE_SIZES } = require('../../../scripts/generate/templates/Checkpoints.opts');
 
-const last = array => (array.length ? array[array.length - 1] : undefined);
-
 describe('Checkpoints', function () {
   for (const length of VALUE_SIZES) {
     describe(`Trace${length}`, function () {
@@ -81,7 +79,7 @@ describe('Checkpoints', function () {
         it('returns latest value', async function () {
           const latest = this.checkpoints.at(-1);
           expect(await this.methods.latest()).to.equal(latest.value);
-          expect(await this.methods.latestCheckpoint()).to.have.ordered.members([true, latest.key, latest.value]);
+          expect(await this.methods.latestCheckpoint()).to.deep.equal([true, latest.key, latest.value]);
         });
 
         it('cannot push values in the past', async function () {
@@ -115,7 +113,7 @@ describe('Checkpoints', function () {
 
         it('upper lookup & upperLookupRecent', async function () {
           for (let i = 0; i < 14; ++i) {
-            const value = last(this.checkpoints.filter(x => i >= x.key))?.value || 0n;
+            const value = this.checkpoints.findLast(x => i >= x.key)?.value || 0n;
 
             expect(await this.methods.upperLookup(i)).to.equal(value);
             expect(await this.methods.upperLookupRecent(i)).to.equal(value);
@@ -137,7 +135,7 @@ describe('Checkpoints', function () {
           }
 
           for (let i = 0; i < 25; ++i) {
-            const value = last(allCheckpoints.filter(x => i >= x.key))?.value || 0n;
+            const value = allCheckpoints.findLast(x => i >= x.key)?.value || 0n;
             expect(await this.methods.upperLookup(i)).to.equal(value);
             expect(await this.methods.upperLookupRecent(i)).to.equal(value);
           }