Browse Source

feat: constants declared in impl blocks in seeds (#2128)

Sammy Harris 3 years ago
parent
commit
290b2aa43e
4 changed files with 92 additions and 14 deletions
  1. 1 0
      CHANGELOG.md
  2. 1 1
      lang/syn/src/idl/mod.rs
  3. 53 13
      lang/syn/src/idl/pda.rs
  4. 37 0
      lang/syn/src/parser/context.rs

+ 1 - 0
CHANGELOG.md

@@ -15,6 +15,7 @@ The minor version will be incremented upon a breaking change and the patch versi
 * client: Add `transaction` functions to RequestBuilder ([#1958](https://github.com/coral-xyz/anchor/pull/1958)).
 * spl: Add `create_metadata_accounts_v3` and `set_collection_size` wrappers ([#2119](https://github.com/coral-xyz/anchor/pull/2119))
 * spl: Add `MetadataAccount` account deserialization. ([#2014](https://github.com/coral-xyz/anchor/pull/2014)).
+* lang: Add parsing for consts from impl blocks for IDL PDA seeds generation ([#2128](https://github.com/coral-xyz/anchor/pull/2014))
 
 ### Fixes
 

+ 1 - 1
lang/syn/src/idl/mod.rs

@@ -234,7 +234,7 @@ impl std::str::FromStr for IdlType {
             "u128" => IdlType::U128,
             "i128" => IdlType::I128,
             "Vec<u8>" => IdlType::Bytes,
-            "String" | "&str" => IdlType::String,
+            "String" | "&str" | "&'staticstr" => IdlType::String,
             "Pubkey" => IdlType::PublicKey,
             _ => match s.to_string().strip_prefix("Option<") {
                 None => match s.to_string().strip_prefix("Vec<") {

+ 53 - 13
lang/syn/src/idl/pda.rs

@@ -56,6 +56,8 @@ struct PdaParser<'a> {
     ix_args: HashMap<String, String>,
     // Constants available in the crate.
     const_names: Vec<String>,
+    // Constants declared in impl blocks available in the crate
+    impl_const_names: Vec<String>,
     // All field names of the accounts in the accounts context.
     account_field_names: Vec<String>,
 }
@@ -65,6 +67,12 @@ impl<'a> PdaParser<'a> {
         // All the available sources of seeds.
         let ix_args = accounts.instruction_args().unwrap_or_default();
         let const_names: Vec<String> = ctx.consts().map(|c| c.ident.to_string()).collect();
+
+        let impl_const_names: Vec<String> = ctx
+            .impl_consts()
+            .map(|(ident, item)| format!("{} :: {}", ident, item.ident))
+            .collect();
+
         let account_field_names = accounts.field_names();
 
         Self {
@@ -72,6 +80,7 @@ impl<'a> PdaParser<'a> {
             accounts,
             ix_args,
             const_names,
+            impl_const_names,
             account_field_names,
         }
     }
@@ -83,7 +92,6 @@ impl<'a> PdaParser<'a> {
             .iter()
             .map(|s| self.parse_seed(s))
             .collect::<Option<Vec<_>>>()?;
-
         // Parse the program id from the constraints.
         let program_id = seeds_grp
             .program_seed
@@ -104,6 +112,8 @@ impl<'a> PdaParser<'a> {
                     self.parse_instruction(&seed_path)
                 } else if self.is_const(&seed_path) {
                     self.parse_const(&seed_path)
+                } else if self.is_impl_const(&seed_path) {
+                    self.parse_impl_const(&seed_path)
                 } else if self.is_account(&seed_path) {
                     self.parse_account(&seed_path)
                 } else if self.is_str_literal(&seed_path) {
@@ -150,18 +160,30 @@ impl<'a> PdaParser<'a> {
             .find(|c| c.ident == seed_path.name())
             .unwrap();
         let idl_ty = IdlType::from_str(&parser::tts_to_string(&const_item.ty)).ok()?;
-        let mut idl_ty_value = parser::tts_to_string(&const_item.expr);
-
-        if let IdlType::Array(_ty, _size) = &idl_ty {
-            // Convert str literal to array.
-            if idl_ty_value.contains("b\"") {
-                let components: Vec<&str> = idl_ty_value.split('b').collect();
-                assert!(components.len() == 2);
-                let mut str_lit = components[1].to_string();
-                str_lit.retain(|c| c != '"');
-                idl_ty_value = format!("{:?}", str_lit.as_bytes());
-            }
-        }
+
+        let idl_ty_value = parser::tts_to_string(&const_item.expr);
+        let idl_ty_value = str_lit_to_array(&idl_ty, &idl_ty_value);
+
+        Some(IdlSeed::Const(IdlSeedConst {
+            ty: idl_ty,
+            value: serde_json::from_str(&idl_ty_value).unwrap(),
+        }))
+    }
+
+    fn parse_impl_const(&self, seed_path: &SeedPath) -> Option<IdlSeed> {
+        // Pull in the constant value directly into the IDL.
+        assert!(seed_path.components().is_empty());
+        let static_item = self
+            .ctx
+            .impl_consts()
+            .find(|(ident, item)| format!("{} :: {}", ident, item.ident) == seed_path.name())
+            .unwrap()
+            .1;
+
+        let idl_ty = IdlType::from_str(&parser::tts_to_string(&static_item.ty)).ok()?;
+
+        let idl_ty_value = parser::tts_to_string(&static_item.expr);
+        let idl_ty_value = str_lit_to_array(&idl_ty, &idl_ty_value);
 
         Some(IdlSeed::Const(IdlSeedConst {
             ty: idl_ty,
@@ -236,6 +258,10 @@ impl<'a> PdaParser<'a> {
         self.const_names.contains(&seed_path.name())
     }
 
+    fn is_impl_const(&self, seed_path: &SeedPath) -> bool {
+        self.impl_const_names.contains(&seed_path.name())
+    }
+
     fn is_account(&self, seed_path: &SeedPath) -> bool {
         self.account_field_names.contains(&seed_path.name())
     }
@@ -327,3 +353,17 @@ fn parse_field_path(ctx: &CrateContext, strct: &syn::ItemStruct, path: &mut &[St
 
     parse_field_path(ctx, strct, path)
 }
+
+fn str_lit_to_array(idl_ty: &IdlType, idl_ty_value: &String) -> String {
+    if let IdlType::Array(_ty, _size) = &idl_ty {
+        // Convert str literal to array.
+        if idl_ty_value.contains("b\"") {
+            let components: Vec<&str> = idl_ty_value.split('b').collect();
+            assert_eq!(components.len(), 2);
+            let mut str_lit = components[1].to_string();
+            str_lit.retain(|c| c != '"');
+            return format!("{:?}", str_lit.as_bytes());
+        }
+    }
+    idl_ty_value.to_string()
+}

+ 37 - 0
lang/syn/src/parser/context.rs

@@ -2,6 +2,7 @@ use anyhow::anyhow;
 use std::collections::BTreeMap;
 use std::path::{Path, PathBuf};
 use syn::parse::{Error as ParseError, Result as ParseResult};
+use syn::{Ident, ImplItem, ImplItemConst, Type, TypePath};
 
 /// Crate parse context
 ///
@@ -15,6 +16,10 @@ impl CrateContext {
         self.modules.iter().flat_map(|(_, ctx)| ctx.consts())
     }
 
+    pub fn impl_consts(&self) -> impl Iterator<Item = (&Ident, &syn::ImplItemConst)> {
+        self.modules.iter().flat_map(|(_, ctx)| ctx.impl_consts())
+    }
+
     pub fn structs(&self) -> impl Iterator<Item = &syn::ItemStruct> {
         self.modules.iter().flat_map(|(_, ctx)| ctx.structs())
     }
@@ -244,4 +249,36 @@ impl ParsedModule {
             _ => None,
         })
     }
+
+    fn impl_consts(&self) -> impl Iterator<Item = (&Ident, &ImplItemConst)> {
+        self.items
+            .iter()
+            .filter_map(|i| match i {
+                syn::Item::Impl(syn::ItemImpl {
+                    self_ty: ty, items, ..
+                }) => {
+                    if let Type::Path(TypePath {
+                        qself: None,
+                        path: p,
+                    }) = ty.as_ref()
+                    {
+                        if let Some(ident) = p.get_ident() {
+                            let mut to_return = Vec::new();
+                            items.iter().for_each(|item| {
+                                if let ImplItem::Const(item) = item {
+                                    to_return.push((ident, item));
+                                }
+                            });
+                            Some(to_return)
+                        } else {
+                            None
+                        }
+                    } else {
+                        None
+                    }
+                }
+                _ => None,
+            })
+            .flatten()
+    }
 }