浏览代码

fix: expose MerkleTree as a real interface that is accumulator friendly

Reisen 2 年之前
父节点
当前提交
0b7805f285
共有 1 个文件被更改,包括 115 次插入105 次删除
  1. 115 105
      pythnet/pythnet_sdk/src/accumulators/merkle.rs

+ 115 - 105
pythnet/pythnet_sdk/src/accumulators/merkle.rs

@@ -35,120 +35,80 @@ const LEAF_PREFIX: &[u8] = &[0];
 const NODE_PREFIX: &[u8] = &[1];
 const NULL_PREFIX: &[u8] = &[2];
 
-fn hash_leaf<H: Hasher>(leaf: &[u8]) -> H::Hash {
-    H::hashv(&[LEAF_PREFIX, leaf])
-}
-
-fn hash_node<H: Hasher>(l: &H::Hash, r: &H::Hash) -> H::Hash {
-    H::hashv(&[
-        NODE_PREFIX,
-        (if l <= r { l } else { r }).as_ref(),
-        (if l <= r { r } else { l }).as_ref(),
-    ])
-}
-
-fn hash_null<H: Hasher>() -> H::Hash {
-    H::hashv(&[NULL_PREFIX])
-}
-
+/// A MerklePath contains a list of hashes that form a proof for membership in a tree.
 #[derive(Clone, Default, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
 pub struct MerklePath<H: Hasher>(Vec<H::Hash>);
 
+/// A MerkleRoot contains the root hash of a MerkleTree.
 #[derive(Clone, Default, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
 pub struct MerkleRoot<H: Hasher>(H::Hash);
 
+/// A MerkleTree is a binary tree where each node is the hash of its children.
+#[derive(
+    Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize, Serialize, Deserialize, Default,
+)]
+pub struct MerkleTree<H: Hasher = Keccak256> {
+    pub root:  MerkleRoot<H>,
+    #[serde(skip)]
+    pub nodes: Vec<H::Hash>,
+}
+
+/// Implements functionality for using standalone MerkleRoots.
 impl<H: Hasher> MerkleRoot<H> {
+    /// Construct a MerkleRoot from an existing Hash.
     pub fn new(root: H::Hash) -> Self {
         Self(root)
     }
 
+    /// Given a item and corresponding MerklePath, check that it is a valid membership proof.
     pub fn check(&self, proof: MerklePath<H>, item: &[u8]) -> bool {
-        let mut current: <H as Hasher>::Hash = hash_leaf::<H>(item);
+        let mut current: <H as Hasher>::Hash = MerkleTree::<H>::hash_leaf(item);
         for hash in proof.0 {
-            current = hash_node::<H>(&current, &hash);
+            current = MerkleTree::<H>::hash_node(&current, &hash);
         }
         current == self.0
     }
 }
 
+/// Implements functionality for working with MerklePath (proofs).
 impl<H: Hasher> MerklePath<H> {
+    /// Given a Vector of hashes representing a merkle proof, construct a MerklePath.
     pub fn new(path: Vec<H::Hash>) -> Self {
         Self(path)
     }
 }
 
-/// A MerkleAccumulator maintains a Merkle Tree.
-///
-/// The implementation is based on Solana's Merkle Tree implementation. This structure also stores
-/// the items that are in the tree due to the need to look-up the index of an item in the tree in
-/// order to create a proof.
-#[derive(
-    Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize, Serialize, Deserialize, Default,
-)]
-pub struct MerkleAccumulator<H: Hasher = Keccak256> {
-    pub root:  MerkleRoot<H>,
-    #[serde(skip)]
-    pub nodes: Vec<H::Hash>,
-}
-
-// Layout:
-//
-// ```
-// 4 bytes:  magic number
-// 1 byte:   update type
-// 4 byte:   storage id
-// 32 bytes: root hash
-// ```
-//
-// TODO: This code does not belong to MerkleAccumulator, we should be using the wire data types in
-// calling code to wrap this value.
-impl<'a, H: Hasher + 'a> MerkleAccumulator<H> {
-    pub fn serialize(&self, slot: u64, ring_size: u32) -> Vec<u8> {
-        let mut serialized = vec![];
-        serialized.extend_from_slice(0x41555756u32.to_be_bytes().as_ref());
-        serialized.extend_from_slice(0u8.to_be_bytes().as_ref());
-        serialized.extend_from_slice(slot.to_be_bytes().as_ref());
-        serialized.extend_from_slice(ring_size.to_be_bytes().as_ref());
-        serialized.extend_from_slice(self.root.0.as_ref());
-        serialized
-    }
-}
-
-impl<'a, H: Hasher + 'a> Accumulator<'a> for MerkleAccumulator<H> {
+/// Presents an Accumulator friendly interface for MerkleTree.
+impl<'a, H: Hasher + 'a> Accumulator<'a> for MerkleTree<H> {
     type Proof = MerklePath<H>;
 
