浏览代码

refactor(hermes): wrap hex::serde for deserializing

Reisen 2 年之前
父节点
当前提交
80fe023174
共有 6 个文件被更改,包括 97 次插入52 次删除
  1. 1 1
      hermes/Cargo.toml
  2. 6 4
      hermes/src/api/rest/get_vaa_ccip.rs
  3. 6 5
      hermes/src/api/types.rs
  4. 0 41
      hermes/src/macros.rs
  5. 1 1
      hermes/src/main.rs
  6. 83 0
      hermes/src/serde.rs

+ 1 - 1
hermes/Cargo.toml

@@ -17,7 +17,7 @@ dashmap            = { version = "5.4.0" }
 derive_more        = { version = "0.99.17" }
 derive_more        = { version = "0.99.17" }
 env_logger         = { version = "0.10.0" }
 env_logger         = { version = "0.10.0" }
 futures            = { version = "0.3.28" }
 futures            = { version = "0.3.28" }
-hex                = { version = "0.4.3" }
+hex                = { version = "0.4.3", features = ["serde"] }
 humantime          = { version = "2.1.0" }
 humantime          = { version = "2.1.0" }
 lazy_static        = { version = "1.4.0" }
 lazy_static        = { version = "1.4.0" }
 libc               = { version = "0.2.140" }
 libc               = { version = "0.2.140" }

+ 6 - 4
hermes/src/api/rest/get_vaa_ccip.rs

@@ -5,7 +5,6 @@ use {
             UnixTimestamp,
             UnixTimestamp,
         },
         },
         api::rest::RestError,
         api::rest::RestError,
-        impl_deserialize_for_hex_string_wrapper,
     },
     },
     anyhow::Result,
     anyhow::Result,
     axum::{
     axum::{
@@ -17,6 +16,10 @@ use {
         DerefMut,
         DerefMut,
     },
     },
     pyth_sdk::PriceIdentifier,
     pyth_sdk::PriceIdentifier,
+    serde::{
+        Deserialize,
+        Serialize,
+    },
     serde_qs::axum::QsQuery,
     serde_qs::axum::QsQuery,
     utoipa::{
     utoipa::{
         IntoParams,
         IntoParams,
@@ -24,9 +27,8 @@ use {
     },
     },
 };
 };
 
 
-#[derive(Debug, Clone, Deref, DerefMut, ToSchema)]
-pub struct GetVaaCcipInput([u8; 40]);
-impl_deserialize_for_hex_string_wrapper!(GetVaaCcipInput, 40);
+#[derive(Clone, Debug, Deref, DerefMut, Deserialize, Serialize, ToSchema)]
+pub struct GetVaaCcipInput(#[serde(with = "crate::serde::hex")] [u8; 40]);
 
 
 #[derive(Debug, serde::Deserialize, IntoParams)]
 #[derive(Debug, serde::Deserialize, IntoParams)]
 #[into_params(parameter_in=Query)]
 #[into_params(parameter_in=Query)]

+ 6 - 5
hermes/src/api/types.rs

@@ -6,7 +6,6 @@ use {
             UnixTimestamp,
             UnixTimestamp,
         },
         },
         doc_examples,
         doc_examples,
-        impl_deserialize_for_hex_string_wrapper,
     },
     },
     base64::{
     base64::{
         engine::general_purpose::STANDARD as base64_standard_engine,
         engine::general_purpose::STANDARD as base64_standard_engine,
@@ -21,6 +20,10 @@ use {
         DerefMut,
         DerefMut,
     },
     },
     pyth_sdk::PriceIdentifier,
     pyth_sdk::PriceIdentifier,
+    serde::{
+        Deserialize,
+        Serialize,
+    },
     utoipa::ToSchema,
     utoipa::ToSchema,
     wormhole_sdk::Chain,
     wormhole_sdk::Chain,
 };
 };
