Forráskód Böngészése

Drozdziak1/p2w client mapping crawl (#286)

* pyth2wormhole-client: Add a mapping crawling routine

* pyth2wormhole-client: Add mapping_addr for attestation config

* pyth2wormhole-client: cargo fmt
Stanisław Drozd 3 éve
szülő
commit
163fa44f24

+ 34 - 4
solana/pyth2wormhole/Cargo.lock

@@ -2079,8 +2079,8 @@ name = "p2w-sdk"
 version = "0.1.1"
 dependencies = [
  "hex",
- "pyth-sdk",
- "pyth-sdk-solana",
+ "pyth-sdk 0.5.0",
+ "pyth-sdk-solana 0.5.0",
  "serde",
  "solana-program",
  "solitaire",
@@ -2355,6 +2355,19 @@ dependencies = [
  "serde",
 ]
 
+[[package]]
+name = "pyth-sdk"
+version = "0.6.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9a993cc2b76d9842ee92f00d3104a67d403e8a5a745d2474caf42361b3fc815a"
+dependencies = [
+ "borsh",
+ "borsh-derive",
+ "hex",
+ "schemars",
+ "serde",
+]
+
 [[package]]
 name = "pyth-sdk-solana"
 version = "0.5.0"
@@ -2366,7 +2379,24 @@ dependencies = [
  "bytemuck",
  "num-derive",
  "num-traits",
- "pyth-sdk",
+ "pyth-sdk 0.5.0",
+ "serde",
+ "solana-program",
+ "thiserror",
+]
+
+[[package]]
+name = "pyth-sdk-solana"
+version = "0.6.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "97f071fdeb5129de07a2cb70bb4f3a9e4be1e4cc6b85132bdea0967e601eb757"
+dependencies = [
+ "borsh",
+ "borsh-derive",
+ "bytemuck",
+ "num-derive",
+ "num-traits",
+ "pyth-sdk 0.6.1",
  "serde",
  "solana-program",
  "thiserror",
@@ -2399,7 +2429,7 @@ dependencies = [
  "log",
  "p2w-sdk",
  "pyth-client 0.5.1",
- "pyth-sdk-solana",
+ "pyth-sdk-solana 0.6.1",
  "pyth2wormhole",
  "serde",
  "serde_yaml",

+ 1 - 1
solana/pyth2wormhole/client/Cargo.toml

@@ -19,7 +19,7 @@ log = "0.4.14"
 wormhole-bridge-solana = {git = "https://github.com/wormhole-foundation/wormhole", tag = "v2.8.9"}
 pyth2wormhole = {path = "../program"}
 p2w-sdk = { path = "../../../third_party/pyth/p2w-sdk/rust", features=["solana"] }
-pyth-sdk-solana = "0.5.0"
+pyth-sdk-solana = "0.6.1"
 serde = "1"
 serde_yaml = "0.8"
 shellexpand = "2.1.0"

+ 26 - 0
solana/pyth2wormhole/client/src/attestation_cfg.rs

@@ -19,6 +19,12 @@ pub struct AttestationConfig {
     pub min_msg_reuse_interval_ms: u64,
     #[serde(default = "default_max_msg_accounts")]
     pub max_msg_accounts: u64,
+    /// Optionally, we take a mapping account to add remaining symbols from a Pyth deployments. These symbols are processed under attestation conditions for the `default` symbol group.
+    #[serde(
+        deserialize_with = "opt_pubkey_string_de",
+        serialize_with = "opt_pubkey_string_ser"
+    )]
+    pub mapping_addr: Option<Pubkey>,
     pub symbol_groups: Vec<SymbolGroup>,
 }
 
@@ -116,6 +122,25 @@ where
     Ok(pubkey)
 }
 
+fn opt_pubkey_string_ser<S>(k_opt: &Option<Pubkey>, ser: S) -> Result<S::Ok, S::Error>
+where
+    S: Serializer,
+{
+    let k_str_opt = k_opt.clone().map(|k| k.to_string());
+
+    Option::<String>::serialize(&k_str_opt, ser)
+}
+
+fn opt_pubkey_string_de<'de, D>(de: D) -> Result<Option<Pubkey>, D::Error>
+where
+    D: Deserializer<'de>,
+{
+    match Option::<String>::deserialize(de)? {
+        Some(k) => Ok(Some(Pubkey::from_str(&k).map_err(D::Error::custom)?)),
+        None => Ok(None),
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -163,6 +188,7 @@ mod tests {
         let cfg = AttestationConfig {
             min_msg_reuse_interval_ms: 1000,
             max_msg_accounts: 100_000,
+            mapping_addr: None,
             symbol_groups: vec![fastbois, slowbois],
         };
 

+ 83 - 0
solana/pyth2wormhole/client/src/lib.rs

@@ -7,6 +7,15 @@ use borsh::{
     BorshDeserialize,
     BorshSerialize,
 };
+use log::{
+    debug,
+    trace,
+};
+use pyth_sdk_solana::state::{
+    load_mapping_account,
+    load_price_account,
+    load_product_account,
+};
 use solana_client::nonblocking::rpc_client::RpcClient;
 use solana_program::{
     hash::Hash,
@@ -44,6 +53,11 @@ use bridge::{
     types::ConsistencyLevel,
 };
 
+use std::collections::{
+    HashMap,
+    HashSet,
+};
+
 use p2w_sdk::P2WEmitter;
 
 use pyth2wormhole::{
@@ -321,3 +335,72 @@ pub fn gen_attest_tx(
     );
     Ok(tx_signed)
 }
+
+/// Enumerates all products and their prices in a Pyth mapping.
+/// Returns map of: product address => [price addresses]
+pub async fn crawl_pyth_mapping(
+    rpc_client: &RpcClient,
+    first_mapping_addr: &Pubkey,
+) -> Result<HashMap<Pubkey, HashSet<Pubkey>>, ErrBox> {
+    let mut ret = HashMap::new();
+
+    let mut n_mappings = 1; // We assume the first one must be valid
+    let mut n_products = 0;
+    let mut n_prices = 0;
+
+    let mut mapping_addr = first_mapping_addr.clone();
+
+    // loop until the last non-zero MappingAccount.next account
+    loop {
+        let mapping_bytes = rpc_client.get_account_data(&mapping_addr).await?;
+
+        let mapping = load_mapping_account(&mapping_bytes)?;
+
+        // loop through all products in this mapping; filter out zeroed-out empty product slots
+        for prod_addr in mapping.products.iter().filter(|p| *p != &Pubkey::default()) {
+            let prod_bytes = rpc_client.get_account_data(prod_addr).await?;
+            let prod = load_product_account(&prod_bytes)?;
+
+            let mut price_addr = prod.px_acc.clone();
+
+            // loop until the last non-zero PriceAccount.next account
+            loop {
+                let price_bytes = rpc_client.get_account_data(&price_addr).await?;
+                let price = load_price_account(&price_bytes)?;
+
+                // Append to existing set or create a new map entry
+                ret.entry(prod_addr.clone())
+                    .or_insert(HashSet::new())
+                    .insert(price_addr);
+
+                n_prices += 1;
+
+                if price.next == Pubkey::default() {
+                    trace!("Product {}: processed {} prices", prod_addr, n_prices);
+                    break;
+                }
+                price_addr = price.next.clone();
+            }
+
+            n_products += 1;
+        }
+        trace!(
+            "Mapping {}: processed {} products",
+            mapping_addr,
+            n_products
+        );
+
+        // Traverse other mapping accounts if applicable
+        if mapping.next == Pubkey::default() {
+            break;
+        }
+        mapping_addr = mapping.next.clone();
+        n_mappings += 1;
+    }
+    debug!(
+        "Processed {} price(s) in {} product account(s), in {} mapping account(s)",
+        n_prices, n_products, n_mappings
+    );
+
+    Ok(ret)
+}

+ 18 - 9
solana/pyth2wormhole/client/src/main.rs

@@ -29,7 +29,7 @@ use log::{
 use solana_client::{
     client_error::ClientError,
     nonblocking::rpc_client::RpcClient,
-    rpc_config::RpcTransactionConfig
+    rpc_config::RpcTransactionConfig,
 };
 use solana_program::pubkey::Pubkey;
 use solana_sdk::{
@@ -176,6 +176,11 @@ async fn main() -> Result<(), ErrBox> {
             let attestation_cfg: AttestationConfig =
                 serde_yaml::from_reader(File::open(attestation_cfg)?)?;
 
+            if let Some(mapping_addr) = attestation_cfg.mapping_addr.as_ref() {
+                let additional_accounts = crawl_pyth_mapping(&rpc_client, mapping_addr).await?;
+                info!("Additional mapping accounts:\n{:#?}", additional_accounts);
+            }
+
             handle_attest(
                 cli.rpc_url,
                 Duration::from_millis(cli.rpc_interval_ms),
@@ -190,7 +195,7 @@ async fn main() -> Result<(), ErrBox> {
             )
             .await?;
         }
-        Action::GetEmitter => unreachable!{}
+        Action::GetEmitter => unreachable! {},
     }
 
     Ok(())
@@ -267,7 +272,10 @@ async fn handle_attest(
         rpc_interval,
     ));
 
-    let message_q_mtx = Arc::new(Mutex::new(P2WMessageQueue::new(Duration::from_millis(attestation_cfg.min_msg_reuse_interval_ms), attestation_cfg.max_msg_accounts as usize)));
+    let message_q_mtx = Arc::new(Mutex::new(P2WMessageQueue::new(
+        Duration::from_millis(attestation_cfg.min_msg_reuse_interval_ms),
+        attestation_cfg.max_msg_accounts as usize,
+    )));
 
     // Create attestation scheduling routines; see attestation_sched_job() for details
     let mut attestation_sched_futs = batches.into_iter().map(|(batch_no, batch)| {
@@ -297,7 +305,11 @@ async fn handle_attest(
     let errors: Vec<_> = results
         .iter()
         .enumerate()
-        .filter_map(|(idx, r)| r.as_ref().err().map(|e| format!("Error {}: {:#?}\n", idx + 1, e)))
+        .filter_map(|(idx, r)| {
+            r.as_ref()
+                .err()
+                .map(|e| format!("Error {}: {:#?}\n", idx + 1, e))
+        })
         .collect();
 
     if !errors.is_empty() {
@@ -417,13 +429,10 @@ async fn attestation_sched_job(
             let group_name4err_msg = batch.group_name.clone();
 
             // We never get to error reporting in daemon mode, attach a map_err
-            let job_with_err_msg = job.map_err(move |e|  {
+            let job_with_err_msg = job.map_err(move |e| {
                 warn!(
                     "Batch {}/{}, group {:?} ERR: {:#?}",
-                    batch_no4err_msg,
-                    batch_count4err_msg,
-                    group_name4err_msg,
-                    e
+                    batch_no4err_msg, batch_count4err_msg, group_name4err_msg, e
                 );
                 e
             });

+ 7 - 3
solana/pyth2wormhole/client/src/message.rs

@@ -37,7 +37,7 @@ impl P2WMessageQueue {
         Self {
             accounts: VecDeque::new(),
             grace_period,
-            max_accounts
+            max_accounts,
         }
     }
     /// Finds or creates an account with last_used at least grace_period in the past.
@@ -59,7 +59,11 @@ impl P2WMessageQueue {
 
                 // Make sure we're not going over the limit
                 if self.accounts.len() >= self.max_accounts {
-                    return Err(format!("Max message queue size of {} reached.", self.max_accounts).into());
+                    return Err(format!(
+                        "Max message queue size of {} reached.",
+                        self.max_accounts
+                    )
+                    .into());
                 }
 
                 debug!(
@@ -106,7 +110,7 @@ pub mod test {
 
         std::thread::sleep(Duration::from_millis(600));
 
-        // Account 0 should be in front, enough time passed 
+        // Account 0 should be in front, enough time passed
         let acc3 = q.get_account()?;
 
         assert_eq!(q.accounts.len(), 2);

+ 8 - 3
third_party/pyth/p2w_autoattest.py

@@ -169,20 +169,25 @@ if P2W_ATTESTATION_CFG is None:
 
     res = conn.getresponse()
 
-    pyth_accounts = None
+    publisher_state_map = {}
 
     if res.getheader("Content-Type") == "application/json":
-        pyth_accounts = json.load(res)
+        publisher_state_map = json.load(res)
     else:
         logging.error("Bad Content type")
         sys.exit(1)
 
+    pyth_accounts = publisher_state_map["symbols"]
+
     logging.info(
         f"Retrieved {len(pyth_accounts)} Pyth accounts from endpoint: {pyth_accounts}"
     )
 
-    cfg_yaml = """
+    mapping_addr = publisher_state_map["mapping_addr"]
+
+    cfg_yaml = f"""
 ---
+mapping_addr: {mapping_addr}
 symbol_groups:
   - group_name: fast_interval_only
     conditions:

+ 15 - 8
third_party/pyth/pyth_publisher.py

@@ -16,13 +16,13 @@ PYTH_TEST_SYMBOL_COUNT = int(os.environ.get("PYTH_TEST_SYMBOL_COUNT", "9"))
 
 class PythAccEndpoint(BaseHTTPRequestHandler):
     """
-    A dumb endpoint to respond with a JSON containing Pyth account addresses
+    A dumb endpoint to respond with a JSON containing Pyth symbol and mapping addresses
     """
 
     def do_GET(self):
         print(f"Got path {self.path}")
         sys.stdout.flush()
-        data = json.dumps(TEST_SYMBOLS).encode("utf-8")
+        data = json.dumps(HTTP_ENDPOINT_DATA).encode("utf-8")
         print(f"Sending:\n{data}")
 
         self.send_response(200)
@@ -32,8 +32,8 @@ class PythAccEndpoint(BaseHTTPRequestHandler):
         self.wfile.write(data)
         self.wfile.flush()
 
-
-TEST_SYMBOLS = []
+# Test publisher state that gets served via the HTTP endpoint. Note: the schema of this dict is extended here and there
+HTTP_ENDPOINT_DATA = {"symbols": [], "mapping_address": None}
 
 
 def publisher_random_update(price_pubkey):
@@ -92,13 +92,12 @@ def add_symbol(num: int):
         "price": price_pubkey
     }
 
-    TEST_SYMBOLS.append(sym)
+    HTTP_ENDPOINT_DATA["symbols"].append(sym)
 
     sys.stdout.flush()
 
     return num
 
-
 # Fund the publisher
 sol_run_or_die("airdrop", [
     str(SOL_AIRDROP_AMT),
@@ -107,7 +106,15 @@ sol_run_or_die("airdrop", [
 ])
 
 # Create a mapping
-pyth_admin_run_or_die("init_mapping")
+pyth_admin_run_or_die("init_mapping", capture_output=True)
+
+mapping_addr = sol_run_or_die("address", args=[
+    "--keypair", PYTH_MAPPING_KEYPAIR
+], capture_output=True).stdout.strip()
+
+HTTP_ENDPOINT_DATA["mapping_addr"] = mapping_addr
+
+print(f"New mapping at {mapping_addr}")
 
 print(f"Creating {PYTH_TEST_SYMBOL_COUNT} test Pyth symbols")
 
@@ -134,7 +141,7 @@ readiness_thread.start()
 http_service.start()
 
 while True:
-    for sym in TEST_SYMBOLS:
+    for sym in HTTP_ENDPOINT_DATA["symbols"]:
         publisher_random_update(sym["price"])
 
     time.sleep(PYTH_PUBLISHER_INTERVAL)

+ 3 - 0
third_party/pyth/pyth_utils.py

@@ -15,6 +15,9 @@ PYTH_PUBLISHER_KEYPAIR = os.environ.get(
     "PYTH_PUBLISHER_KEYPAIR", f"{PYTH_KEY_STORE}/publish_key_pair.json"
 )
 PYTH_PUBLISHER_INTERVAL = float(os.environ.get("PYTH_PUBLISHER_INTERVAL", "5"))
+PYTH_MAPPING_KEYPAIR = os.environ.get(
+    "PYTH_MAPPING_KEYPAIR", f"{PYTH_KEY_STORE}/mapping_key_pair.json"
+)
 
 # 0 setting disables airdropping
 SOL_AIRDROP_AMT = int(os.environ.get("SOL_AIRDROP_AMT", 0))