Browse Source

Add new `clz(bytes)` and `clz(uint256)` functions (#5725)

Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
Ernesto García 2 months ago
parent
commit
d66a0ce63f

+ 5 - 0
.changeset/fast-beans-pull.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`Bytes`: Add a `clz` function to count the leading zero bits in a `bytes` buffer.

+ 5 - 0
.changeset/whole-cats-find.md

@@ -0,0 +1,5 @@
+---
+'openzeppelin-solidity': minor
+---
+
+`Math`: Add a `clz` function to count the leading zero bits in a `uint256` value.

+ 15 - 1
contracts/utils/Bytes.sol

@@ -128,7 +128,7 @@ library Bytes {
         return buffer;
     }
 
-    /*
+    /**
      * @dev Returns true if the two byte buffers are equal.
      */
     function equal(bytes memory a, bytes memory b) internal pure returns (bool) {
@@ -187,6 +187,20 @@ library Bytes {
         return (value >> 8) | (value << 8);
     }
 
+    /**
+     * @dev Counts the number of leading zero bits a bytes array. Returns `8 * buffer.length`
+     * if the buffer is all zeros.
+     */
+    function clz(bytes memory buffer) internal pure returns (uint256) {
+        for (uint256 i = 0; i < buffer.length; i += 32) {
+            bytes32 chunk = _unsafeReadBytesOffset(buffer, i);
+            if (chunk != bytes32(0)) {
+                return Math.min(8 * i + Math.clz(uint256(chunk)), 8 * buffer.length);
+            }
+        }
+        return 8 * buffer.length;
+    }
+
     /**
      * @dev Reads a bytes32 from a bytes array without bounds checking.
      *

+ 7 - 0
contracts/utils/math/Math.sol

@@ -746,4 +746,11 @@ library Math {
     function unsignedRoundsUp(Rounding rounding) internal pure returns (bool) {
         return uint8(rounding) % 2 == 1;
     }
+
+    /**
+     * @dev Counts the number of leading zero bits in a uint256.
+     */
+    function clz(uint256 x) internal pure returns (uint256) {
+        return ternary(x == 0, 256, 255 - log2(x));
+    }
 }

+ 28 - 0
test/utils/Bytes.t.sol

@@ -196,6 +196,34 @@ contract BytesTest is Test {
         assertEq(Bytes.reverseBytes2(_dirtyBytes2(Bytes.reverseBytes2(value))), value);
     }
 
+    // CLZ (Count Leading Zeros)
+    function testClz(bytes memory buffer) public pure {
+        uint256 result = Bytes.clz(buffer);
+
+        // index and offset of the first non zero bit
+        uint256 index = result / 8;
+        uint256 offset = result % 8;
+
+        // Result should never exceed buffer length
+        assertLe(index, buffer.length);
+
+        // All bytes before index position must be zero
+        for (uint256 i = 0; i < index; ++i) {
+            assertEq(buffer[i], 0);
+        }
+
+        // If index < buffer.length, byte at index position must be non-zero
+        if (index < buffer.length) {
+            // bit at position offset must be non zero
+            bytes1 singleBitMask = bytes1(0x80) >> offset;
+            assertEq(buffer[index] & singleBitMask, singleBitMask);
+
+            // all bits before offset must be zero
+            bytes1 multiBitsMask = bytes1(0xff) << (8 - offset);
+            assertEq(buffer[index] & multiBitsMask, 0);
+        }
+    }
+
     // Helpers
     function _dirtyBytes16(bytes16 value) private pure returns (bytes16 dirty) {
         assembly ("memory-safe") {

+ 87 - 0
test/utils/Bytes.test.js

@@ -112,6 +112,93 @@ describe('Bytes', function () {
     });
   });
 
+  describe('clz bytes', function () {
+    it('empty buffer', async function () {
+      await expect(this.mock.$clz('0x')).to.eventually.equal(0);
+    });
+
+    it('single zero byte', async function () {
+      await expect(this.mock.$clz('0x00')).to.eventually.equal(8);
+    });
+
+    it('single non-zero byte', async function () {
+      await expect(this.mock.$clz('0x01')).to.eventually.equal(7);
+      await expect(this.mock.$clz('0xff')).to.eventually.equal(0);
+    });
+
+    it('multiple leading zeros', async function () {
+      await expect(this.mock.$clz('0x0000000001')).to.eventually.equal(39);
+      await expect(
+        this.mock.$clz('0x0000000000000000000000000000000000000000000000000000000000000001'),
+      ).to.eventually.equal(255);
+    });
+
+    it('all zeros of various lengths', async function () {
+      await expect(this.mock.$clz('0x00000000')).to.eventually.equal(32);
+      await expect(
+        this.mock.$clz('0x0000000000000000000000000000000000000000000000000000000000000000'),
+      ).to.eventually.equal(256);
+
+      // Complete chunks
+      await expect(this.mock.$clz('0x' + '00'.repeat(32) + '01')).to.eventually.equal(263); // 32*8+7
+      await expect(this.mock.$clz('0x' + '00'.repeat(64) + '01')).to.eventually.equal(519); // 64*8+7
+
+      // Partial last chunk
+      await expect(this.mock.$clz('0x' + '00'.repeat(33) + '01')).to.eventually.equal(271); // 33*8+7
+      await expect(this.mock.$clz('0x' + '00'.repeat(34) + '01')).to.eventually.equal(279); // 34*8+7
+      await expect(this.mock.$clz('0x' + '00'.repeat(40) + '01' + '00'.repeat(9))).to.eventually.equal(327); // 40*8+7
+      await expect(this.mock.$clz('0x' + '00'.repeat(50))).to.eventually.equal(400); // 50*8
+
+      // First byte of each chunk non-zero
+      await expect(this.mock.$clz('0x80' + '00'.repeat(31))).to.eventually.equal(0);
+      await expect(this.mock.$clz('0x01' + '00'.repeat(31))).to.eventually.equal(7);
+      await expect(this.mock.$clz('0x' + '00'.repeat(32) + '80' + '00'.repeat(31))).to.eventually.equal(256); // 32*8
+      await expect(this.mock.$clz('0x' + '00'.repeat(32) + '01' + '00'.repeat(31))).to.eventually.equal(263); // 32*8+7
+
+      // Last byte of each chunk non-zero
+      await expect(this.mock.$clz('0x' + '00'.repeat(31) + '01')).to.eventually.equal(255); // 31*8+7
+      await expect(this.mock.$clz('0x' + '00'.repeat(63) + '01')).to.eventually.equal(511); // 63*8+7
+
+      // Middle byte of each chunk non-zero
+      await expect(this.mock.$clz('0x' + '00'.repeat(16) + '01' + '00'.repeat(15))).to.eventually.equal(135); // 16*8+7
+      await expect(this.mock.$clz('0x' + '00'.repeat(32) + '01' + '00'.repeat(31))).to.eventually.equal(263); // 32*8+7
+      await expect(this.mock.$clz('0x' + '00'.repeat(48) + '01' + '00'.repeat(47))).to.eventually.equal(391); // 48*8+7
+      await expect(this.mock.$clz('0x' + '00'.repeat(64) + '01' + '00'.repeat(63))).to.eventually.equal(519); // 64*8+7
+    });
+  });
+
+  describe('equal', function () {
+    it('identical buffers', async function () {
+      await expect(this.mock.$equal(lorem, lorem)).to.eventually.be.true;
+    });
+
+    it('same content', async function () {
+      const copy = new Uint8Array(lorem);
+      await expect(this.mock.$equal(lorem, copy)).to.eventually.be.true;
+    });
+
+    it('different content', async function () {
+      const different = ethers.toUtf8Bytes('Different content');
+      await expect(this.mock.$equal(lorem, different)).to.eventually.be.false;
+    });
+
+    it('different lengths', async function () {
+      const shorter = lorem.slice(0, 10);
+      await expect(this.mock.$equal(lorem, shorter)).to.eventually.be.false;
+    });
+
+    it('empty buffers', async function () {
+      const empty1 = new Uint8Array(0);
+      const empty2 = new Uint8Array(0);
+      await expect(this.mock.$equal(empty1, empty2)).to.eventually.be.true;
+    });
+
+    it('one empty one not', async function () {
+      const empty = new Uint8Array(0);
+      await expect(this.mock.$equal(lorem, empty)).to.eventually.be.false;
+    });
+  });
+
   describe('reverseBits', function () {
     describe('reverseBytes32', function () {
       it('reverses bytes correctly', async function () {

+ 19 - 0
test/utils/math/Math.t.sol

@@ -308,6 +308,25 @@ contract MathTest is Test {
         }
     }
 
+    function testSymbolicCountLeadingZeroes(uint256 x) public pure {
+        uint256 result = Math.clz(x);
+
+        if (x == 0) {
+            assertEq(result, 256);
+        } else {
+            // result in [0, 255]
+            assertLe(result, 255);
+
+            // bit at position offset must be non zero
+            uint256 singleBitMask = uint256(1) << (255 - result);
+            assertEq(x & singleBitMask, singleBitMask);
+
+            // all bits before offset must be zero
+            uint256 multiBitsMask = type(uint256).max << (256 - result);
+            assertEq(x & multiBitsMask, 0);
+        }
+    }
+
     // Helpers
     function _asRounding(uint8 r) private pure returns (Math.Rounding) {
         vm.assume(r < uint8(type(Math.Rounding).max));

+ 33 - 0
test/utils/math/Math.test.js

@@ -710,4 +710,37 @@ describe('Math', function () {
       });
     });
   });
+
+  describe('clz', function () {
+    it('zero value', async function () {
+      await expect(this.mock.$clz(0)).to.eventually.equal(256);
+    });
+
+    it('small values', async function () {
+      await expect(this.mock.$clz(1)).to.eventually.equal(255);
+      await expect(this.mock.$clz(255)).to.eventually.equal(248);
+    });
+
+    it('larger values', async function () {
+      await expect(this.mock.$clz(256)).to.eventually.equal(247);
+      await expect(this.mock.$clz(0xff00)).to.eventually.equal(240);
+      await expect(this.mock.$clz(0x10000)).to.eventually.equal(239);
+    });
+
+    it('max value', async function () {
+      await expect(this.mock.$clz(ethers.MaxUint256)).to.eventually.equal(0);
+    });
+
+    it('specific patterns', async function () {
+      await expect(
+        this.mock.$clz('0x0000000000000000000000000000000000000000000000000000000000000100'),
+      ).to.eventually.equal(247);
+      await expect(
+        this.mock.$clz('0x0000000000000000000000000000000000000000000000000000000000010000'),
+      ).to.eventually.equal(239);
+      await expect(
+        this.mock.$clz('0x0000000000000000000000000000000000000000000000000000000001000000'),
+      ).to.eventually.equal(231);
+    });
+  });
 });