@@ -33,11 +36,9 @@ use {
 /// * e62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43
 /// * e62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43
 ///
 ///
 /// See https://pyth.network/developers/price-feed-ids for a list of all price feed ids.
 /// See https://pyth.network/developers/price-feed-ids for a list of all price feed ids.
-#[derive(Debug, Clone, Deref, DerefMut, ToSchema)]
+#[derive(Clone, Debug, Deref, DerefMut, Deserialize, Serialize, ToSchema)]
 #[schema(value_type=String, example=doc_examples::price_feed_id_example)]
 #[schema(value_type=String, example=doc_examples::price_feed_id_example)]
-pub struct PriceIdInput([u8; 32]);
-// TODO: Use const generics instead of macro.
-impl_deserialize_for_hex_string_wrapper!(PriceIdInput, 32);
+pub struct PriceIdInput(#[serde(with = "crate::serde::hex")] [u8; 32]);
 
 
 impl From<PriceIdInput> for PriceIdentifier {
 impl From<PriceIdInput> for PriceIdentifier {
     fn from(id: PriceIdInput) -> Self {
     fn from(id: PriceIdInput) -> Self {

+ 0 - 41
hermes/src/macros.rs

@@ -1,41 +0,0 @@
-#[macro_export]
-/// A macro that generates Deserialize from string for a struct S that wraps [u8; N] where N is a
-/// compile-time constant. This macro deserializes a string with or without leading 0x and supports
-/// both lower case and upper case hex characters.
-macro_rules! impl_deserialize_for_hex_string_wrapper {
-    ($struct_name:ident, $array_size:expr) => {
-        impl<'de> serde::Deserialize<'de> for $struct_name {
-            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
-            where
-                D: serde::Deserializer<'de>,
-            {
-                struct HexVisitor;
-
-                impl<'de> serde::de::Visitor<'de> for HexVisitor {
-                    type Value = [u8; $array_size];
-
-                    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
-                        write!(formatter, "a hex string of length {}", $array_size * 2)
-                    }
-
-                    fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
-                    where
-                        E: serde::de::Error,
-                    {
-                        let s = s.trim_start_matches("0x");
-                        let bytes = hex::decode(s)
-                            .map_err(|_| E::invalid_value(serde::de::Unexpected::Str(s), &self))?;
-                        if bytes.len() != $array_size {
-                            return Err(E::invalid_length(bytes.len(), &self));
-                        }
-                        let mut array = [0_u8; $array_size];
-                        array.copy_from_slice(&bytes);
-                        Ok(array)
-                    }
-                }
-
-                deserializer.deserialize_str(HexVisitor).map($struct_name)
-            }
-        }
-    };
-}

+ 1 - 1
hermes/src/main.rs

@@ -20,8 +20,8 @@ mod aggregate;
 mod api;
 mod api;
 mod config;
 mod config;
 mod doc_examples;
 mod doc_examples;
-mod macros;
 mod network;
 mod network;
+mod serde;
 mod state;
 mod state;
 mod wormhole;
 mod wormhole;
 
 

+ 83 - 0
hermes/src/serde.rs

@@ -0,0 +1,83 @@
+pub mod hex {
+    use {
+        hex::FromHex,
+        serde::{
+            de::IntoDeserializer,
+            Deserialize,
+            Deserializer,
+            Serializer,
+        },
+    };
+
+    pub fn serialize<S, const N: usize>(b: &[u8; N], s: S) -> Result<S::Ok, S::Error>
+    where
+        S: Serializer,
+    {
+        s.serialize_str(hex::encode(b).as_str())
+    }
+
+    pub fn deserialize<'de, D, R>(d: D) -> Result<R, D::Error>
+    where
+        D: Deserializer<'de>,
+        R: FromHex,
+        <R as hex::FromHex>::Error: std::fmt::Display,
+    {
+        let s: String = Deserialize::deserialize(d)?;
+        let p = s.starts_with("0x") || s.starts_with("0X");
+        let s = if p { &s[2..] } else { &s[..] };
+        hex::serde::deserialize(s.into_deserializer())
+    }
+
+    #[cfg(test)]
+    mod tests {
+        use serde::Deserialize;
+
+        #[derive(Debug, Deserialize, PartialEq)]
+        struct H(#[serde(with = "super")] [u8; 32]);
+
+        #[test]
+        fn test_deserialize() {
+            let e = H([
+                0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab,
+                0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67,
+                0x89, 0xab, 0xcd, 0xef,
+            ]);
+
+            let l = "\"0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\"";
+            let u = "\"0x0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF\"";
+            assert_eq!(serde_json::from_str::<H>(l).unwrap(), e);
+            assert_eq!(serde_json::from_str::<H>(u).unwrap(), e);
+
+            let l = "\"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\"";
+            let u = "\"0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF\"";
+            assert_eq!(serde_json::from_str::<H>(l).unwrap(), e);
+            assert_eq!(serde_json::from_str::<H>(u).unwrap(), e);
+        }
+
+        #[test]
+        fn test_deserialize_invalid_length() {
+            let l = "\"0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde\"";
+            let u = "\"0X0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789ABCDE\"";
+            assert!(serde_json::from_str::<H>(l).is_err());
+            assert!(serde_json::from_str::<H>(u).is_err());
+
+            let l = "\"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde\"";
+            let u = "\"0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789ABCDE\"";
+            assert!(serde_json::from_str::<H>(l).is_err());
+            assert!(serde_json::from_str::<H>(u).is_err());
+        }
+
+        #[test]
+        fn test_deserialize_invalid_hex() {
+            let l = "\"0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdeg\"";
+            let u = "\"0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdeg\"";
+            assert!(serde_json::from_str::<H>(l).is_err());
+            assert!(serde_json::from_str::<H>(u).is_err());
+
+            let l = "\"0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdeg\"";
+            let u = "\"0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdeg\"";
+            assert!(serde_json::from_str::<H>(l).is_err());
+            assert!(serde_json::from_str::<H>(u).is_err());
+        }
+    }
+}