瀏覽代碼

test(hermes): add tests for store (#985)

Ali Behjati 2 年之前
父節點
當前提交
c1517349f8
共有 6 個文件被更改,包括 410 次插入18 次删除
  1. 10 0
      hermes/Cargo.lock
  2. 1 0
      hermes/Cargo.toml
  3. 383 16
      hermes/src/store.rs
  4. 5 0
      hermes/src/store/storage.rs
  5. 2 0
      hermes/src/store/types.rs
  6. 9 2
      hermes/src/store/wormhole.rs

+ 10 - 0
hermes/Cargo.lock

@@ -1762,6 +1762,7 @@ dependencies = [
  "libc",
  "libp2p",
  "log",
+ "mock_instant",
  "prometheus-client",
  "pyth-sdk",
  "pythnet-sdk",
@@ -2994,6 +2995,15 @@ dependencies = [
  "windows-sys 0.48.0",
 ]
 
+[[package]]
+name = "mock_instant"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6c1a54de846c4006b88b1516731cc1f6026eb5dc4bcb186aa071ef66d40524ec"
+dependencies = [
+ "once_cell",
+]
+
 [[package]]
 name = "multiaddr"
 version = "0.13.0"

+ 1 - 0
hermes/Cargo.toml

@@ -34,6 +34,7 @@ libp2p                 = { version = "0.42.2", features = [
 ]}
 
 log                    = { version = "0.4.17" }
+mock_instant           = { version = "0.3.1", features = ["sync"] }
 prometheus-client      = { version = "0.21.1" }
 pyth-sdk               = { version = "0.8.0" }
 

+ 383 - 16
hermes/src/store.rs

@@ -1,3 +1,15 @@
+#[cfg(test)]
+use mock_instant::{
+    Instant,
+    SystemTime,
+    UNIX_EPOCH,
+};
+#[cfg(not(test))]
+use std::time::{
+    Instant,
+    SystemTime,
+    UNIX_EPOCH,
+};
 use {
     self::{
         proof::wormhole_merkle::{
@@ -55,20 +67,11 @@ use {
             HashSet,
         },
         sync::Arc,
-        time::{
-            SystemTime,
-            UNIX_EPOCH,
-        },
+        time::Duration,
     },
-    tokio::{
-        sync::{
-            mpsc::Sender,
-            RwLock,
-        },
-        time::{
-            Duration,
-            Instant,
-        },
+    tokio::sync::{
+        mpsc::Sender,
+        RwLock,
     },
     wormhole_sdk::{
         Address,
@@ -83,6 +86,7 @@ pub mod types;
 pub mod wormhole;
 
 const OBSERVED_CACHE_SIZE: usize = 1000;
+const READINESS_STALENESS_THRESHOLD: Duration = Duration::from_secs(30);
 
 pub struct Store {
     /// Storage is a short-lived cache of the state of all the updates
@@ -115,8 +119,11 @@ impl Store {
 
     /// Stores the update data in the store
     pub async fn store_update(&self, update: Update) -> Result<()> {
+        // The slot that the update is originating from. It should be available
+        // in all the updates.
         let slot = match update {
             Update::Vaa(vaa_bytes) => {
+                // FIXME: Move to wormhole.rs
                 let vaa =
                     serde_wormhole::from_slice::<Vaa<&serde_wormhole::RawMessage>>(&vaa_bytes)?;
 
@@ -288,14 +295,374 @@ impl Store {
     }
 
     pub async fn is_ready(&self) -> bool {
-        const STALENESS_THRESHOLD: Duration = Duration::from_secs(30);
-
         let last_completed_update_at = self.last_completed_update_at.read().await;
         match last_completed_update_at.as_ref() {
             Some(last_completed_update_at) => {
-                last_completed_update_at.elapsed() < STALENESS_THRESHOLD
+                last_completed_update_at.elapsed() < READINESS_STALENESS_THRESHOLD
             }
             None => false,
         }
     }
 }
+
+#[cfg(test)]
+mod test {
+    use {
+        super::{
+            types::Slot,
+            *,
+        },
+        futures::future::join_all,
+        mock_instant::MockClock,
+        pythnet_sdk::{
+            accumulators::{
+                merkle::{
+                    MerkleRoot,
+                    MerkleTree,
+                },
+                Accumulator,
+            },
+            hashers::keccak256_160::Keccak160,
+            messages::{
+                Message,
+                PriceFeedMessage,
+            },
+            wire::v1::{
+                AccumulatorUpdateData,
+                Proof,
+                WormholeMerkleRoot,
+            },
+        },
+        rand::seq::SliceRandom,
+        serde_wormhole::RawMessage,
+        tokio::sync::mpsc::Receiver,
+    };
+
+    /// Generate list of updates for the given list of messages at a given slot with given sequence
+    ///
+    /// Sequence in Vaas is used to filter duplicate messages (as by wormhole design there is only
+    /// one message per sequence)
+    pub fn generate_update(messages: Vec<Message>, slot: Slot, sequence: u64) -> Vec<Update> {
+        let mut updates = Vec::new();
+
+        // Accumulator messages
+        let accumulator_messages = AccumulatorMessages {
+            slot,
+            raw_messages: messages
+                .iter()
+                .map(|message| pythnet_sdk::wire::to_vec::<_, byteorder::BE>(message).unwrap())
+                .collect(),
+            magic: [0; 4],
+            ring_size: 100,
+        };
+        updates.push(Update::AccumulatorMessages(accumulator_messages.clone()));
+
+        // Wormhole merkle update
+        let merkle_tree = MerkleTree::<Keccak160>::from_set(
+            accumulator_messages.raw_messages.iter().map(|m| m.as_ref()),
+        )
+        .unwrap();
+
+        let wormhole_message = WormholeMessage::new(WormholePayload::Merkle(WormholeMerkleRoot {
+            slot,
+            ring_size: 100,
+            root: merkle_tree.root.as_bytes().try_into().unwrap(),
+        }));
+
+        let wormhole_message =
+            pythnet_sdk::wire::to_vec::<_, byteorder::BE>(&wormhole_message).unwrap();
+
+        let vaa = Vaa {
+            nonce: 0,
+            version: 0,
+            sequence,
+            timestamp: 0,
+            signatures: vec![],    // We are bypassing signature check now
+            guardian_set_index: 0, // We are bypassing signature check now
+            emitter_chain: Chain::Pythnet,
+            emitter_address: Address(pythnet_sdk::ACCUMULATOR_EMITTER_ADDRESS),
+            consistency_level: 0,
+            payload: serde_wormhole::RawMessage::new(wormhole_message.as_ref()),
+        };
+
+        updates.push(Update::Vaa(serde_wormhole::to_vec(&vaa).unwrap()));
+
+        updates
+    }
+
+    /// Create a dummy price feed base on the given seed for all the fields except
+    /// `publish_time` and `prev_publish_time`. Those are set to the given value.
+    pub fn create_dummy_price_feed_message(
+        seed: u8,
+        publish_time: i64,
+        prev_publish_time: i64,
+    ) -> PriceFeedMessage {
+        PriceFeedMessage {
+            feed_id: [seed; 32],
+            price: seed as _,
+            conf: seed as _,
+            exponent: 0,
+            ema_conf: seed as _,
+            ema_price: seed as _,
+            publish_time,
+            prev_publish_time,
+        }
+    }
+
+    pub async fn setup_store(cache_size: u64) -> (Arc<Store>, Receiver<()>) {
+        let (update_tx, update_rx) = tokio::sync::mpsc::channel(1000);
+        let store = Store::new(update_tx, cache_size);
+
+        // Add an initial guardian set with public key 0
+        store
+            .update_guardian_set(
+                0,
+                GuardianSet {
+                    keys: vec![[0; 20]],
+                },
+            )
+            .await;
+
+        (store, update_rx)
+    }
+
+    pub async fn store_multiple_concurrent_valid_updates(store: Arc<Store>, updates: Vec<Update>) {
+        let res = join_all(updates.into_iter().map(|u| store.store_update(u))).await;
+        // Check that all store_update calls succeeded
+        assert!(res.into_iter().all(|r| r.is_ok()));
+    }
+
+    #[tokio::test]
+    pub async fn test_store_works() {
+        let (store, mut update_rx) = setup_store(10).await;
+
+        let price_feed_message = create_dummy_price_feed_message(100, 10, 9);
+
+        // Populate the store
+        store_multiple_concurrent_valid_updates(
+            store.clone(),
+            generate_update(vec![Message::PriceFeedMessage(price_feed_message)], 10, 20),
+        )
+        .await;
+
+        // Check that the update_rx channel has received a message
+        assert_eq!(update_rx.recv().await, Some(()));
+
+        // Check the price ids are stored correctly
+        assert_eq!(
+            store.get_price_feed_ids().await,
+            vec![PriceIdentifier::new([100; 32])].into_iter().collect()
+        );
+
+        // Check get_price_feeds_with_update_data retrieves the correct
+        // price feed with correct update data.
+        let price_feeds_with_update_data = store
+            .get_price_feeds_with_update_data(
+                vec![PriceIdentifier::new([100; 32])],
+                RequestTime::Latest,
+            )
+            .await
+            .unwrap();
+
+        assert_eq!(
+            price_feeds_with_update_data.price_feeds,
+            vec![PriceFeedUpdate {
+                price_feed:                  price_feed_message,
+                slot:                        10,
+                received_at:                 price_feeds_with_update_data.price_feeds[0]
+                    .received_at, // Ignore checking this field.
+                wormhole_merkle_update_data: price_feeds_with_update_data.price_feeds[0]
+                    .wormhole_merkle_update_data
+                    .clone(), // Ignore checking this field.
+            }]
+        );
+
+        // Check the update data is correct.
+        assert_eq!(
+            price_feeds_with_update_data
+                .wormhole_merkle_update_data
+                .len(),
+            1
+        );
+        let update_data = price_feeds_with_update_data
+            .wormhole_merkle_update_data
+            .get(0)
+            .unwrap();
+        let update_data = AccumulatorUpdateData::try_from_slice(update_data.as_ref()).unwrap();
+        match update_data.proof {
+            Proof::WormholeMerkle { vaa, updates } => {
+                // Check the vaa and get the root
+                let vaa: Vec<u8> = vaa.into();
+                let vaa: Vaa<&RawMessage> = serde_wormhole::from_slice(vaa.as_ref()).unwrap();
+                assert_eq!(
+                    vaa,
+                    Vaa {
+                        nonce:              0,
+                        version:            0,
+                        sequence:           20,
+                        timestamp:          0,
+                        signatures:         vec![],
+                        guardian_set_index: 0,
+                        emitter_chain:      Chain::Pythnet,
+                        emitter_address:    Address(pythnet_sdk::ACCUMULATOR_EMITTER_ADDRESS),
+                        consistency_level:  0,
+                        payload:            vaa.payload, // Ignore checking this field.
+                    }
+                );
+                let merkle_root = WormholeMessage::try_from_bytes(vaa.payload.as_ref()).unwrap();
+                let WormholePayload::Merkle(merkle_root) = merkle_root.payload;
+                assert_eq!(
+                    merkle_root,
+                    WormholeMerkleRoot {
+                        slot:      10,
+                        ring_size: 100,
+                        root:      merkle_root.root, // Ignore checking this field.
+                    }
+                );
+
+                // Check the updates
+                assert_eq!(updates.len(), 1);
+                let update = updates.get(0).unwrap();
+                let message: Vec<u8> = update.message.clone().into();
+                // Check the serialized message is the price feed message generated above.
+                assert_eq!(
+                    pythnet_sdk::wire::from_slice::<byteorder::BE, Message>(message.as_ref())
+                        .unwrap(),
+                    Message::PriceFeedMessage(price_feed_message)
+                );
+
+                // Check the proof is correct with the Vaa root
+                let merkle_root = MerkleRoot::<Keccak160>::new(merkle_root.root);
+                assert!(merkle_root.check(update.proof.clone(), message.as_ref()));
+            }
+        }
+    }
+
+    #[tokio::test]
+    pub async fn test_metadata_times_and_readiness_work() {
+        // The receiver channel should stay open for the store to work
+        // properly. That is why we don't use _ here as it drops the channel
+        // immediately.
+        let (store, _receiver_tx) = setup_store(10).await;
+
+        let price_feed_message = create_dummy_price_feed_message(100, 10, 9);
+
+        // Advance the clock
+        MockClock::advance_system_time(Duration::from_secs(1));
+        MockClock::advance(Duration::from_secs(1));
+
+        // Get the current unix timestamp. It is mocked using
+        // mock-instance module. So it should remain the same
+        // on the next call.
+        let unix_timestamp = SystemTime::now()
+            .duration_since(UNIX_EPOCH)
+            .unwrap()
+            .as_secs();
+
+        // Populate the store
+        store_multiple_concurrent_valid_updates(
+            store.clone(),
+            generate_update(vec![Message::PriceFeedMessage(price_feed_message)], 10, 20),
+        )
+        .await;
+
+        // Advance the clock again
+        MockClock::advance_system_time(Duration::from_secs(1));
+        MockClock::advance(Duration::from_secs(1));
+
+        // Get the price feeds with update data
+        let price_feeds_with_update_data = store
+            .get_price_feeds_with_update_data(
+                vec![PriceIdentifier::new([100; 32])],
+                RequestTime::Latest,
+            )
+            .await
+            .unwrap();
+
+        // check received_at is correct
+        assert_eq!(price_feeds_with_update_data.price_feeds.len(), 1);
+        assert_eq!(
+            price_feeds_with_update_data.price_feeds[0].received_at,
+            unix_timestamp as i64
+        );
+
+        // Check the store is ready
+        assert!(store.is_ready().await);
+
+        // Advance the clock to make the prices stale
+        MockClock::advance_system_time(READINESS_STALENESS_THRESHOLD);
+        MockClock::advance(READINESS_STALENESS_THRESHOLD);
+        // Check the store is not ready
+        assert!(!store.is_ready().await);
+    }
+
+    /// Test that the store retains the latest slots upon cache eviction.
+    ///
+    /// Store is set up with cache size of 100 and 1000 slot updates will
+    /// be stored all at the same time with random order.
+    /// After the cache eviction, the store should retain the latest 100
+    /// slots regardless of the order.
+    #[tokio::test]
+    pub async fn test_store_retains_latest_slots_upon_cache_eviction() {
+        // The receiver channel should stay open for the store to work
+        // properly. That is why we don't use _ here as it drops the channel
+        // immediately.
+        let (store, _receiver_tx) = setup_store(100).await;
+
+        let mut updates: Vec<Update> = (0..1000)
+            .flat_map(|slot| {
+                let messages = vec![
+                    Message::PriceFeedMessage(create_dummy_price_feed_message(
+                        100,
+                        slot as i64,
+                        slot as i64,
+                    )),
+                    Message::PriceFeedMessage(create_dummy_price_feed_message(
+                        200,
+                        slot as i64,
+                        slot as i64,
+                    )),
+                ];
+                generate_update(messages, slot, slot)
+            })
+            .collect();
+
+        // Shuffle the updates
+        let mut rng = rand::thread_rng();
+        updates.shuffle(&mut rng);
+
+        // Store the updates
+        store_multiple_concurrent_valid_updates(store.clone(), updates).await;
+
+        // Check the last 100 slots are retained
+        for slot in 900..1000 {
+            let price_feeds_with_update_data = store
+                .get_price_feeds_with_update_data(
+                    vec![
+                        PriceIdentifier::new([100; 32]),
+                        PriceIdentifier::new([200; 32]),
+                    ],
+                    RequestTime::FirstAfter(slot as i64),
+                )
+                .await
+                .unwrap();
+            assert_eq!(price_feeds_with_update_data.price_feeds.len(), 2);
+            assert_eq!(price_feeds_with_update_data.price_feeds[0].slot, slot);
+            assert_eq!(price_feeds_with_update_data.price_feeds[1].slot, slot);
+        }
+
+        // Check nothing else is retained
+        for slot in 0..900 {
+            assert!(store
+                .get_price_feeds_with_update_data(
+                    vec![
+                        PriceIdentifier::new([100; 32]),
+                        PriceIdentifier::new([200; 32]),
+                    ],
+                    RequestTime::FirstAfter(slot as i64),
+                )
+                .await
+                .is_err());
+        }
+    }
+}

+ 5 - 0
hermes/src/store/storage.rs

@@ -45,6 +45,11 @@ pub struct MessageStateTime {
 pub struct MessageState {
     pub slot:        Slot,
     pub message:     Message,
+    /// The raw updated message.
+    ///
+    /// We need to store the raw message binary because the Message
+    /// struct might lose some data due to its support for forward
+    /// compatibility.
     pub raw_message: RawMessage,
     pub proof_set:   ProofSet,
     pub received_at: UnixTimestamp,

+ 2 - 0
hermes/src/store/types.rs

@@ -45,6 +45,7 @@ pub enum Update {
     AccumulatorMessages(AccumulatorMessages),
 }
 
+#[derive(Debug, PartialEq)]
 pub struct PriceFeedUpdate {
     pub price_feed:                  PriceFeedMessage,
     pub slot:                        Slot,
@@ -55,6 +56,7 @@ pub struct PriceFeedUpdate {
     pub wormhole_merkle_update_data: Vec<u8>,
 }
 
+#[derive(Debug, PartialEq)]
 pub struct PriceFeedsWithUpdateData {
     pub price_feeds:                 Vec<PriceFeedUpdate>,
     pub wormhole_merkle_update_data: Vec<Vec<u8>>,

+ 9 - 2
hermes/src/store/wormhole.rs

@@ -26,7 +26,6 @@ use {
     },
 };
 
-/// A small wrapper around [u8; 20] guardian set key types.
 #[derive(Eq, PartialEq, Clone, Hash, Debug)]
 pub struct GuardianSet {
     pub keys: Vec<[u8; 20]>,
@@ -120,7 +119,15 @@ pub async fn verify_vaa<'a>(
         }
     }
 
-    let quorum = (guardian_set.keys.len() * 2 + 2) / 3;
+    // TODO: This check bypass checking the signatures on tests.
+    // Ideally we need to test the signatures but currently Wormhole
+    // doesn't give us any easy way for it.
+    let quorum = if cfg!(test) {
+        0
+    } else {
+        (guardian_set.keys.len() * 2) / 3 + 1
+    };
+
     if num_correct_signers < quorum {
         return Err(anyhow!(
             "Not enough correct signatures. Expected {:?}, received {:?}",