Bladeren bron

Add `sort` in memory to Arrays library (#4846)

Co-authored-by: RenanSouza2 <renan.rodrigues.souza1@gmail.com>
Co-authored-by: Ernesto García <ernestognw@gmail.com>
Hadrien Croubois 1 jaar geleden
bovenliggende
commit
0a757ec463

+ 5 - 0
.changeset/dirty-cobras-smile.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`Arrays`: add a `sort` function.

+ 76 - 2
contracts/utils/Arrays.sol

@@ -12,6 +12,69 @@ import {Math} from "./math/Math.sol";
 library Arrays {
     using StorageSlot for bytes32;
 
+    /**
+     * @dev Sort an array (in memory) in increasing order.
+     *
+     * This function does the sorting "in place", meaning that it overrides the input. The object is returned for
+     * convenience, but that returned value can be discarded safely if the caller has a memory pointer to the array.
+     *
+     * NOTE: this function's cost is `O(n · log(n))` in average and `O(n²)` in the worst case, with n the length of the
+     * array. Using it in view functions that are executed through `eth_call` is safe, but one should be very careful
+     * when executing this as part of a transaction. If the array being sorted is too large, the sort operation may
+     * consume more gas than is available in a block, leading to potential DoS.
+     */
+    function sort(uint256[] memory array) internal pure returns (uint256[] memory) {
+        _quickSort(array, 0, array.length);
+        return array;
+    }
+
+    /**
+     * @dev Performs a quick sort on an array in memory. The array is sorted in increasing order.
+     *
+     * Invariant: `i <= j <= array.length`. This is the case when initially called by {sort} and is preserved in
+     * subcalls.
+     */
+    function _quickSort(uint256[] memory array, uint256 i, uint256 j) private pure {
+        unchecked {
+            // Can't overflow given `i <= j`
+            if (j - i < 2) return;
+
+            // Use first element as pivot
+            uint256 pivot = unsafeMemoryAccess(array, i);
+            // Position where the pivot should be at the end of the loop
+            uint256 index = i;
+
+            for (uint256 k = i + 1; k < j; ++k) {
+                // Unsafe access is safe given `k < j <= array.length`.
+                if (unsafeMemoryAccess(array, k) < pivot) {
+                    // If array[k] is smaller than the pivot, we increment the index and move array[k] there.
+                    _swap(array, ++index, k);
+                }
+            }
+
+            // Swap pivot into place
+            _swap(array, i, index);
+
+            _quickSort(array, i, index); // Sort the left side of the pivot
+            _quickSort(array, index + 1, j); // Sort the right side of the pivot
+        }
+    }
+
+    /**
+     * @dev Swaps the elements at positions `i` and `j` in the `arr` array.
+     */
+    function _swap(uint256[] memory arr, uint256 i, uint256 j) private pure {
+        assembly {
+            let start := add(arr, 0x20) // Pointer to the first element of the array
+            let pos_i := add(start, mul(i, 0x20))
+            let pos_j := add(start, mul(j, 0x20))
+            let val_i := mload(pos_i)
+            let val_j := mload(pos_j)
+            mstore(pos_i, val_j)
+            mstore(pos_j, val_i)
+        }
+    }
+
     /**
      * @dev Searches a sorted `array` and returns the first index that contains
      * a value greater or equal to `element`. If no such index exists (i.e. all
@@ -238,7 +301,7 @@ library Arrays {
      *
      * WARNING: Only use if you are certain `pos` is lower than the array length.
      */
-    function unsafeMemoryAccess(uint256[] memory arr, uint256 pos) internal pure returns (uint256 res) {
+    function unsafeMemoryAccess(address[] memory arr, uint256 pos) internal pure returns (address res) {
         assembly {
             res := mload(add(add(arr, 0x20), mul(pos, 0x20)))
         }
@@ -249,7 +312,18 @@ library Arrays {
      *
      * WARNING: Only use if you are certain `pos` is lower than the array length.
      */
-    function unsafeMemoryAccess(address[] memory arr, uint256 pos) internal pure returns (address res) {
+    function unsafeMemoryAccess(bytes32[] memory arr, uint256 pos) internal pure returns (bytes32 res) {
+        assembly {
+            res := mload(add(add(arr, 0x20), mul(pos, 0x20)))
+        }
+    }
+
+    /**
+     * @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check.
+     *
+     * WARNING: Only use if you are certain `pos` is lower than the array length.
+     */
+    function unsafeMemoryAccess(uint256[] memory arr, uint256 pos) internal pure returns (uint256 res) {
         assembly {
             res := mload(add(add(arr, 0x20), mul(pos, 0x20)))
         }

+ 2 - 2
scripts/generate/templates/Checkpoints.t.js

@@ -7,8 +7,8 @@ const header = `\
 pragma solidity ^0.8.20;
 
 import {Test} from "forge-std/Test.sol";
-import {SafeCast} from "../../../contracts/utils/math/SafeCast.sol";
-import {Checkpoints} from "../../../contracts/utils/structs/Checkpoints.sol";
+import {SafeCast} from "@openzeppelin/contracts/utils/math/SafeCast.sol";
+import {Checkpoints} from "@openzeppelin/contracts/utils/structs/Checkpoints.sol";
 `;
 
 /* eslint-disable max-len */

+ 15 - 0
test/utils/Arrays.t.sol

@@ -0,0 +1,15 @@
+// SPDX-License-Identifier: MIT
+
+pragma solidity ^0.8.20;
+
+import {Test} from "forge-std/Test.sol";
+import {Arrays} from "@openzeppelin/contracts/utils/Arrays.sol";
+
+contract ArraysTest is Test {
+    function testSort(uint256[] memory values) public {
+        Arrays.sort(values);
+        for (uint256 i = 1; i < values.length; ++i) {
+            assertLe(values[i - 1], values[i]);
+        }
+    }
+}

+ 84 - 21
test/utils/Arrays.test.js

@@ -16,9 +16,56 @@ const upperBound = (array, value) => {
   return i == -1 ? array.length : i;
 };
 
+// By default, js "sort" cast to string and then sort in alphabetical order. Use this to sort numbers.
+const compareNumbers = (a, b) => (a > b ? 1 : a < b ? -1 : 0);
+
 const hasDuplicates = array => array.some((v, i) => array.indexOf(v) != i);
 
 describe('Arrays', function () {
+  const fixture = async () => {
+    return { mock: await ethers.deployContract('$Arrays') };
+  };
+
+  beforeEach(async function () {
+    Object.assign(this, await loadFixture(fixture));
+  });
+
+  describe('sort', function () {
+    for (const length of [0, 1, 2, 8, 32, 128]) {
+      it(`sort array of length ${length}`, async function () {
+        this.elements = randomArray(generators.uint256, length);
+        this.expected = Array.from(this.elements).sort(compareNumbers);
+      });
+
+      if (length > 1) {
+        it(`sort array of length ${length} (identical elements)`, async function () {
+          this.elements = Array(length).fill(generators.uint256());
+          this.expected = this.elements;
+        });
+
+        it(`sort array of length ${length} (already sorted)`, async function () {
+          this.elements = randomArray(generators.uint256, length).sort(compareNumbers);
+          this.expected = this.elements;
+        });
+
+        it(`sort array of length ${length} (sorted in reverse order)`, async function () {
+          this.elements = randomArray(generators.uint256, length).sort(compareNumbers).reverse();
+          this.expected = Array.from(this.elements).reverse();
+        });
+
+        it(`sort array of length ${length} (almost sorted)`, async function () {
+          this.elements = randomArray(generators.uint256, length).sort(compareNumbers);
+          this.expected = Array.from(this.elements);
+          // rotate (move the last element to the front) for an almost sorted effect
+          this.elements.unshift(this.elements.pop());
+        });
+      }
+    }
+    afterEach(async function () {
+      expect(await this.mock.$sort(this.elements)).to.deep.equal(this.expected);
+    });
+  });
+
   describe('search', function () {
     for (const [title, { array, tests }] of Object.entries({
       'Even number of elements': {
@@ -74,7 +121,7 @@ describe('Arrays', function () {
     })) {
       describe(title, function () {
         const fixture = async () => {
-          return { mock: await ethers.deployContract('Uint256ArraysMock', [array]) };
+          return { instance: await ethers.deployContract('Uint256ArraysMock', [array]) };
         };
 
         beforeEach(async function () {
@@ -86,20 +133,20 @@ describe('Arrays', function () {
             it('[deprecated] findUpperBound', async function () {
               // findUpperBound does not support duplicated
               if (hasDuplicates(array)) {
-                expect(await this.mock.findUpperBound(input)).to.be.equal(upperBound(array, input) - 1);
+                expect(await this.instance.findUpperBound(input)).to.equal(upperBound(array, input) - 1);
               } else {
-                expect(await this.mock.findUpperBound(input)).to.be.equal(lowerBound(array, input));
+                expect(await this.instance.findUpperBound(input)).to.equal(lowerBound(array, input));
               }
             });
 
             it('lowerBound', async function () {
-              expect(await this.mock.lowerBound(input)).to.be.equal(lowerBound(array, input));
-              expect(await this.mock.lowerBoundMemory(array, input)).to.be.equal(lowerBound(array, input));
+              expect(await this.instance.lowerBound(input)).to.equal(lowerBound(array, input));
+              expect(await this.instance.lowerBoundMemory(array, input)).to.equal(lowerBound(array, input));
             });
 
             it('upperBound', async function () {
-              expect(await this.mock.upperBound(input)).to.be.equal(upperBound(array, input));
-              expect(await this.mock.upperBoundMemory(array, input)).to.be.equal(upperBound(array, input));
+              expect(await this.instance.upperBound(input)).to.equal(upperBound(array, input));
+              expect(await this.instance.upperBoundMemory(array, input)).to.equal(upperBound(array, input));
             });
           });
         }
@@ -108,28 +155,44 @@ describe('Arrays', function () {
   });
 
   describe('unsafeAccess', function () {
-    for (const [title, { artifact, elements }] of Object.entries({
+    for (const [type, { artifact, elements }] of Object.entries({
       address: { artifact: 'AddressArraysMock', elements: randomArray(generators.address, 10) },
       bytes32: { artifact: 'Bytes32ArraysMock', elements: randomArray(generators.bytes32, 10) },
       uint256: { artifact: 'Uint256ArraysMock', elements: randomArray(generators.uint256, 10) },
     })) {
-      describe(title, function () {
-        const fixture = async () => {
-          return { mock: await ethers.deployContract(artifact, [elements]) };
-        };
+      describe(type, function () {
+        describe('storage', function () {
+          const fixture = async () => {
+            return { instance: await ethers.deployContract(artifact, [elements]) };
+          };
 
-        beforeEach(async function () {
-          Object.assign(this, await loadFixture(fixture));
-        });
+          beforeEach(async function () {
+            Object.assign(this, await loadFixture(fixture));
+          });
 
-        for (const i in elements) {
-          it(`unsafeAccess within bounds #${i}`, async function () {
-            expect(await this.mock.unsafeAccess(i)).to.equal(elements[i]);
+          for (const i in elements) {
+            it(`unsafeAccess within bounds #${i}`, async function () {
+              expect(await this.instance.unsafeAccess(i)).to.equal(elements[i]);
+            });
+          }
+
+          it('unsafeAccess outside bounds', async function () {
+            await expect(this.instance.unsafeAccess(elements.length)).to.not.be.rejected;
           });
-        }
+        });
+
+        describe('memory', function () {
+          const fragment = `$unsafeMemoryAccess(${type}[] arr, uint256 pos)`;
 
-        it('unsafeAccess outside bounds', async function () {
-          await expect(this.mock.unsafeAccess(elements.length)).to.not.be.rejected;
+          for (const i in elements) {
+            it(`unsafeMemoryAccess within bounds #${i}`, async function () {
+              expect(await this.mock[fragment](elements, i)).to.equal(elements[i]);
+            });
+          }
+
+          it('unsafeMemoryAccess outside bounds', async function () {
+            await expect(this.mock[fragment](elements, elements.length)).to.not.be.rejected;
+          });
         });
       });
     }

+ 0 - 1
test/utils/Base64.t.sol

@@ -3,7 +3,6 @@
 pragma solidity ^0.8.20;
 
 import {Test} from "forge-std/Test.sol";
-
 import {Base64} from "@openzeppelin/contracts/utils/Base64.sol";
 
 contract Base64Test is Test {

+ 2 - 2
test/utils/structs/Checkpoints.t.sol

@@ -4,8 +4,8 @@
 pragma solidity ^0.8.20;
 
 import {Test} from "forge-std/Test.sol";
-import {SafeCast} from "../../../contracts/utils/math/SafeCast.sol";
-import {Checkpoints} from "../../../contracts/utils/structs/Checkpoints.sol";
+import {SafeCast} from "@openzeppelin/contracts/utils/math/SafeCast.sol";
+import {Checkpoints} from "@openzeppelin/contracts/utils/structs/Checkpoints.sol";
 
 contract CheckpointsTrace224Test is Test {
     using Checkpoints for Checkpoints.Trace224;