Browse Source

feat(near): parse string identifiers

Reisen 2 years ago
parent
commit
2b118c1f9c

+ 5 - 0
target_chains/near/receiver/src/ext.rs

@@ -28,10 +28,15 @@ pub trait Wormhole {
 /// An external definition of the Pyth interface.
 /// An external definition of the Pyth interface.
 #[ext_contract(ext_pyth)]
 #[ext_contract(ext_pyth)]
 pub trait Pyth {
 pub trait Pyth {
+    // See the implementation for details. The `data` parameter can be found by using a Hermes
+    // price feed endpoint, and should be fed in as base64.
     fn update_price_feeds(&mut self, data: String) -> Result<(), Error>;
     fn update_price_feeds(&mut self, data: String) -> Result<(), Error>;
     fn get_update_fee_estimate(&self, vaa: String) -> U128;
     fn get_update_fee_estimate(&self, vaa: String) -> U128;
     fn get_sources(&self) -> Vec<Source>;
     fn get_sources(&self) -> Vec<Source>;
     fn get_stale_threshold(&self) -> u64;
     fn get_stale_threshold(&self) -> u64;
+
+    // See implementations for details, PriceIdentifier can be passed either as a 64 character
+    // hex price ID which can be found on the Pyth homepage.
     fn price_feed_exists(&self, price_identifier: PriceIdentifier) -> bool;
     fn price_feed_exists(&self, price_identifier: PriceIdentifier) -> bool;
     fn get_price(&self, price_identifier: PriceIdentifier) -> Option<Price>;
     fn get_price(&self, price_identifier: PriceIdentifier) -> Option<Price>;
     fn get_price_unsafe(&self, price_identifier: PriceIdentifier) -> Option<Price>;
     fn get_price_unsafe(&self, price_identifier: PriceIdentifier) -> Option<Price>;

+ 2 - 1
target_chains/near/receiver/src/lib.rs

@@ -151,7 +151,8 @@ impl Pyth {
     /// Instruction for processing VAA's relayed via Wormhole.
     /// Instruction for processing VAA's relayed via Wormhole.
     ///
     ///
     /// Note that VAA verification requires calling Wormhole so processing of the VAA itself is
     /// Note that VAA verification requires calling Wormhole so processing of the VAA itself is
-    /// done in a callback handler, see `process_vaa_callback`.
+    /// done in a callback handler, see `process_vaa_callback`. The `data` parameter can be
+    /// retrieved from Hermes using the price feed APIs.
     #[payable]
     #[payable]
     #[handle_result]
     #[handle_result]
     pub fn update_price_feeds(&mut self, data: String) -> Result<(), Error> {
     pub fn update_price_feeds(&mut self, data: String) -> Result<(), Error> {

+ 49 - 2
target_chains/near/receiver/src/state.rs

@@ -21,11 +21,58 @@ pub type WormholeSignature = [u8; 65];
 /// Type alias for Wormhole's cross-chain 32-byte address.
 /// Type alias for Wormhole's cross-chain 32-byte address.
 pub type WormholeAddress = [u8; 32];
 pub type WormholeAddress = [u8; 32];
 
 
-#[derive(BorshDeserialize, BorshSerialize, Deserialize, Serialize)]
-#[serde(crate = "near_sdk::serde")]
+#[derive(BorshDeserialize, BorshSerialize)]
 #[repr(transparent)]
 #[repr(transparent)]
 pub struct PriceIdentifier(pub [u8; 32]);
 pub struct PriceIdentifier(pub [u8; 32]);
 
 
+impl<'de> near_sdk::serde::Deserialize<'de> for PriceIdentifier {
+    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+    where
+        D: near_sdk::serde::Deserializer<'de>,
+    {
+        /// A visitor that deserializes a hex string into a 32 byte array.
+        struct IdentifierVisitor;
+
+        impl<'de> near_sdk::serde::de::Visitor<'de> for IdentifierVisitor {
+            /// Target type for either a hex string or a 32 byte array.
+            type Value = [u8; 32];
+
+            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
+                formatter.write_str("a hex string")
+            }
+
+            // When given a string, attempt a standard hex decode.
+            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
+            where
+                E: near_sdk::serde::de::Error,
+            {
+                if value.len() != 64 {
+                    return Err(E::custom(format!(
+                        "expected a 64 character hex string, got {}",
+                        value.len()
+                    )));
+                }
+                let mut bytes = [0u8; 32];
+                hex::decode_to_slice(value, &mut bytes).map_err(E::custom)?;
+                Ok(bytes)
+            }
+        }
+
+        deserializer
+            .deserialize_any(IdentifierVisitor)
+            .map(PriceIdentifier)
+    }
+}
+
+impl near_sdk::serde::Serialize for PriceIdentifier {
+    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+    where
+        S: near_sdk::serde::Serializer,
+    {
+        serializer.serialize_str(&hex::encode(&self.0))
+    }
+}
+
 /// A price with a degree of uncertainty, represented as a price +- a confidence interval.
 /// A price with a degree of uncertainty, represented as a price +- a confidence interval.
 ///
 ///
 /// The confidence interval roughly corresponds to the standard error of a normal distribution.
 /// The confidence interval roughly corresponds to the standard error of a normal distribution.

+ 31 - 15
target_chains/near/receiver/tests/workspaces.rs

@@ -400,7 +400,7 @@ async fn test_stale_threshold() {
         &contract
         &contract
             .view("get_update_fee_estimate")
             .view("get_update_fee_estimate")
             .args_json(&json!({
             .args_json(&json!({
-                "vaa": vaa,
+                "data": vaa,
             }))
             }))
             .await
             .await
             .unwrap()
             .unwrap()
@@ -410,7 +410,7 @@ async fn test_stale_threshold() {
 
 
     // Submit price. As there are no prices this should succeed despite being old.
     // Submit price. As there are no prices this should succeed despite being old.
     assert!(contract
     assert!(contract
-        .call("update_price_feed")
+        .call("update_price_feeds")
         .gas(300_000_000_000_000)
         .gas(300_000_000_000_000)
         .deposit(update_fee.into())
         .deposit(update_fee.into())
         .args_json(&json!({
         .args_json(&json!({
@@ -482,7 +482,7 @@ async fn test_stale_threshold() {
 
 
     // The update handler should now succeed even if price is old, but simply not update the price.
     // The update handler should now succeed even if price is old, but simply not update the price.
     assert!(contract
     assert!(contract
-        .call("update_price_feed")
+        .call("update_price_feeds")
         .gas(300_000_000_000_000)
         .gas(300_000_000_000_000)
         .deposit(update_fee.into())
         .deposit(update_fee.into())
         .args_json(&json!({
         .args_json(&json!({
@@ -616,7 +616,7 @@ async fn test_contract_fees() {
         &contract
         &contract
             .view("get_update_fee_estimate")
             .view("get_update_fee_estimate")
             .args_json(&json!({
             .args_json(&json!({
-                "vaa": vaa,
+                "data": vaa,
             }))
             }))
             .await
             .await
             .unwrap()
             .unwrap()
@@ -648,7 +648,7 @@ async fn test_contract_fees() {
                 &contract
                 &contract
                     .view("get_update_fee_estimate")
                     .view("get_update_fee_estimate")
                     .args_json(&json!({
                     .args_json(&json!({
-                        "vaa": vaa,
+                        "data": vaa,
                     }))
                     }))
                     .await
                     .await
                     .unwrap()
                     .unwrap()
@@ -699,7 +699,7 @@ async fn test_contract_fees() {
     };
     };
 
 
     assert!(contract
     assert!(contract
-        .call("update_price_feed")
+        .call("update_price_feeds")
         .gas(300_000_000_000_000)
         .gas(300_000_000_000_000)
         .deposit(update_fee.into())
         .deposit(update_fee.into())
         .args_json(&json!({
         .args_json(&json!({
@@ -990,7 +990,6 @@ async fn test_accumulator_updates() {
     fn create_accumulator_message_from_updates(
     fn create_accumulator_message_from_updates(
         price_updates: Vec<MerklePriceUpdate>,
         price_updates: Vec<MerklePriceUpdate>,
         tree: MerkleTree<Keccak160>,
         tree: MerkleTree<Keccak160>,
-        corrupt_wormhole_message: bool,
         emitter_address: [u8; 32],
         emitter_address: [u8; 32],
         emitter_chain: u16,
         emitter_chain: u16,
     ) -> Vec<u8> {
     ) -> Vec<u8> {
@@ -1026,11 +1025,7 @@ async fn test_accumulator_updates() {
         to_vec::<_, BigEndian>(&accumulator_update_data).unwrap()
         to_vec::<_, BigEndian>(&accumulator_update_data).unwrap()
     }
     }
 
 
-    fn create_accumulator_message(
-        all_feeds: &[Message],
-        updates: &[Message],
-        corrupt_wormhole_message: bool,
-    ) -> Vec<u8> {
+    fn create_accumulator_message(all_feeds: &[Message], updates: &[Message]) -> Vec<u8> {
         let all_feeds_bytes: Vec<_> = all_feeds
         let all_feeds_bytes: Vec<_> = all_feeds
             .iter()
             .iter()
             .map(|f| to_vec::<_, BigEndian>(f).unwrap())
             .map(|f| to_vec::<_, BigEndian>(f).unwrap())
@@ -1050,7 +1045,6 @@ async fn test_accumulator_updates() {
         create_accumulator_message_from_updates(
         create_accumulator_message_from_updates(
             price_updates,
             price_updates,
             tree,
             tree,
-            corrupt_wormhole_message,
             [1; 32],
             [1; 32],
             wormhole::Chain::Any.into(),
             wormhole::Chain::Any.into(),
         )
         )
@@ -1109,12 +1103,12 @@ async fn test_accumulator_updates() {
     // Create a couple of test feeds.
     // Create a couple of test feeds.
     let feed_1 = create_dummy_price_feed_message(100);
     let feed_1 = create_dummy_price_feed_message(100);
     let feed_2 = create_dummy_price_feed_message(200);
     let feed_2 = create_dummy_price_feed_message(200);
-    let message = create_accumulator_message(&[feed_1, feed_2], &[feed_1], false);
+    let message = create_accumulator_message(&[feed_1, feed_2], &[feed_1]);
     let message = hex::encode(message);
     let message = hex::encode(message);
 
 
     // Call the usual UpdatePriceFeed function.
     // Call the usual UpdatePriceFeed function.
     assert!(contract
     assert!(contract
-        .call("update_price_feed")
+        .call("update_price_feeds")
         .gas(300_000_000_000_000)
         .gas(300_000_000_000_000)
         .deposit(300_000_000_000_000_000_000_000)
         .deposit(300_000_000_000_000_000_000_000)
         .args_json(&json!({
         .args_json(&json!({
@@ -1127,4 +1121,26 @@ async fn test_accumulator_updates() {
         .unwrap()
         .unwrap()
         .failures()
         .failures()
         .is_empty());
         .is_empty());
+
+    // Check the price feed actually updated. Check both types of serialized PriceIdentifier.
+    let mut identifier = [0; 32];
+    identifier[0] = 100;
+
+    assert_eq!(
+        Some(Price {
+            price:     100,
+            conf:      100,
+            expo:      100,
+            timestamp: 100,
+        }),
+        serde_json::from_slice::<Option<Price>>(
+            &contract
+                .view("get_price_unsafe")
+                .args_json(&json!({ "price_identifier": PriceIdentifier(identifier) }))
+                .await
+                .unwrap()
+                .result
+        )
+        .unwrap(),
+    );
 }
 }