Эх сурвалжийг харах

Add support for more types in Arrays.sol (#5568)

Hadrien Croubois 6 сар өмнө
parent
commit
8a4eadea51

+ 5 - 0
.changeset/rare-shirts-unite.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`Arrays`: Add `unsafeAccess`, `unsafeMemoryAccess` and `unsafeSetLength` for `bytes[]` and `string[]`.

+ 44 - 0
contracts/mocks/ArraysMock.sol

@@ -125,3 +125,47 @@ contract Bytes32ArraysMock {
         return _array.length;
     }
 }
+
+contract BytesArraysMock {
+    using Arrays for bytes[];
+
+    bytes[] private _array;
+
+    constructor(bytes[] memory array) {
+        _array = array;
+    }
+
+    function unsafeAccess(uint256 pos) external view returns (bytes memory) {
+        return _array.unsafeAccess(pos).value;
+    }
+
+    function unsafeSetLength(uint256 newLength) external {
+        _array.unsafeSetLength(newLength);
+    }
+
+    function length() external view returns (uint256) {
+        return _array.length;
+    }
+}
+
+contract StringArraysMock {
+    using Arrays for string[];
+
+    string[] private _array;
+
+    constructor(string[] memory array) {
+        _array = array;
+    }
+
+    function unsafeAccess(uint256 pos) external view returns (string memory) {
+        return _array.unsafeAccess(pos).value;
+    }
+
+    function unsafeSetLength(uint256 newLength) external {
+        _array.unsafeSetLength(newLength);
+    }
+
+    function length() external view returns (uint256) {
+        return _array.length;
+    }
+}

+ 70 - 0
contracts/utils/Arrays.sol

@@ -414,6 +414,32 @@ library Arrays {
         return slot.deriveArray().offset(pos).getUint256Slot();
     }
 
+    /**
+     * @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check.
+     *
+     * WARNING: Only use if you are certain `pos` is lower than the array length.
+     */
+    function unsafeAccess(bytes[] storage arr, uint256 pos) internal pure returns (StorageSlot.BytesSlot storage) {
+        bytes32 slot;
+        assembly ("memory-safe") {
+            slot := arr.slot
+        }
+        return slot.deriveArray().offset(pos).getBytesSlot();
+    }
+
+    /**
+     * @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check.
+     *
+     * WARNING: Only use if you are certain `pos` is lower than the array length.
+     */
+    function unsafeAccess(string[] storage arr, uint256 pos) internal pure returns (StorageSlot.StringSlot storage) {
+        bytes32 slot;
+        assembly ("memory-safe") {
+            slot := arr.slot
+        }
+        return slot.deriveArray().offset(pos).getStringSlot();
+    }
+
     /**
      * @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check.
      *
@@ -447,6 +473,28 @@ library Arrays {
         }
     }
 
+    /**
+     * @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check.
+     *
+     * WARNING: Only use if you are certain `pos` is lower than the array length.
+     */
+    function unsafeMemoryAccess(bytes[] memory arr, uint256 pos) internal pure returns (bytes memory res) {
+        assembly {
+            res := mload(add(add(arr, 0x20), mul(pos, 0x20)))
+        }
+    }
+
+    /**
+     * @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check.
+     *
+     * WARNING: Only use if you are certain `pos` is lower than the array length.
+     */
+    function unsafeMemoryAccess(string[] memory arr, uint256 pos) internal pure returns (string memory res) {
+        assembly {
+            res := mload(add(add(arr, 0x20), mul(pos, 0x20)))
+        }
+    }
+
     /**
      * @dev Helper to set the length of a dynamic array. Directly writing to `.length` is forbidden.
      *
@@ -479,4 +527,26 @@ library Arrays {
             sstore(array.slot, len)
         }
     }
+
+    /**
+     * @dev Helper to set the length of a dynamic array. Directly writing to `.length` is forbidden.
+     *
+     * WARNING: this does not clear elements if length is reduced, of initialize elements if length is increased.
+     */
+    function unsafeSetLength(bytes[] storage array, uint256 len) internal {
+        assembly ("memory-safe") {
+            sstore(array.slot, len)
+        }
+    }
+
+    /**
+     * @dev Helper to set the length of a dynamic array. Directly writing to `.length` is forbidden.
+     *
+     * WARNING: this does not clear elements if length is reduced, of initialize elements if length is increased.
+     */
+    function unsafeSetLength(string[] storage array, uint256 len) internal {
+        assembly ("memory-safe") {
+            sstore(array.slot, len)
+        }
+    }
 }