+    /// Construct a MerkleTree from an iterator of items.
     fn from_set(items: impl Iterator<Item = &'a [u8]>) -> Option<Self> {
         let items: Vec<&[u8]> = items.collect();
         Self::new(&items)
     }
 
+    /// Prove an item is in the tree by returning a MerklePath.
     fn prove(&'a self, item: &[u8]) -> Option<Self::Proof> {
-        let item = hash_leaf::<H>(item);
+        let item = MerkleTree::<H>::hash_leaf(item);
         let index = self.nodes.iter().position(|i| i == &item)?;
         Some(self.find_path(index))
     }
 
-    // NOTE: This `check` call is intended to be generic accross accumulator implementations, but
-    // for a merkle tree the proof does not use the `self` parameter as the proof is standalone
-    // and doesn't need the original nodes. Normally a merkle API would be something like:
-    //
-    // ```
-    // MerkleTree::check(proof)
-    // ```
-    //
-    // or even:
-    //
-    // ```
-    // proof.verify()
-    // ```
-    //
-    // But to stick to the Accumulator trait we do it via the trait method.
+    // NOTE: This `check` call is intended to fit the generic accumulator implementation, but for a
+    // merkle tree the proof does not usually need the `self` parameter as the proof is standalone
+    // and doesn't need the original nodes.
     fn check(&'a self, proof: Self::Proof, item: &[u8]) -> bool {
-        self.root.check(proof, item)
+        self.verify_path(proof, item)
     }
 }
 
