Răsfoiți Sursa

Add variants of Array.sort for address[] and bytes32[] (#4883)

Co-authored-by: Ernesto García <ernestognw@gmail.com>
Hadrien Croubois 1 an în urmă
părinte
comite
f8b1ddf591

+ 1 - 1
.changeset/dirty-cobras-smile.md

@@ -2,4 +2,4 @@
 'openzeppelin-solidity': minor
 ---
 
-`Arrays`: add a `sort` function.
+`Arrays`: add a `sort` functions for `address[]`, `bytes32[]` and `uint256[]` memory arrays.

+ 36 - 0
contracts/mocks/ArraysMock.sol

@@ -36,6 +36,18 @@ contract Uint256ArraysMock {
     function unsafeAccess(uint256 pos) external view returns (uint256) {
         return _array.unsafeAccess(pos).value;
     }
+
+    function sort(uint256[] memory array) external pure returns (uint256[] memory) {
+        return array.sort();
+    }
+
+    function sortReverse(uint256[] memory array) external pure returns (uint256[] memory) {
+        return array.sort(_reverse);
+    }
+
+    function _reverse(uint256 a, uint256 b) private pure returns (bool) {
+        return a > b;
+    }
 }
 
 contract AddressArraysMock {
@@ -50,6 +62,18 @@ contract AddressArraysMock {
     function unsafeAccess(uint256 pos) external view returns (address) {
         return _array.unsafeAccess(pos).value;
     }
+
+    function sort(address[] memory array) external pure returns (address[] memory) {
+        return array.sort();
+    }
+
+    function sortReverse(address[] memory array) external pure returns (address[] memory) {
+        return array.sort(_reverse);
+    }
+
+    function _reverse(address a, address b) private pure returns (bool) {
+        return uint160(a) > uint160(b);
+    }
 }
 
 contract Bytes32ArraysMock {
@@ -64,4 +88,16 @@ contract Bytes32ArraysMock {
     function unsafeAccess(uint256 pos) external view returns (bytes32) {
         return _array.unsafeAccess(pos).value;
     }
+
+    function sort(bytes32[] memory array) external pure returns (bytes32[] memory) {
+        return array.sort();
+    }
+
+    function sortReverse(bytes32[] memory array) external pure returns (bytes32[] memory) {
+        return array.sort(_reverse);
+    }
+
+    function _reverse(bytes32 a, bytes32 b) private pure returns (bool) {
+        return uint256(a) > uint256(b);
+    }
 }

+ 141 - 29
contracts/utils/Arrays.sol

@@ -13,7 +13,7 @@ library Arrays {
     using StorageSlot for bytes32;
 
     /**
-     * @dev Sort an array (in memory) in increasing order.
+     * @dev Sort an array of bytes32 (in memory) following the provided comparator function.
      *
      * This function does the sorting "in place", meaning that it overrides the input. The object is returned for
      * convenience, but that returned value can be discarded safely if the caller has a memory pointer to the array.
@@ -23,55 +23,167 @@ library Arrays {
      * when executing this as part of a transaction. If the array being sorted is too large, the sort operation may
      * consume more gas than is available in a block, leading to potential DoS.
      */
+    function sort(
+        bytes32[] memory array,
+        function(bytes32, bytes32) pure returns (bool) comp
+    ) internal pure returns (bytes32[] memory) {
+        _quickSort(_begin(array), _end(array), comp);
+        return array;
+    }
+
+    /**
+     * @dev Variant of {sort} that sorts an array of bytes32 in increasing order.
+     */
+    function sort(bytes32[] memory array) internal pure returns (bytes32[] memory) {
+        return sort(array, _defaultComp);
+    }
+
+    /**
+     * @dev Variant of {sort} that sorts an array of address following a provided comparator function.
+     */
+    function sort(
+        address[] memory array,
+        function(address, address) pure returns (bool) comp
+    ) internal pure returns (address[] memory) {
+        sort(_castToBytes32Array(array), _castToBytes32Comp(comp));
+        return array;
+    }
+
+    /**
+     * @dev Variant of {sort} that sorts an array of address in increasing order.
+     */
+    function sort(address[] memory array) internal pure returns (address[] memory) {
+        sort(_castToBytes32Array(array), _defaultComp);
+        return array;
+    }
+
+    /**
+     * @dev Variant of {sort} that sorts an array of uint256 following a provided comparator function.
+     */
+    function sort(
+        uint256[] memory array,
+        function(uint256, uint256) pure returns (bool) comp
+    ) internal pure returns (uint256[] memory) {
+        sort(_castToBytes32Array(array), _castToBytes32Comp(comp));
+        return array;
+    }
+
+    /**
+     * @dev Variant of {sort} that sorts an array of uint256 in increasing order.
+     */
     function sort(uint256[] memory array) internal pure returns (uint256[] memory) {
-        _quickSort(array, 0, array.length);
+        sort(_castToBytes32Array(array), _defaultComp);
         return array;
     }
 
     /**
-     * @dev Performs a quick sort on an array in memory. The array is sorted in increasing order.
+     * @dev Performs a quick sort of a segment of memory. The segment sorted starts at `begin` (inclusive), and stops
+     * at end (exclusive). Sorting follows the `comp` comparator.
      *
-     * Invariant: `i <= j <= array.length`. This is the case when initially called by {sort} and is preserved in
-     * subcalls.
+     * Invariant: `begin <= end`. This is the case when initially called by {sort} and is preserved in subcalls.
+     *
+     * IMPORTANT: Memory locations between `begin` and `end` are not validated/zeroed. This function should
+     * be used only if the limits are within a memory array.
      */
-    function _quickSort(uint256[] memory array, uint256 i, uint256 j) private pure {
+    function _quickSort(uint256 begin, uint256 end, function(bytes32, bytes32) pure returns (bool) comp) private pure {
         unchecked {
-            // Can't overflow given `i <= j`
-            if (j - i < 2) return;
+            if (end - begin < 0x40) return;
 
             // Use first element as pivot
-            uint256 pivot = unsafeMemoryAccess(array, i);
+            bytes32 pivot = _mload(begin);
             // Position where the pivot should be at the end of the loop
-            uint256 index = i;
-
-            for (uint256 k = i + 1; k < j; ++k) {
-                // Unsafe access is safe given `k < j <= array.length`.
-                if (unsafeMemoryAccess(array, k) < pivot) {
-                    // If array[k] is smaller than the pivot, we increment the index and move array[k] there.
-                    _swap(array, ++index, k);
+            uint256 pos = begin;
+
+            for (uint256 it = begin + 0x20; it < end; it += 0x20) {
+                if (comp(_mload(it), pivot)) {
+                    // If the value stored at the iterator's position comes before the pivot, we increment the
+                    // position of the pivot and move the value there.
+                    pos += 0x20;
+                    _swap(pos, it);
                 }
             }
 
-            // Swap pivot into place
-            _swap(array, i, index);
+            _swap(begin, pos); // Swap pivot into place
+            _quickSort(begin, pos, comp); // Sort the left side of the pivot
+            _quickSort(pos + 0x20, end, comp); // Sort the right side of the pivot
+        }
+    }
+
+    /**
+     * @dev Pointer to the memory location of the first element of `array`.
+     */
+    function _begin(bytes32[] memory array) private pure returns (uint256 ptr) {
+        /// @solidity memory-safe-assembly
+        assembly {
+            ptr := add(array, 0x20)
+        }
+    }
+
+    /**
+     * @dev Pointer to the memory location of the first memory word (32bytes) after `array`. This is the memory word
+     * that comes just after the last element of the array.
+     */
+    function _end(bytes32[] memory array) private pure returns (uint256 ptr) {
+        unchecked {
+            return _begin(array) + array.length * 0x20;
+        }
+    }
 
-            _quickSort(array, i, index); // Sort the left side of the pivot
-            _quickSort(array, index + 1, j); // Sort the right side of the pivot
+    /**
+     * @dev Load memory word (as a bytes32) at location `ptr`.
+     */
+    function _mload(uint256 ptr) private pure returns (bytes32 value) {
+        assembly {
+            value := mload(ptr)
         }
     }
 
     /**
-     * @dev Swaps the elements at positions `i` and `j` in the `arr` array.
+     * @dev Swaps the elements memory location `ptr1` and `ptr2`.
      */
-    function _swap(uint256[] memory arr, uint256 i, uint256 j) private pure {
+    function _swap(uint256 ptr1, uint256 ptr2) private pure {
+        assembly {
+            let value1 := mload(ptr1)
+            let value2 := mload(ptr2)
+            mstore(ptr1, value2)
+            mstore(ptr2, value1)
+        }
+    }
+
+    /// @dev Comparator for sorting arrays in increasing order.
+    function _defaultComp(bytes32 a, bytes32 b) private pure returns (bool) {
+        return a < b;
+    }
+
+    /// @dev Helper: low level cast address memory array to uint256 memory array
+    function _castToBytes32Array(address[] memory input) private pure returns (bytes32[] memory output) {
+        assembly {
+            output := input
+        }
+    }
+
+    /// @dev Helper: low level cast uint256 memory array to uint256 memory array
+    function _castToBytes32Array(uint256[] memory input) private pure returns (bytes32[] memory output) {
+        assembly {
+            output := input
+        }
+    }
+
+    /// @dev Helper: low level cast address comp function to bytes32 comp function
+    function _castToBytes32Comp(
+        function(address, address) pure returns (bool) input
+    ) private pure returns (function(bytes32, bytes32) pure returns (bool) output) {
+        assembly {
+            output := input
+        }
+    }
+
+    /// @dev Helper: low level cast uint256 comp function to bytes32 comp function
+    function _castToBytes32Comp(
+        function(uint256, uint256) pure returns (bool) input
+    ) private pure returns (function(bytes32, bytes32) pure returns (bool) output) {
         assembly {
-            let start := add(arr, 0x20) // Pointer to the first element of the array
-            let pos_i := add(start, mul(i, 0x20))
-            let pos_j := add(start, mul(j, 0x20))
-            let val_i := mload(pos_i)
-            let val_j := mload(pos_j)
-            mstore(pos_i, val_j)
-            mstore(pos_j, val_i)
+            output := input
         }
     }
 

+ 2 - 8
scripts/helpers.js

@@ -7,11 +7,7 @@ function range(start, stop = undefined, step = 1) {
     stop = start;
     start = 0;
   }
-  return start < stop
-    ? Array(Math.ceil((stop - start) / step))
-        .fill()
-        .map((_, i) => start + i * step)
-    : [];
+  return start < stop ? Array.from({ length: Math.ceil((stop - start) / step) }, (_, i) => start + i * step) : [];
 }
 
 function unique(array, op = x => x) {
@@ -19,9 +15,7 @@ function unique(array, op = x => x) {
 }
 
 function zip(...args) {
-  return Array(Math.max(...args.map(arg => arg.length)))
-    .fill(null)
-    .map((_, i) => args.map(arg => arg[i]));
+  return Array.from({ length: Math.max(...args.map(arg => arg.length)) }, (_, i) => args.map(arg => arg[i]));
 }
 
 function capitalize(str) {

+ 1 - 3
test/finance/VestingWallet.test.js

@@ -55,9 +55,7 @@ async function fixture() {
     },
   };
 
-  const schedule = Array(64)
-    .fill()
-    .map((_, i) => (BigInt(i) * duration) / 60n + start);
+  const schedule = Array.from({ length: 64 }, (_, i) => (BigInt(i) * duration) / 60n + start);
 
   const vestingFn = timestamp => min(amount, (amount * (timestamp - start)) / duration);
 

+ 1 - 3
test/helpers/iterate.js

@@ -5,9 +5,7 @@ const mapValues = (obj, fn) => Object.fromEntries(Object.entries(obj).map(([k, v
 const product = (...arrays) => arrays.reduce((a, b) => a.flatMap(ai => b.map(bi => [...ai, bi])), [[]]);
 const unique = (...array) => array.filter((obj, i) => array.indexOf(obj) === i);
 const zip = (...args) =>
-  Array(Math.max(...args.map(array => array.length)))
-    .fill()
-    .map((_, i) => args.map(array => array[i]));
+  Array.from({ length: Math.max(...args.map(array => array.length)) }, (_, i) => args.map(array => array[i]));
 
 module.exports = {
   mapValues,

+ 0 - 3
test/helpers/random.js

@@ -1,7 +1,5 @@
 const { ethers } = require('hardhat');
 
-const randomArray = (generator, arrayLength = 3) => Array(arrayLength).fill().map(generator);
-
 const generators = {
   address: () => ethers.Wallet.createRandom().address,
   bytes32: () => ethers.hexlify(ethers.randomBytes(32)),
@@ -15,6 +13,5 @@ generators.uint256.zero = 0n;
 generators.hexBytes.zero = '0x';
 
 module.exports = {
-  randomArray,
   generators,
 };

+ 73 - 55
test/utils/Arrays.test.js

@@ -2,7 +2,7 @@ const { ethers } = require('hardhat');
 const { expect } = require('chai');
 const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');
 
-const { randomArray, generators } = require('../helpers/random');
+const { generators } = require('../helpers/random');
 
 // See https://en.cppreference.com/w/cpp/algorithm/lower_bound
 const lowerBound = (array, value) => {
@@ -16,9 +16,7 @@ const upperBound = (array, value) => {
   return i == -1 ? array.length : i;
 };
 
-// By default, js "sort" cast to string and then sort in alphabetical order. Use this to sort numbers.
-const compareNumbers = (a, b) => (a > b ? 1 : a < b ? -1 : 0);
-
+const bigintSign = x => (x > 0n ? 1 : x < 0n ? -1 : 0);
 const hasDuplicates = array => array.some((v, i) => array.indexOf(v) != i);
 
 describe('Arrays', function () {
@@ -30,42 +28,6 @@ describe('Arrays', function () {
     Object.assign(this, await loadFixture(fixture));
   });
 
-  describe('sort', function () {
-    for (const length of [0, 1, 2, 8, 32, 128]) {
-      it(`sort array of length ${length}`, async function () {
-        this.elements = randomArray(generators.uint256, length);
-        this.expected = Array.from(this.elements).sort(compareNumbers);
-      });
-
-      if (length > 1) {
-        it(`sort array of length ${length} (identical elements)`, async function () {
-          this.elements = Array(length).fill(generators.uint256());
-          this.expected = this.elements;
-        });
-
-        it(`sort array of length ${length} (already sorted)`, async function () {
-          this.elements = randomArray(generators.uint256, length).sort(compareNumbers);
-          this.expected = this.elements;
-        });
-
-        it(`sort array of length ${length} (sorted in reverse order)`, async function () {
-          this.elements = randomArray(generators.uint256, length).sort(compareNumbers).reverse();
-          this.expected = Array.from(this.elements).reverse();
-        });
-
-        it(`sort array of length ${length} (almost sorted)`, async function () {
-          this.elements = randomArray(generators.uint256, length).sort(compareNumbers);
-          this.expected = Array.from(this.elements);
-          // rotate (move the last element to the front) for an almost sorted effect
-          this.elements.unshift(this.elements.pop());
-        });
-      }
-    }
-    afterEach(async function () {
-      expect(await this.mock.$sort(this.elements)).to.deep.equal(this.expected);
-    });
-  });
-
   describe('search', function () {
     for (const [title, { array, tests }] of Object.entries({
       'Even number of elements': {
@@ -154,22 +116,78 @@ describe('Arrays', function () {
     }
   });
 
-  describe('unsafeAccess', function () {
-    for (const [type, { artifact, elements }] of Object.entries({
-      address: { artifact: 'AddressArraysMock', elements: randomArray(generators.address, 10) },
-      bytes32: { artifact: 'Bytes32ArraysMock', elements: randomArray(generators.bytes32, 10) },
-      uint256: { artifact: 'Uint256ArraysMock', elements: randomArray(generators.uint256, 10) },
-    })) {
-      describe(type, function () {
-        describe('storage', function () {
-          const fixture = async () => {
-            return { instance: await ethers.deployContract(artifact, [elements]) };
-          };
+  for (const [type, { artifact, elements, comp }] of Object.entries({
+    address: {
+      artifact: 'AddressArraysMock',
+      elements: Array.from({ length: 10 }, generators.address),
+      comp: (a, b) => bigintSign(ethers.toBigInt(a) - ethers.toBigInt(b)),
+    },
+    bytes32: {
+      artifact: 'Bytes32ArraysMock',
+      elements: Array.from({ length: 10 }, generators.bytes32),
+      comp: (a, b) => bigintSign(ethers.toBigInt(a) - ethers.toBigInt(b)),
+    },
+    uint256: {
+      artifact: 'Uint256ArraysMock',
+      elements: Array.from({ length: 10 }, generators.uint256),
+      comp: (a, b) => bigintSign(a - b),
+    },
+  })) {
+    describe(type, function () {
+      const fixture = async () => {
+        return { instance: await ethers.deployContract(artifact, [elements]) };
+      };
+
+      beforeEach(async function () {
+        Object.assign(this, await loadFixture(fixture));
+      });
+
+      describe('sort', function () {
+        for (const length of [0, 1, 2, 8, 32, 128]) {
+          describe(`${type}[] of length ${length}`, function () {
+            beforeEach(async function () {
+              this.elements = Array.from({ length }, generators[type]);
+            });
+
+            afterEach(async function () {
+              const expected = Array.from(this.elements).sort(comp);
+              const reversed = Array.from(expected).reverse();
+              expect(await this.instance.sort(this.elements)).to.deep.equal(expected);
+              expect(await this.instance.sortReverse(this.elements)).to.deep.equal(reversed);
+            });
+
+            it('sort array', async function () {
+              // nothing to do here, beforeEach and afterEach already take care of everything.
+            });
 
-          beforeEach(async function () {
-            Object.assign(this, await loadFixture(fixture));
+            if (length > 1) {
+              it('sort array for identical elements', async function () {
+                // duplicate the first value to all elements
+                this.elements.fill(this.elements.at(0));
+              });
+
+              it('sort already sorted array', async function () {
+                // pre-sort the elements
+                this.elements.sort(comp);
+              });
+
+              it('sort reversed array', async function () {
+                // pre-sort in reverse order
+                this.elements.sort(comp).reverse();
+              });
+
+              it('sort almost sorted array', async function () {
+                // pre-sort + rotate (move the last element to the front) for an almost sorted effect
+                this.elements.sort(comp);
+                this.elements.unshift(this.elements.pop());
+              });
+            }
           });
+        }
+      });
 
+      describe('unsafeAccess', function () {
+        describe('storage', function () {
           for (const i in elements) {
             it(`unsafeAccess within bounds #${i}`, async function () {
               expect(await this.instance.unsafeAccess(i)).to.equal(elements[i]);
@@ -195,6 +213,6 @@ describe('Arrays', function () {
           });
         });
       });
-    }
-  });
+    });
+  }
 });

+ 2 - 2
test/utils/math/Math.test.js

@@ -5,7 +5,7 @@ const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic');
 
 const { Rounding } = require('../../helpers/enums');
 const { min, max } = require('../../helpers/math');
-const { randomArray, generators } = require('../../helpers/random');
+const { generators } = require('../../helpers/random');
 
 const RoundingDown = [Rounding.Floor, Rounding.Trunc];
 const RoundingUp = [Rounding.Ceil, Rounding.Expand];
@@ -337,7 +337,7 @@ describe('Math', function () {
         });
 
         if (p != 0) {
-          for (const value of randomArray(generators.uint256, 16)) {
+          for (const value of Array.from({ length: 16 }, generators.uint256)) {
             const isInversible = factors.every(f => value % f);
             it(`trying to inverse ${value}`, async function () {
               const result = await this.mock.$invMod(value, p);

+ 1 - 7
test/utils/structs/DoubleEndedQueue.test.js

@@ -8,13 +8,7 @@ async function fixture() {
 
   /** Rebuild the content of the deque as a JS array. */
   const getContent = () =>
-    mock.$length(0).then(length =>
-      Promise.all(
-        Array(Number(length))
-          .fill()
-          .map((_, i) => mock.$at(0, i)),
-      ),
-    );
+    mock.$length(0).then(length => Promise.all(Array.from({ length: Number(length) }, (_, i) => mock.$at(0, i))));
 
   return { mock, getContent };
 }

+ 3 - 3
test/utils/structs/EnumerableMap.test.js

@@ -2,7 +2,7 @@ const { ethers } = require('hardhat');
 const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');
 
 const { mapValues } = require('../../helpers/iterate');
-const { randomArray, generators } = require('../../helpers/random');
+const { generators } = require('../../helpers/random');
 const { TYPES, formatType } = require('../../../scripts/generate/templates/EnumerableMap.opts');
 
 const { shouldBehaveLikeMap } = require('./EnumerableMap.behavior');
@@ -17,8 +17,8 @@ async function fixture() {
       name,
       {
         keyType,
-        keys: randomArray(generators[keyType]),
-        values: randomArray(generators[valueType]),
+        keys: Array.from({ length: 3 }, generators[keyType]),
+        values: Array.from({ length: 3 }, generators[valueType]),
         zeroValue: generators[valueType].zero,
         methods: mapValues(
           {

+ 2 - 2
test/utils/structs/EnumerableSet.test.js

@@ -2,7 +2,7 @@ const { ethers } = require('hardhat');
 const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');
 
 const { mapValues } = require('../../helpers/iterate');
-const { randomArray, generators } = require('../../helpers/random');
+const { generators } = require('../../helpers/random');
 const { TYPES } = require('../../../scripts/generate/templates/EnumerableSet.opts');
 
 const { shouldBehaveLikeSet } = require('./EnumerableSet.behavior');
@@ -23,7 +23,7 @@ async function fixture() {
     TYPES.map(({ name, type }) => [
       type,
       {
-        values: randomArray(generators[type]),
+        values: Array.from({ length: 3 }, generators[type]),
         methods: getMethods(mock, {
           add: `$add(uint256,${type})`,
           remove: `$remove(uint256,${type})`,