+ 23 - 21
scripts/generate/templates/Arrays.js

@@ -17,7 +17,7 @@ import {Math} from "./math/Math.sol";
 
 const sort = type => `\
 /**
- * @dev Sort an array of ${type} (in memory) following the provided comparator function.
+ * @dev Sort an array of ${type.name} (in memory) following the provided comparator function.
  *
  * This function does the sorting "in place", meaning that it overrides the input. The object is returned for
  * convenience, but that returned value can be discarded safely if the caller has a memory pointer to the array.
@@ -30,11 +30,11 @@ const sort = type => `\
  * IMPORTANT: Consider memory side-effects when using custom comparator functions that access memory in an unsafe way.
  */
 function sort(
-    ${type}[] memory array,
-    function(${type}, ${type}) pure returns (bool) comp
-) internal pure returns (${type}[] memory) {
+    ${type.name}[] memory array,
+    function(${type.name}, ${type.name}) pure returns (bool) comp
+) internal pure returns (${type.name}[] memory) {
     ${
-      type === 'uint256'
+      type.name === 'uint256'
         ? '_quickSort(_begin(array), _end(array), comp);'
         : 'sort(_castToUint256Array(array), _castToUint256Comp(comp));'
     }
@@ -42,10 +42,10 @@ function sort(
 }
 
 /**
- * @dev Variant of {sort} that sorts an array of ${type} in increasing order.
+ * @dev Variant of {sort} that sorts an array of ${type.name} in increasing order.
  */
-function sort(${type}[] memory array) internal pure returns (${type}[] memory) {
-    ${type === 'uint256' ? 'sort(array, Comparators.lt);' : 'sort(_castToUint256Array(array), Comparators.lt);'}
+function sort(${type.name}[] memory array) internal pure returns (${type.name}[] memory) {
+    ${type.name === 'uint256' ? 'sort(array, Comparators.lt);' : 'sort(_castToUint256Array(array), Comparators.lt);'}
     return array;
 }
 `;
@@ -126,8 +126,8 @@ function _swap(uint256 ptr1, uint256 ptr2) private pure {
 `;
 
 const castArray = type => `\
-/// @dev Helper: low level cast ${type} memory array to uint256 memory array
-function _castToUint256Array(${type}[] memory input) private pure returns (uint256[] memory output) {
+/// @dev Helper: low level cast ${type.name} memory array to uint256 memory array
+function _castToUint256Array(${type.name}[] memory input) private pure returns (uint256[] memory output) {
     assembly {
         output := input
     }
@@ -135,9 +135,9 @@ function _castToUint256Array(${type}[] memory input) private pure returns (uint2
 `;
 
 const castComparator = type => `\
-/// @dev Helper: low level cast ${type} comp function to uint256 comp function
+/// @dev Helper: low level cast ${type.name} comp function to uint256 comp function
 function _castToUint256Comp(
-    function(${type}, ${type}) pure returns (bool) input
+    function(${type.name}, ${type.name}) pure returns (bool) input
 ) private pure returns (function(uint256, uint256) pure returns (bool) output) {
     assembly {
         output := input
@@ -320,14 +320,14 @@ const unsafeAccessStorage = type => `\
  *
  * WARNING: Only use if you are certain \`pos\` is lower than the array length.
  */
-function unsafeAccess(${type}[] storage arr, uint256 pos) internal pure returns (StorageSlot.${capitalize(
-  type,
+function unsafeAccess(${type.name}[] storage arr, uint256 pos) internal pure returns (StorageSlot.${capitalize(
+  type.name,
 )}Slot storage) {
     bytes32 slot;
     assembly ("memory-safe") {
         slot := arr.slot
     }
-    return slot.deriveArray().offset(pos).get${capitalize(type)}Slot();
+    return slot.deriveArray().offset(pos).get${capitalize(type.name)}Slot();
 }
 `;
 
@@ -337,7 +337,9 @@ const unsafeAccessMemory = type => `\
  *
  * WARNING: Only use if you are certain \`pos\` is lower than the array length.
  */
-function unsafeMemoryAccess(${type}[] memory arr, uint256 pos) internal pure returns (${type} res) {
+function unsafeMemoryAccess(${type.name}[] memory arr, uint256 pos) internal pure returns (${type.name}${
+  type.isValueType ? '' : ' memory'
+} res) {
     assembly {
         res := mload(add(add(arr, 0x20), mul(pos, 0x20)))
     }
@@ -350,7 +352,7 @@ const unsafeSetLength = type => `\
  *
  * WARNING: this does not clear elements if length is reduced, of initialize elements if length is increased.
  */
-function unsafeSetLength(${type}[] storage array, uint256 len) internal {
+function unsafeSetLength(${type.name}[] storage array, uint256 len) internal {
     assembly ("memory-safe") {
         sstore(array.slot, len)
     }
@@ -367,11 +369,11 @@ module.exports = format(
       'using StorageSlot for bytes32;',
       '',
       // sorting, comparator, helpers and internal
-      sort('uint256'),
-      TYPES.filter(type => type !== 'uint256').map(sort),
+      sort({ name: 'uint256' }),
+      TYPES.filter(type => type.isValueType && type.name !== 'uint256').map(sort),
       quickSort,
-      TYPES.filter(type => type !== 'uint256').map(castArray),
-      TYPES.filter(type => type !== 'uint256').map(castComparator),
+      TYPES.filter(type => type.isValueType && type.name !== 'uint256').map(castArray),
+      TYPES.filter(type => type.isValueType && type.name !== 'uint256').map(castComparator),
       // lookup
       search,
       // unsafe (direct) storage and memory access

+ 7 - 1
scripts/generate/templates/Arrays.opts.js

@@ -1,3 +1,9 @@
-const TYPES = ['address', 'bytes32', 'uint256'];
+const TYPES = [
+  { name: 'address', isValueType: true },
+  { name: 'bytes32', isValueType: true },
+  { name: 'uint256', isValueType: true },
+  { name: 'bytes', isValueType: false },
+  { name: 'string', isValueType: false },
+];
 
 module.exports = { TYPES };

+ 7 - 2
test/helpers/random.js

@@ -5,14 +5,19 @@ const generators = {
   bytes32: () => ethers.hexlify(ethers.randomBytes(32)),
   uint256: () => ethers.toBigInt(ethers.randomBytes(32)),
   int256: () => ethers.toBigInt(ethers.randomBytes(32)) + ethers.MinInt256,
-  hexBytes: length => ethers.hexlify(ethers.randomBytes(length)),
+  bytes: (length = 32) => ethers.hexlify(ethers.randomBytes(length)),
+  string: () => ethers.uuidV4(ethers.randomBytes(32)),
 };
 
 generators.address.zero = ethers.ZeroAddress;
 generators.bytes32.zero = ethers.ZeroHash;
 generators.uint256.zero = 0n;
 generators.int256.zero = 0n;
-generators.hexBytes.zero = '0x';
+generators.bytes.zero = '0x';
+generators.string.zero = '';
+
+// alias hexBytes -> bytes
+generators.hexBytes = generators.bytes;
 
 module.exports = {
   generators,

+ 47 - 43
test/utils/Arrays.test.js

@@ -119,61 +119,63 @@ describe('Arrays', function () {
     }
   });
 
-  for (const type of TYPES) {
-    const elements = Array.from({ length: 10 }, generators[type]);
+  for (const { name, isValueType } of TYPES) {
+    const elements = Array.from({ length: 10 }, generators[name]);
 
-    describe(type, function () {
+    describe(name, function () {
       const fixture = async () => {
-        return { instance: await ethers.deployContract(`${capitalize(type)}ArraysMock`, [elements]) };
+        return { instance: await ethers.deployContract(`${capitalize(name)}ArraysMock`, [elements]) };
       };
 
       beforeEach(async function () {
         Object.assign(this, await loadFixture(fixture));
       });
 
-      describe('sort', function () {
-        for (const length of [0, 1, 2, 8, 32, 128]) {
-          describe(`${type}[] of length ${length}`, function () {
-            beforeEach(async function () {
-              this.array = Array.from({ length }, generators[type]);
-            });
-
-            afterEach(async function () {
-              const expected = Array.from(this.array).sort(comparator);
-              const reversed = Array.from(expected).reverse();
-              expect(await this.instance.sort(this.array)).to.deep.equal(expected);
-              expect(await this.instance.sortReverse(this.array)).to.deep.equal(reversed);
-            });
-
-            it('sort array', async function () {
-              // nothing to do here, beforeEach and afterEach already take care of everything.
-            });
-
-            if (length > 1) {
-              it('sort array for identical elements', async function () {
-                // duplicate the first value to all elements
-                this.array.fill(this.array.at(0));
+      if (isValueType) {
+        describe('sort', function () {
+          for (const length of [0, 1, 2, 8, 32, 128]) {
+            describe(`${name}[] of length ${length}`, function () {
+              beforeEach(async function () {
+                this.array = Array.from({ length }, generators[name]);
               });
 
-              it('sort already sorted array', async function () {
-                // pre-sort the elements
-                this.array.sort(comparator);
+              afterEach(async function () {
+                const expected = Array.from(this.array).sort(comparator);
+                const reversed = Array.from(expected).reverse();
+                expect(await this.instance.sort(this.array)).to.deep.equal(expected);
+                expect(await this.instance.sortReverse(this.array)).to.deep.equal(reversed);
               });
 
-              it('sort reversed array', async function () {
-                // pre-sort in reverse order
-                this.array.sort(comparator).reverse();
+              it('sort array', async function () {
+                // nothing to do here, beforeEach and afterEach already take care of everything.
               });
 
-              it('sort almost sorted array', async function () {
-                // pre-sort + rotate (move the last element to the front) for an almost sorted effect
-                this.array.sort(comparator);
-                this.array.unshift(this.array.pop());
-              });
-            }
-          });
-        }
-      });
+              if (length > 1) {
+                it('sort array for identical elements', async function () {
+                  // duplicate the first value to all elements
+                  this.array.fill(this.array.at(0));
+                });
+
+                it('sort already sorted array', async function () {
+                  // pre-sort the elements
+                  this.array.sort(comparator);
+                });
+
+                it('sort reversed array', async function () {
+                  // pre-sort in reverse order
+                  this.array.sort(comparator).reverse();
+                });
+
+                it('sort almost sorted array', async function () {
+                  // pre-sort + rotate (move the last element to the front) for an almost sorted effect
+                  this.array.sort(comparator);
+                  this.array.unshift(this.array.pop());
+                });
+              }
+            });
+          }
+        });
+      }
 
       describe('unsafeAccess', function () {
         describe('storage', function () {
@@ -197,7 +199,7 @@ describe('Arrays', function () {
         });
 
         describe('memory', function () {
-          const fragment = `$unsafeMemoryAccess(${type}[] arr, uint256 pos)`;
+          const fragment = `$unsafeMemoryAccess(${name}[] arr, uint256 pos)`;
 
           for (const i in elements) {
             it(`unsafeMemoryAccess within bounds #${i}`, async function () {
@@ -211,7 +213,9 @@ describe('Arrays', function () {
 
           it('unsafeMemoryAccess loop around', async function () {
             for (let i = 251n; i < 256n; ++i) {
-              expect(await this.mock[fragment](elements, 2n ** i - 1n)).to.equal(BigInt(elements.length));
+              expect(await this.mock[fragment](elements, 2n ** i - 1n)).to.equal(
+                isValueType ? BigInt(elements.length) : generators[name].zero,
+              );
               expect(await this.mock[fragment](elements, 2n ** i + 0n)).to.equal(elements[0]);
               expect(await this.mock[fragment](elements, 2n ** i + 1n)).to.equal(elements[1]);
             }