瀏覽代碼

feat(near): parse string identifiers

Reisen 2 年之前
父節點
當前提交
0ec98ffb78

+ 8 - 9
target_chains/near/receiver/src/ext.rs

@@ -4,16 +4,9 @@
 use {
     crate::{
         error::Error,
-        state::{
-            Price,
-            PriceIdentifier,
-            Source,
-        },
-    },
-    near_sdk::{
-        ext_contract,
-        json_types::U128,
+        state::{Price, PriceIdentifier, Source},
     },
+    near_sdk::{ext_contract, json_types::U128},
 };
 
 /// Defines the external contract API we care about for interacting with Wormhole. Note that
@@ -28,10 +21,16 @@ pub trait Wormhole {
 /// An external definition of the Pyth interface.
 #[ext_contract(ext_pyth)]
 pub trait Pyth {
+    // See the implementation for details. The `data` parameter can be found by using a Hermes
+    // price feed endpoint, and should be fed in as base64.
     fn update_price_feeds(&mut self, data: String) -> Result<(), Error>;
     fn get_update_fee_estimate(&self, vaa: String) -> U128;
     fn get_sources(&self) -> Vec<Source>;
     fn get_stale_threshold(&self) -> u64;
+
+    // See implementations for details, PriceIdentifier can be passed either as a 64 character
+    // hex price ID which can be found on the Pyth homepage, or can be a 32 element byte array
+    // representing the same thing.
     fn price_feed_exists(&self, price_identifier: PriceIdentifier) -> bool;
     fn get_price(&self, price_identifier: PriceIdentifier) -> Option<Price>;
     fn get_price_unsafe(&self, price_identifier: PriceIdentifier) -> Option<Price>;

+ 2 - 1
target_chains/near/receiver/src/lib.rs

@@ -151,7 +151,8 @@ impl Pyth {
     /// Instruction for processing VAA's relayed via Wormhole.
     ///
     /// Note that VAA verification requires calling Wormhole so processing of the VAA itself is
-    /// done in a callback handler, see `process_vaa_callback`.
+    /// done in a callback handler, see `process_vaa_callback`. The `data` parameter can be
+    /// retrieved from Hermes using the price feed APIs.
     #[payable]
     #[handle_result]
     pub fn update_price_feeds(&mut self, data: String) -> Result<(), Error> {

+ 54 - 1
target_chains/near/receiver/src/state.rs

@@ -21,11 +21,64 @@ pub type WormholeSignature = [u8; 65];
 /// Type alias for Wormhole's cross-chain 32-byte address.
 pub type WormholeAddress = [u8; 32];
 
-#[derive(BorshDeserialize, BorshSerialize, Deserialize, Serialize)]
+#[derive(BorshDeserialize, BorshSerialize, Serialize)]
 #[serde(crate = "near_sdk::serde")]
 #[repr(transparent)]
 pub struct PriceIdentifier(pub [u8; 32]);
 
+impl<'de> near_sdk::serde::Deserialize<'de> for PriceIdentifier {
+    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+    where
+        D: near_sdk::serde::Deserializer<'de>,
+    {
+        /// A visitor that deserializes a hex string into a 32 byte array.
+        struct IdentifierVisitor;
+
+        impl<'de> near_sdk::serde::de::Visitor<'de> for IdentifierVisitor {
+            /// Target type for either a hex string or a 32 byte array.
+            type Value = [u8; 32];
+
+            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
+                formatter.write_str("a 32 byte array or a hex string")
+            }
+
+            // When given a string, attempt a standard hex decode.
+            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
+            where
+                E: near_sdk::serde::de::Error,
+            {
+                if value.len() != 64 {
+                    return Err(E::custom(format!(
+                        "expected a 64 character hex string, got {}",
+                        value.len()
+                    )));
+                }
+                let mut bytes = [0u8; 32];
+                hex::decode_to_slice(value, &mut bytes).map_err(E::custom)?;
+                Ok(bytes)
+            }
+
+            // When given a classic array, attempt to read the bytes directly.
+            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
+            where
+                A: near_sdk::serde::de::SeqAccess<'de>,
+            {
+                let mut bytes = [0u8; 32];
+                for i in 0..32 {
+                    bytes[i] = seq
+                        .next_element()?
+                        .ok_or_else(|| near_sdk::serde::de::Error::invalid_length(i, &self))?;
+                }
+                Ok(bytes)
+            }
+        }
+
+        deserializer
+            .deserialize_any(IdentifierVisitor)
+            .map(PriceIdentifier)
+    }
+}
+
 /// A price with a degree of uncertainty, represented as a price +- a confidence interval.
 ///
 /// The confidence interval roughly corresponds to the standard error of a normal distribution.

+ 54 - 15
target_chains/near/receiver/tests/workspaces.rs

@@ -400,7 +400,7 @@ async fn test_stale_threshold() {
         &contract
             .view("get_update_fee_estimate")
             .args_json(&json!({
-                "vaa": vaa,
+                "data": vaa,
             }))
             .await
             .unwrap()
@@ -410,7 +410,7 @@ async fn test_stale_threshold() {
 
     // Submit price. As there are no prices this should succeed despite being old.
     assert!(contract
-        .call("update_price_feed")
+        .call("update_price_feeds")
         .gas(300_000_000_000_000)
         .deposit(update_fee.into())
         .args_json(&json!({
@@ -482,7 +482,7 @@ async fn test_stale_threshold() {
 
     // The update handler should now succeed even if price is old, but simply not update the price.
     assert!(contract
-        .call("update_price_feed")
+        .call("update_price_feeds")
         .gas(300_000_000_000_000)
         .deposit(update_fee.into())
         .args_json(&json!({
@@ -616,7 +616,7 @@ async fn test_contract_fees() {
         &contract
             .view("get_update_fee_estimate")
             .args_json(&json!({
-                "vaa": vaa,
+                "data": vaa,
             }))
             .await
             .unwrap()
@@ -648,7 +648,7 @@ async fn test_contract_fees() {
                 &contract
                     .view("get_update_fee_estimate")
                     .args_json(&json!({
-                        "vaa": vaa,
+                        "data": vaa,
                     }))
                     .await
                     .unwrap()
@@ -699,7 +699,7 @@ async fn test_contract_fees() {
     };
 
     assert!(contract
-        .call("update_price_feed")
+        .call("update_price_feeds")
         .gas(300_000_000_000_000)
         .deposit(update_fee.into())
         .args_json(&json!({
@@ -990,7 +990,6 @@ async fn test_accumulator_updates() {
     fn create_accumulator_message_from_updates(
         price_updates: Vec<MerklePriceUpdate>,
         tree: MerkleTree<Keccak160>,
-        corrupt_wormhole_message: bool,
         emitter_address: [u8; 32],
         emitter_chain: u16,
     ) -> Vec<u8> {
@@ -1026,11 +1025,7 @@ async fn test_accumulator_updates() {
         to_vec::<_, BigEndian>(&accumulator_update_data).unwrap()
     }
 
-    fn create_accumulator_message(
-        all_feeds: &[Message],
-        updates: &[Message],
-        corrupt_wormhole_message: bool,
-    ) -> Vec<u8> {
+    fn create_accumulator_message(all_feeds: &[Message], updates: &[Message]) -> Vec<u8> {
         let all_feeds_bytes: Vec<_> = all_feeds
             .iter()
             .map(|f| to_vec::<_, BigEndian>(f).unwrap())
@@ -1050,7 +1045,6 @@ async fn test_accumulator_updates() {
         create_accumulator_message_from_updates(
             price_updates,
             tree,
-            corrupt_wormhole_message,
             [1; 32],
             wormhole::Chain::Any.into(),
         )
@@ -1109,12 +1103,12 @@ async fn test_accumulator_updates() {
     // Create a couple of test feeds.
     let feed_1 = create_dummy_price_feed_message(100);
     let feed_2 = create_dummy_price_feed_message(200);
-    let message = create_accumulator_message(&[feed_1, feed_2], &[feed_1], false);
+    let message = create_accumulator_message(&[feed_1, feed_2], &[feed_1]);
     let message = hex::encode(message);
 
     // Call the usual UpdatePriceFeed function.
     assert!(contract
-        .call("update_price_feed")
+        .call("update_price_feeds")
         .gas(300_000_000_000_000)
         .deposit(300_000_000_000_000_000_000_000)
         .args_json(&json!({
@@ -1127,4 +1121,49 @@ async fn test_accumulator_updates() {
         .unwrap()
         .failures()
         .is_empty());
+
+    // Check the price feed actually updated. Check both types of serialized PriceIdentifier.
+    let mut identifier = [0; 32];
+    identifier[0] = 100;
+
+    assert_eq!(
+        Some(Price {
+            price:     100,
+            conf:      100,
+            expo:      100,
+            timestamp: 100,
+        }),
+        serde_json::from_slice::<Option<Price>>(
+            &contract
+                .view("get_price_unsafe")
+                .args_json(&json!({ "price_identifier": PriceIdentifier(identifier) }))
+                .await
+                .unwrap()
+                .result
+        )
+        .unwrap(),
+    );
+
+    // String Identifier should also work.
+    assert_eq!(
+        Some(Price {
+            price: 100,
+            conf: 100,
+            expo: 100,
+            timestamp: 100,
+        }),
+        serde_json::from_slice::<Option<Price>>(
+            &contract
+                .view("get_price_unsafe")
+                .args_json(serde_json::from_str::<serde_json::Value>(
+                    r#"{
+                        "price_identifier": "6400000000000000000000000000000000000000000000000000000000000000"
+                    }"#
+                ).unwrap())
+                .await
+                .unwrap()
+                .result
+        )
+        .unwrap(),
+    );
 }