Procházet zdrojové kódy

fix(hermes): improve TWAP reliability - non-optional price selection and consistent time windows (#2521)

* refactor(twap): update get_twaps_with_update_data to use window_seconds and add LatestTimeEarliestSlot request time

* refactor(twap): change window_seconds type from i64 to u64 in get_twaps_with_update_data function

* bump version
Tejas Badadare před 8 měsíci
rodič
revize
337d9c189f

+ 1 - 1
apps/hermes/server/Cargo.lock

@@ -1868,7 +1868,7 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
 
 [[package]]
 name = "hermes"
-version = "0.8.3"
+version = "0.8.4"
 dependencies = [
  "anyhow",
  "async-trait",

+ 1 - 1
apps/hermes/server/Cargo.toml

@@ -1,6 +1,6 @@
 [package]
 name        = "hermes"
-version     = "0.8.3"
+version     = "0.8.4"
 description = "Hermes is an agent that provides Verified Prices from the Pythnet Pyth Oracle."
 edition     = "2021"
 

+ 1 - 1
apps/hermes/server/src/api/rest.rs

@@ -201,7 +201,7 @@ mod tests {
         async fn get_twaps_with_update_data(
             &self,
             _price_ids: &[PriceIdentifier],
-            _start_time: RequestTime,
+            _window_seconds: u64,
             _end_time: RequestTime,
         ) -> Result<TwapsWithUpdateData> {
             unimplemented!("Not needed for this test")

+ 3 - 11
apps/hermes/server/src/api/rest/v2/latest_twaps.rs

@@ -13,7 +13,7 @@ use {
         Json,
     },
     base64::{engine::general_purpose::STANDARD as base64_standard_engine, Engine as _},
-    pyth_sdk::{DurationInSeconds, PriceIdentifier, UnixTimestamp},
+    pyth_sdk::{DurationInSeconds, PriceIdentifier},
     serde::Deserialize,
     serde_qs::axum::QsQuery,
     utoipa::IntoParams,
@@ -105,20 +105,12 @@ where
     let price_ids: Vec<PriceIdentifier> =
         validate_price_ids(&state, &price_id_inputs, params.ignore_invalid_price_ids).await?;
 
-    // Collect start and end bounds for the TWAP window
-    let window_seconds = path_params.window_seconds as i64;
-    let current_time = std::time::SystemTime::now()
-        .duration_since(std::time::UNIX_EPOCH)
-        .unwrap()
-        .as_secs() as UnixTimestamp;
-    let start_time = current_time - window_seconds;
-
     // Calculate the average
     let twaps_with_update_data = Aggregates::get_twaps_with_update_data(
         &*state.state,
         &price_ids,
-        RequestTime::FirstAfter(start_time),
-        RequestTime::Latest,
+        path_params.window_seconds,
+        RequestTime::LatestTimeEarliestSlot,
     )
     .await
     .map_err(|e| {

+ 146 - 18
apps/hermes/server/src/state/aggregate.rs

@@ -1,3 +1,4 @@
+use log::warn;
 #[cfg(test)]
 use mock_instant::{SystemTime, UNIX_EPOCH};
 use pythnet_sdk::messages::TwapMessage;
@@ -60,6 +61,7 @@ pub type UnixTimestamp = i64;
 #[derive(Clone, PartialEq, Eq, Debug)]
 pub enum RequestTime {
     Latest,
+    LatestTimeEarliestSlot,
     FirstAfter(UnixTimestamp),
     AtSlot(Slot),
 }
@@ -242,7 +244,7 @@ where
     async fn get_twaps_with_update_data(
         &self,
         price_ids: &[PriceIdentifier],
-        start_time: RequestTime,
+        window_seconds: u64,
         end_time: RequestTime,
     ) -> Result<TwapsWithUpdateData>;
 }
@@ -410,16 +412,11 @@ where
     async fn get_twaps_with_update_data(
         &self,
         price_ids: &[PriceIdentifier],
-        start_time: RequestTime,
+        window_seconds: u64,
         end_time: RequestTime,
     ) -> Result<TwapsWithUpdateData> {
-        match get_verified_twaps_with_update_data(
-            self,
-            price_ids,
-            start_time.clone(),
-            end_time.clone(),
-        )
-        .await
+        match get_verified_twaps_with_update_data(self, price_ids, window_seconds, end_time.clone())
+            .await
         {
             Ok(twaps_with_update_data) => Ok(twaps_with_update_data),
             Err(e) => {
@@ -637,33 +634,68 @@ where
 async fn get_verified_twaps_with_update_data<S>(
     state: &S,
     price_ids: &[PriceIdentifier],
-    start_time: RequestTime,
+    window_seconds: u64,
     end_time: RequestTime,
 ) -> Result<TwapsWithUpdateData>
 where
     S: Cache,
 {
-    // Get all start messages for all price IDs
-    let start_messages = state
+    // Get all end messages for all price IDs
+    let end_messages = state
         .fetch_message_states(
             price_ids.iter().map(|id| id.to_bytes()).collect(),
-            start_time.clone(),
+            end_time.clone(),
             MessageStateFilter::Only(MessageType::TwapMessage),
         )
         .await?;
 
-    // Get all end messages for all price IDs
-    let end_messages = state
+    // Calculate start_time based on the publish time of the end messages
+    // to guarantee that the start and end messages are window_seconds apart
+    let start_timestamp = if end_messages.is_empty() {
+        // If there are no end messages, we can't calculate a TWAP
+        tracing::warn!(
+            price_ids = ?price_ids,
+            time = ?end_time,
+            "Could not find TWAP messages"
+        );
+        return Err(anyhow!(
+            "Update data not found for the specified timestamps"
+        ));
+    } else {
+        // Use the publish time from the first end message
+        end_messages[0].message.publish_time() - window_seconds as i64
+    };
+    let start_time = RequestTime::FirstAfter(start_timestamp);
+
+    // Get all start messages for all price IDs
+    let start_messages = state
         .fetch_message_states(
             price_ids.iter().map(|id| id.to_bytes()).collect(),
-            end_time.clone(),
+            start_time.clone(),
             MessageStateFilter::Only(MessageType::TwapMessage),
         )
         .await?;
 
+    if start_messages.is_empty() {
+        tracing::warn!(
+            price_ids = ?price_ids,
+            time = ?start_time,
+            "Could not find TWAP messages"
+        );
+        return Err(anyhow!(
+            "Update data not found for the specified timestamps"
+        ));
+    }
+
     // Verify we have matching start and end messages.
     // The cache should throw an error earlier, but checking just in case.
     if start_messages.len() != end_messages.len() {
+        tracing::warn!(
+            price_ids = ?price_ids,
+            start_message_length = ?price_ids,
+            end_message_length = ?start_time,
+            "Start and end messages length mismatch"
+        );
         return Err(anyhow!(
             "Update data not found for the specified timestamps"
         ));
@@ -695,6 +727,11 @@ where
                     });
                 }
                 Err(e) => {
+                    tracing::error!(
+                        feed_id = ?start_twap.feed_id,
+                        error = %e,
+                        "Failed to calculate TWAP for price feed"
+                    );
                     return Err(anyhow!(
                         "Failed to calculate TWAP for price feed {:?}: {}",
                         start_twap.feed_id,
@@ -1295,7 +1332,7 @@ mod test {
                 PriceIdentifier::new(feed_id_1),
                 PriceIdentifier::new(feed_id_2),
             ],
-            RequestTime::FirstAfter(100), // Start time
+            100,                          // window seconds
             RequestTime::FirstAfter(200), // End time
         )
         .await
@@ -1329,6 +1366,97 @@ mod test {
         // update_data should have 2 elements, one for the start block and one for the end block.
         assert_eq!(result.update_data.len(), 2);
     }
+
+    #[tokio::test]
+    /// Tests that the TWAP calculation correctly selects TWAP messages that are the first ones
+    /// for their timestamp (non-optional prices). This is important because if a message such that
+    /// `publish_time == prev_publish_time`is chosen, the TWAP calculation will fail due to the optionality check.
+    async fn test_get_verified_twaps_with_update_data_uses_non_optional_prices() {
+        let (state, _update_rx) = setup_state(10).await;
+        let feed_id = [1u8; 32];
+
+        // Store start TWAP message
+        store_multiple_concurrent_valid_updates(
+            state.clone(),
+            generate_update(
+                vec![create_basic_twap_message(
+                    feed_id, 100,  // cumulative_price
+                    0,    // num_down_slots
+                    100,  // publish_time
+                    99,   // prev_publish_time
+                    1000, // publish_slot
+                )],
+                1000,
+                20,
+            ),
+        )
+        .await;
+
+        // Store end TWAP messages
+
+        // This first message has the latest publish_time and earliest slot,
+        // so it should be chosen as the end_message to calculate TWAP with.
+        store_multiple_concurrent_valid_updates(
+            state.clone(),
+            generate_update(
+                vec![create_basic_twap_message(
+                    feed_id, 300,  // cumulative_price
+                    50,   // num_down_slots
+                    200,  // publish_time
+                    180,  // prev_publish_time
+                    1100, // publish_slot
+                )],
+                1100,
+                21,
+            ),
+        )
+        .await;
+
+        // This second message has the same publish_time as the previous one and a later slot.
+        // It will fail the optionality check since publish_time == prev_publish_time.
+        // Thus, it should not be chosen to calculate TWAP with.
+        store_multiple_concurrent_valid_updates(
+            state.clone(),
+            generate_update(
+                vec![create_basic_twap_message(
+                    feed_id, 900,  // cumulative_price
+                    50,   // num_down_slots
+                    200,  // publish_time
+                    200,  // prev_publish_time
+                    1101, // publish_slot
+                )],
+                1101,
+                22,
+            ),
+        )
+        .await;
+
+        // Get TWAPs over timestamp window 100 -> 200
+        let result = get_verified_twaps_with_update_data(
+            &*state,
+            &[PriceIdentifier::new(feed_id)],
+            100,                                 // window seconds
+            RequestTime::LatestTimeEarliestSlot, // End time
+        )
+        .await
+        .unwrap();
+
+        // Verify that the first end message was chosen to calculate the TWAP
+        // and that the calculation is accurate
+        assert_eq!(result.twaps.len(), 1);
+        let twap_1 = result
+            .twaps
+            .iter()
+            .find(|t| t.id == PriceIdentifier::new(feed_id))
+            .unwrap();
+        assert_eq!(twap_1.twap.price, 2); // (300-100)/(1100-1000) = 2
+        assert_eq!(twap_1.down_slots_ratio, Decimal::from_f64(0.5).unwrap()); // (50-0)/(1100-1000) = 0.5
+        assert_eq!(twap_1.start_timestamp, 100);
+        assert_eq!(twap_1.end_timestamp, 200);
+
+        // update_data should have 2 elements, one for the start block and one for the end block.
+        assert_eq!(result.update_data.len(), 2);
+    }
     #[tokio::test]
 
     async fn test_get_verified_twaps_with_missing_messages_throws_error() {
@@ -1385,7 +1513,7 @@ mod test {
                 PriceIdentifier::new(feed_id_1),
                 PriceIdentifier::new(feed_id_2),
             ],
-            RequestTime::FirstAfter(100),
+            100,
             RequestTime::FirstAfter(200),
         )
         .await;

+ 89 - 0
apps/hermes/server/src/state/cache.rs

@@ -285,6 +285,28 @@ async fn retrieve_message_state(
         Some(key_cache) => {
             match request_time {
                 RequestTime::Latest => key_cache.last_key_value().map(|(_, v)| v).cloned(),
+                RequestTime::LatestTimeEarliestSlot => {
+                    // Get the latest publish time from the last entry
+                    let last_entry = key_cache.last_key_value()?;
+                    let latest_publish_time = last_entry.0.publish_time;
+                    let mut latest_entry_with_earliest_slot = last_entry;
+
+                    // Walk backwards through the sorted entries rather than use `range` since we will only
+                    // have a couple entries that have the same publish_time.
+                    // We have acquired the RwLock via read() above, so we should be safe to reenter the cache here.
+                    for (k, v) in key_cache.iter().rev() {
+                        if k.publish_time < latest_publish_time {
+                            // We've found an entry with an earlier publish time
+                            break;
+                        }
+
+                        // Update our tracked entry (the reverse iteration will find entries
+                        // with higher slots first, so we'll end up with the lowest slot)
+                        latest_entry_with_earliest_slot = (k, v);
+                    }
+
+                    Some(latest_entry_with_earliest_slot.1.clone())
+                }
                 RequestTime::FirstAfter(time) => {
                     // If the requested time is before the first element in the vector, we are
                     // not sure that the first element is the closest one.
@@ -590,6 +612,73 @@ mod test {
         );
     }
 
+    #[tokio::test]
+    pub async fn test_latest_time_earliest_slot_request_works() {
+        // Initialize state with a cache size of 3 per key.
+        let (state, _) = setup_state(3).await;
+
+        // Create and store a message state with feed id [1....] and publish time 10 at slot 7.
+        create_and_store_dummy_price_feed_message_state(&*state, [1; 32], 10, 7).await;
+
+        // Create and store a message state with feed id [1....] and publish time 10 at slot 10.
+        create_and_store_dummy_price_feed_message_state(&*state, [1; 32], 10, 10).await;
+
+        // Create and store a message state with feed id [1....] and publish time 10 at slot 5.
+        let earliest_slot_message_state =
+            create_and_store_dummy_price_feed_message_state(&*state, [1; 32], 10, 5).await;
+
+        // Create and store a message state with feed id [1....] and publish time 8 at slot 3.
+        create_and_store_dummy_price_feed_message_state(&*state, [1; 32], 8, 3).await;
+
+        // The LatestTimeEarliestSlot should return the message with publish time 10 at slot 5
+        assert_eq!(
+            state
+                .fetch_message_states(
+                    vec![[1; 32]],
+                    RequestTime::LatestTimeEarliestSlot,
+                    MessageStateFilter::Only(MessageType::PriceFeedMessage),
+                )
+                .await
+                .unwrap(),
+            vec![earliest_slot_message_state]
+        );
+
+        // Create and store a message state with feed id [1....] and publish time 15 at slot 20.
+        let newer_time_message_state =
+            create_and_store_dummy_price_feed_message_state(&*state, [1; 32], 15, 20).await;
+
+        // The LatestTimeEarliestSlot should now return the message with publish time 15
+        assert_eq!(
+            state
+                .fetch_message_states(
+                    vec![[1; 32]],
+                    RequestTime::LatestTimeEarliestSlot,
+                    MessageStateFilter::Only(MessageType::PriceFeedMessage),
+                )
+                .await
+                .unwrap(),
+            vec![newer_time_message_state]
+        );
+
+        // Store two messages with even later publish time but different slots
+        create_and_store_dummy_price_feed_message_state(&*state, [1; 32], 20, 35).await;
+        let latest_time_earliest_slot_message =
+            create_and_store_dummy_price_feed_message_state(&*state, [1; 32], 20, 30).await;
+
+        // The LatestTimeEarliestSlot should return the message with publish time 20 at slot 30
+        assert_eq!(
+            state
+                .fetch_message_states(
+                    vec![[1; 32]],
+                    RequestTime::LatestTimeEarliestSlot,
+                    MessageStateFilter::Only(MessageType::PriceFeedMessage),
+                )
+                .await
+                .unwrap(),
+            vec![latest_time_earliest_slot_message]
+        );
+    }
+
     #[tokio::test]
     pub async fn test_store_and_retrieve_first_after_message_state_fails_for_past_time() {
         // Initialize state with a cache size of 2 per key.