-impl<H: Hasher> MerkleAccumulator<H> {
+/// Implement a MerkleTree-specific interface for interacting with trees.
+impl<H: Hasher> MerkleTree<H> {
+    /// Construct a new MerkleTree from a list of byte slices.
+    ///
+    /// This list does not have to be a set which means the tree may contain duplicate items. It is
+    /// up to the caller to enforce a strict set-like object if that is desired.
     pub fn new(items: &[&[u8]]) -> Option<Self> {
         if items.is_empty() {
             return None;
@@ -160,9 +120,9 @@ impl<H: Hasher> MerkleAccumulator<H> {
         // Filling the leaf hashes
         for i in 0..(1 << depth) {
             if i < items.len() {
-                tree[(1 << depth) + i] = hash_leaf::<H>(items[i]);
+                tree[(1 << depth) + i] = MerkleTree::<H>::hash_leaf(items[i]);
             } else {
-                tree[(1 << depth) + i] = hash_null::<H>();
+                tree[(1 << depth) + i] = MerkleTree::<H>::hash_null();
             }
         }
 
@@ -172,7 +132,7 @@ impl<H: Hasher> MerkleAccumulator<H> {
             let level_num_nodes = 1 << level;
             for i in 0..level_num_nodes {
                 let id = (1 << level) + i;
-                tree[id] = hash_node::<H>(&tree[id * 2], &tree[id * 2 + 1]);
+                tree[id] = MerkleTree::<H>::hash_node(&tree[id * 2], &tree[id * 2 + 1]);
             }
         }
 
@@ -182,7 +142,8 @@ impl<H: Hasher> MerkleAccumulator<H> {
         })
     }
 
-    fn find_path(&self, mut index: usize) -> MerklePath<H> {
+    /// Produces a Proof of membership for an index in the tree.
+    pub fn find_path(&self, mut index: usize) -> MerklePath<H> {
         let mut path = Vec::new();
         while index > 1 {
             path.push(self.nodes[index ^ 1]);
@@ -190,6 +151,53 @@ impl<H: Hasher> MerkleAccumulator<H> {
         }
         MerklePath::new(path)
     }
+
+    /// Check if a given MerklePath is a valid proof for a corresponding item.
+    pub fn verify_path(&self, proof: MerklePath<H>, item: &[u8]) -> bool {
+        self.root.check(proof, item)
+    }
+
+    #[inline]
+    pub fn hash_leaf(leaf: &[u8]) -> H::Hash {
+        H::hashv(&[LEAF_PREFIX, leaf])
+    }
+
+    #[inline]
+    pub fn hash_node(l: &H::Hash, r: &H::Hash) -> H::Hash {
+        H::hashv(&[
+            NODE_PREFIX,
+            (if l <= r { l } else { r }).as_ref(),
+            (if l <= r { r } else { l }).as_ref(),
+        ])
+    }
+
+    #[inline]
+    pub fn hash_null() -> H::Hash {
+        H::hashv(&[NULL_PREFIX])
+    }
+
+    /// Serialize a MerkleTree into a Vec<u8>.
+    ///
+    ///Layout:
+    ///
+    /// ```rust,ignore
+    /// 4 bytes:  magic number
+    /// 1 byte:   update type
+    /// 4 byte:   storage id
+    /// 32 bytes: root hash
+    /// ```
+    ///
+    /// TODO: This code does not belong to MerkleTree, we should be using the wire data types in
+    /// calling code to wrap this value.
+    pub fn serialize(&self, slot: u64, ring_size: u32) -> Vec<u8> {
+        let mut serialized = vec![];
+        serialized.extend_from_slice(0x41555756u32.to_be_bytes().as_ref());
+        serialized.extend_from_slice(0u8.to_be_bytes().as_ref());
+        serialized.extend_from_slice(slot.to_be_bytes().as_ref());
+        serialized.extend_from_slice(ring_size.to_be_bytes().as_ref());
+        serialized.extend_from_slice(self.root.0.as_ref());
+        serialized
+    }
 }
 
 #[cfg(test)]
@@ -231,12 +239,12 @@ mod test {
     }
 
     #[derive(Debug)]
-    struct MerkleAccumulatorDataWrapper {
-        pub accumulator: MerkleAccumulator,
+    struct MerkleTreeDataWrapper {
+        pub accumulator: MerkleTree,
         pub data:        BTreeSet<Vec<u8>>,
     }
 
-    impl Arbitrary for MerkleAccumulatorDataWrapper {
+    impl Arbitrary for MerkleTreeDataWrapper {
         type Parameters = usize;
 
         fn arbitrary_with(size: Self::Parameters) -> Self::Strategy {
@@ -248,9 +256,8 @@ mod test {
             .prop_map(|v| {
                 let data: BTreeSet<Vec<u8>> = v.into_iter().collect();
                 let accumulator =
-                    MerkleAccumulator::<Keccak256>::from_set(data.iter().map(|i| i.as_ref()))
-                        .unwrap();
-                MerkleAccumulatorDataWrapper { accumulator, data }
+                    MerkleTree::<Keccak256>::from_set(data.iter().map(|i| i.as_ref())).unwrap();
+                MerkleTreeDataWrapper { accumulator, data }
             })
             .boxed()
         }
@@ -303,14 +310,14 @@ mod test {
         set.insert(&item_b);
         set.insert(&item_c);
 
-        let accumulator = MerkleAccumulator::<Keccak256>::from_set(set.into_iter()).unwrap();
+        let accumulator = MerkleTree::<Keccak256>::from_set(set.into_iter()).unwrap();
         let proof = accumulator.prove(&item_a).unwrap();
 
-        assert!(accumulator.check(proof, &item_a));
+        assert!(accumulator.verify_path(proof, &item_a));
         let proof = accumulator.prove(&item_a).unwrap();
         assert_eq!(size_of::<<Keccak256 as Hasher>::Hash>(), 32);
 
-        assert!(!accumulator.check(proof, &item_d));
+        assert!(!accumulator.verify_path(proof, &item_d));
     }
 
     #[test]
@@ -327,11 +334,11 @@ mod test {
         set.insert(&item_b);
 
         // Attempt to prove empty proofs that are not in the accumulator.
-        let accumulator = MerkleAccumulator::<Keccak256>::from_set(set.into_iter()).unwrap();
+        let accumulator = MerkleTree::<Keccak256>::from_set(set.into_iter()).unwrap();
         let proof = MerklePath::<Keccak256>::default();
-        assert!(!accumulator.check(proof, &item_a));
+        assert!(!accumulator.verify_path(proof, &item_a));
         let proof = MerklePath::<Keccak256>(vec![Default::default()]);
-        assert!(!accumulator.check(proof, &item_a));
+        assert!(!accumulator.verify_path(proof, &item_a));
     }
 
     #[test]
@@ -349,7 +356,7 @@ mod test {
         set.insert(&item_d);
 
         // Accumulate
-        let accumulator = MerkleAccumulator::<Keccak256>::from_set(set.into_iter()).unwrap();
+        let accumulator = MerkleTree::<Keccak256>::from_set(set.into_iter()).unwrap();
 
         // For each hash in the resulting proofs, corrupt one hash and confirm that the proof
         // cannot pass check.
@@ -358,7 +365,7 @@ mod test {
             for (i, _) in proof.0.iter().enumerate() {
                 let mut corrupted_proof = proof.clone();
                 corrupted_proof.0[i] = Default::default();
-                assert!(!accumulator.check(corrupted_proof, item));
+                assert!(!accumulator.verify_path(corrupted_proof, item));
             }
         }
     }
@@ -381,9 +388,9 @@ mod test {
         set.insert(&item_d);
 
         // Accumulate into a 2 level tree.
-        let accumulator = MerkleAccumulator::<Keccak256>::from_set(set.into_iter()).unwrap();
+        let accumulator = MerkleTree::<Keccak256>::from_set(set.into_iter()).unwrap();
         let proof = accumulator.prove(&item_a).unwrap();
-        assert!(accumulator.check(proof, &item_a));
+        assert!(accumulator.verify_path(proof, &item_a));
 
         // We now have a 2 level tree with 4 nodes:
         //
@@ -410,7 +417,7 @@ mod test {
         // implementation did not use a different hash for nodes and leaves then it is possible to
         // falsely prove `A` was in the original tree by tricking the implementation into performing
         // H(a || b) at the leaf.
-        let faulty_accumulator = MerkleAccumulator::<Keccak256> {
+        let faulty_accumulator = MerkleTree::<Keccak256> {
             root:  accumulator.root,
             nodes: vec![
                 accumulator.nodes[0],
@@ -422,30 +429,33 @@ mod test {
 
         // `a || b` is the concatenation of a and b, which when hashed without pre-image fixes in
         // place generates A as a leaf rather than a pair node.
-        let fake_leaf_A = &[
-            hash_leaf::<Keccak256>(&item_b),
-            hash_leaf::<Keccak256>(&item_a),
+        let fake_leaf = &[
+            MerkleTree::<Keccak256>::hash_leaf(&item_b),
+            MerkleTree::<Keccak256>::hash_leaf(&item_a),
         ]
         .concat();
 
         // Confirm our combined hash existed as a node pair in the original tree.
-        assert_eq!(hash_leaf::<Keccak256>(fake_leaf_A), accumulator.nodes[2]);
+        assert_eq!(
+            MerkleTree::<Keccak256>::hash_leaf(fake_leaf),
+            accumulator.nodes[2]
+        );
 
         // Now we can try and prove leaf membership in the faulty accumulator. NOTE: this should
         // fail but to confirm that the test is actually correct you can remove the PREFIXES from
         // the hash functions and this test will erroneously pass.
-        let proof = faulty_accumulator.prove(fake_leaf_A).unwrap();
-        assert!(faulty_accumulator.check(proof, fake_leaf_A));
+        let proof = faulty_accumulator.prove(fake_leaf).unwrap();
+        assert!(faulty_accumulator.verify_path(proof, fake_leaf));
     }
 
     proptest! {
         // Use proptest to generate arbitrary Merkle trees as part of our fuzzing strategy. This
         // will help us identify any edge cases or unexpected behavior in the implementation.
         #[test]
-        fn test_merkle_tree(v in any::<MerkleAccumulatorDataWrapper>()) {
+        fn test_merkle_tree(v in any::<MerkleTreeDataWrapper>()) {
             for d in v.data {
                 let proof = v.accumulator.prove(&d).unwrap();
-                assert!(v.accumulator.check(proof, &d));
+                assert!(v.accumulator.verify_path(proof, &d));
             }
         }
 
@@ -453,7 +463,7 @@ mod test {
         // passes which should not.
         #[test]
         fn test_fake_merkle_proofs(
-            v in any::<MerkleAccumulatorDataWrapper>(),
+            v in any::<MerkleTreeDataWrapper>(),
             p in any::<MerklePath<Keccak256>>(),
         ) {
             // Reject 1-sized trees as they will always pass due to root being the only elements
@@ -463,7 +473,7 @@ mod test {
             }
 
             for d in v.data {
-                assert!(!v.accumulator.check(p.clone(), &d));
+                assert!(!v.accumulator.verify_path(p.clone(), &d));
             }
         }
     }