Przeglądaj źródła

Clean dirty addresses and booleans (#5195)

Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
Signed-off-by: Hadrien Croubois <hadrien.croubois@gmail.com>
cairo 1 rok temu
rodzic
commit
b1d61079d6

+ 2 - 2
contracts/utils/SlotDerivation.sol

@@ -70,7 +70,7 @@ library SlotDerivation {
      */
     function deriveMapping(bytes32 slot, address key) internal pure returns (bytes32 result) {
         assembly ("memory-safe") {
-            mstore(0x00, key)
+            mstore(0x00, and(key, shr(96, not(0))))
             mstore(0x20, slot)
             result := keccak256(0x00, 0x40)
         }
@@ -81,7 +81,7 @@ library SlotDerivation {
      */
     function deriveMapping(bytes32 slot, bool key) internal pure returns (bytes32 result) {
         assembly ("memory-safe") {
-            mstore(0x00, key)
+            mstore(0x00, iszero(iszero(key)))
             mstore(0x20, slot)
             result := keccak256(0x00, 0x40)
         }

+ 5 - 0
scripts/generate/helpers/sanitize.js

@@ -0,0 +1,5 @@
+module.exports = {
+  address: expr => `and(${expr}, shr(96, not(0)))`,
+  bool: expr => `iszero(iszero(${expr}))`,
+  bytes: (expr, size) => `and(${expr}, shl(${256 - 8 * size}, not(0)))`,
+};

+ 5 - 4
scripts/generate/templates/Packing.js

@@ -1,4 +1,5 @@
 const format = require('../format-lines');
+const sanitize = require('../helpers/sanitize');
 const { product } = require('../../helpers');
 const { SIZES } = require('./Packing.opts');
 
@@ -44,8 +45,8 @@ function pack_${left}_${right}(bytes${left} left, bytes${right} right) internal
   left + right
 } result) {
     assembly ("memory-safe") {
-        left := and(left, shl(${256 - 8 * left}, not(0)))
-        right := and(right, shl(${256 - 8 * right}, not(0)))
+        left := ${sanitize.bytes('left', left)}
+        right := ${sanitize.bytes('right', right)}
         result := or(left, shr(${8 * left}, right))
     }
 }
@@ -55,7 +56,7 @@ const extract = (outer, inner) => `\
 function extract_${outer}_${inner}(bytes${outer} self, uint8 offset) internal pure returns (bytes${inner} result) {
     if (offset > ${outer - inner}) revert OutOfRangeAccess();
     assembly ("memory-safe") {
-        result := and(shl(mul(8, offset), self), shl(${256 - 8 * inner}, not(0)))
+        result := ${sanitize.bytes('shl(mul(8, offset), self)', inner)}
     }
 }
 `;
@@ -64,7 +65,7 @@ const replace = (outer, inner) => `\
 function replace_${outer}_${inner}(bytes${outer} self, bytes${inner} value, uint8 offset) internal pure returns (bytes${outer} result) {
     bytes${inner} oldValue = extract_${outer}_${inner}(self, offset);
     assembly ("memory-safe") {
-        value := and(value, shl(${256 - 8 * inner}, not(0)))
+        value := ${sanitize.bytes('value', inner)}
         result := xor(self, shr(mul(8, offset), xor(oldValue, value)))
     }
 }

+ 2 - 0
scripts/generate/templates/Slot.opts.js

@@ -10,4 +10,6 @@ const TYPES = [
   { type: 'bytes', isValueType: false },
 ].map(type => Object.assign(type, { name: type.name ?? capitalize(type.type) }));
 
+Object.assign(TYPES, Object.fromEntries(TYPES.map(entry => [entry.type, entry])));
+
 module.exports = { TYPES };

+ 2 - 1
scripts/generate/templates/SlotDerivation.js

@@ -1,4 +1,5 @@
 const format = require('../format-lines');
+const sanitize = require('../helpers/sanitize');
 const { TYPES } = require('./Slot.opts');
 
 const header = `\
@@ -77,7 +78,7 @@ const mapping = ({ type }) => `\
  */
 function deriveMapping(bytes32 slot, ${type} key) internal pure returns (bytes32 result) {
     assembly ("memory-safe") {
-        mstore(0x00, key)
+        mstore(0x00, ${(sanitize[type] ?? (x => x))('key')})
         mstore(0x20, slot)
         result := keccak256(0x00, 0x40)
     }

+ 14 - 0
scripts/generate/templates/SlotDerivation.t.js

@@ -61,6 +61,18 @@ function testSymbolicDeriveMapping${name}(${type} key) public {
 }
 `;
 
+const mappingDirty = ({ type, name }) => `\
+function testSymbolicDeriveMapping${name}Dirty(bytes32 dirtyKey) public {
+    ${type} key;
+    assembly {
+        key := dirtyKey
+    }
+
+    // run the "normal" test using a potentially dirty value
+    testSymbolicDeriveMapping${name}(key);
+}
+`;
+
 const boundedMapping = ({ type, name }) => `\
 mapping(${type} => bytes) private _${type}Mapping;
 
@@ -107,6 +119,8 @@ module.exports = format(
           })),
         ),
       ).map(type => (type.isValueType ? mapping(type) : boundedMapping(type))),
+      mappingDirty(TYPES.bool),
+      mappingDirty(TYPES.address),
     ),
   ).trimEnd(),
   '}',

+ 20 - 0
test/utils/SlotDerivation.t.sol

@@ -225,4 +225,24 @@ contract SlotDerivationTest is Test, SymTest {
 
         assertEq(baseSlot.deriveMapping(key), derivedSlot);
     }
+
+    function testSymbolicDeriveMappingBooleanDirty(bytes32 dirtyKey) public {
+        bool key;
+        assembly {
+            key := dirtyKey
+        }
+
+        // run the "normal" test using a potentially dirty value
+        testSymbolicDeriveMappingBoolean(key);
+    }
+
+    function testSymbolicDeriveMappingAddressDirty(bytes32 dirtyKey) public {
+        address key;
+        assembly {
+            key := dirtyKey
+        }
+
+        // run the "normal" test using a potentially dirty value
+        testSymbolicDeriveMappingAddress(key);
+    }
 }