Parcourir la source

lang: Add constant attribute (#956)

Tom Linton il y a 3 ans
Parent
commit
6bc59350de

+ 10 - 0
Cargo.lock

@@ -79,6 +79,15 @@ dependencies = [
  "syn 1.0.81",
 ]
 
+[[package]]
+name = "anchor-attribute-constant"
+version = "0.18.0"
+dependencies = [
+ "anchor-syn",
+ "proc-macro2 1.0.29",
+ "syn 1.0.75",
+]
+
 [[package]]
 name = "anchor-attribute-error"
 version = "0.18.2"
@@ -196,6 +205,7 @@ version = "0.18.2"
 dependencies = [
  "anchor-attribute-access-control",
  "anchor-attribute-account",
+ "anchor-attribute-constant",
  "anchor-attribute-error",
  "anchor-attribute-event",
  "anchor-attribute-interface",

+ 2 - 0
lang/Cargo.toml

@@ -13,6 +13,7 @@ default = []
 anchor-debug = [
     "anchor-attribute-access-control/anchor-debug",
     "anchor-attribute-account/anchor-debug",
+    "anchor-attribute-constant/anchor-debug",
     "anchor-attribute-error/anchor-debug",
     "anchor-attribute-event/anchor-debug",
     "anchor-attribute-interface/anchor-debug",
@@ -25,6 +26,7 @@ anchor-debug = [
 [dependencies]
 anchor-attribute-access-control = { path = "./attribute/access-control", version = "0.18.2" }
 anchor-attribute-account = { path = "./attribute/account", version = "0.18.2" }
+anchor-attribute-constant = { path = "./attribute/constant", version = "0.18.2" }
 anchor-attribute-error = { path = "./attribute/error", version = "0.18.2" }
 anchor-attribute-program = { path = "./attribute/program", version = "0.18.2" }
 anchor-attribute-state = { path = "./attribute/state", version = "0.18.2" }

+ 19 - 0
lang/attribute/constant/Cargo.toml

@@ -0,0 +1,19 @@
+[package]
+name = "anchor-attribute-constant"
+version = "0.18.2"
+authors = ["Serum Foundation <foundation@projectserum.com>"]
+repository = "https://github.com/project-serum/anchor"
+license = "Apache-2.0"
+description = "Anchor attribute macro for creating constant types"
+edition = "2018"
+
+[lib]
+proc-macro = true
+
+[features]
+anchor-debug = ["anchor-syn/anchor-debug"]
+
+[dependencies]
+proc-macro2 = "1.0"
+syn = { version = "1.0.60", features = ["full"] }
+anchor-syn = { path = "../../syn", version = "0.18.2" }

+ 11 - 0
lang/attribute/constant/src/lib.rs

@@ -0,0 +1,11 @@
+extern crate proc_macro;
+
+/// A marker attribute used to mark const values that should be included in the
+/// generated IDL but functionally does nothing.
+#[proc_macro_attribute]
+pub fn constant(
+    _attr: proc_macro::TokenStream,
+    input: proc_macro::TokenStream,
+) -> proc_macro::TokenStream {
+    input
+}

+ 5 - 4
lang/src/lib.rs

@@ -83,6 +83,7 @@ pub use crate::sysvar::Sysvar;
 pub use crate::unchecked_account::UncheckedAccount;
 pub use anchor_attribute_access_control::access_control;
 pub use anchor_attribute_account::{account, declare_id, zero_copy};
+pub use anchor_attribute_constant::constant;
 pub use anchor_attribute_error::error;
 pub use anchor_attribute_event::{emit, event};
 pub use anchor_attribute_interface::interface;
@@ -250,10 +251,10 @@ impl Key for Pubkey {
 /// All programs should include it via `anchor_lang::prelude::*;`.
 pub mod prelude {
     pub use super::{
-        access_control, account, declare_id, emit, error, event, interface, program, require,
-        state, zero_copy, Account, AccountDeserialize, AccountLoader, AccountSerialize, Accounts,
-        AccountsExit, AnchorDeserialize, AnchorSerialize, Context, CpiContext, Id, Key, Owner,
-        Program, Signer, System, SystemAccount, Sysvar, ToAccountInfo, ToAccountInfos,
+        access_control, account, constant, declare_id, emit, error, event, interface, program,
+        require, state, zero_copy, Account, AccountDeserialize, AccountLoader, AccountSerialize,
+        Accounts, AccountsExit, AnchorDeserialize, AnchorSerialize, Context, CpiContext, Id, Key,
+        Owner, Program, Signer, System, SystemAccount, Sysvar, ToAccountInfo, ToAccountInfos,
         ToAccountMetas, UncheckedAccount,
     };
 

+ 23 - 0
lang/syn/src/idl/file.rs

@@ -223,6 +223,15 @@ pub fn parse(filename: impl AsRef<Path>, version: String) -> Result<Option<Idl>>
         }
     }
 
+    let constants = parse_consts(&ctx)
+        .iter()
+        .map(|c: &&syn::ItemConst| IdlConst {
+            name: c.ident.to_string(),
+            ty: c.ty.to_token_stream().to_string().parse().unwrap(),
+            value: c.expr.to_token_stream().to_string().parse().unwrap(),
+        })
+        .collect::<Vec<IdlConst>>();
+
     Ok(Some(Idl {
         version,
         name: p.name.to_string(),
@@ -237,6 +246,7 @@ pub fn parse(filename: impl AsRef<Path>, version: String) -> Result<Option<Idl>>
         },
         errors: error_codes,
         metadata: None,
+        constants,
     }))
 }
 
@@ -344,6 +354,19 @@ fn parse_account_derives(ctx: &CrateContext) -> HashMap<String, AccountsStruct>
         .collect()
 }
 
+fn parse_consts(ctx: &CrateContext) -> Vec<&syn::ItemConst> {
+    ctx.consts()
+        .filter(|item_strct| {
+            for attr in &item_strct.attrs {
+                if attr.path.segments.last().unwrap().ident == "constant" {
+                    return true;
+                }
+            }
+            false
+        })
+        .collect()
+}
+
 // Parse all user defined types in the file.
 fn parse_ty_defs(ctx: &CrateContext) -> Result<Vec<IdlTypeDefinition>> {
     ctx.structs()

+ 9 - 0
lang/syn/src/idl/mod.rs

@@ -7,6 +7,8 @@ pub mod file;
 pub struct Idl {
     pub version: String,
     pub name: String,
+    #[serde(skip_serializing_if = "Vec::is_empty", default)]
+    pub constants: Vec<IdlConst>,
     pub instructions: Vec<IdlInstruction>,
     #[serde(skip_serializing_if = "Option::is_none", default)]
     pub state: Option<IdlState>,
@@ -22,6 +24,13 @@ pub struct Idl {
     pub metadata: Option<JsonValue>,
 }
 
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
+pub struct IdlConst {
+    pub name: String,
+    pub ty: IdlType,
+    pub value: String,
+}
+
 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 pub struct IdlState {
     #[serde(rename = "struct")]

+ 11 - 0
lang/syn/src/parser/context.rs

@@ -11,6 +11,10 @@ pub struct CrateContext {
 }
 
 impl CrateContext {
+    pub fn consts(&self) -> impl Iterator<Item = &syn::ItemConst> {
+        self.modules.iter().flat_map(|(_, ctx)| ctx.consts())
+    }
+
     pub fn structs(&self) -> impl Iterator<Item = &syn::ItemStruct> {
         self.modules.iter().flat_map(|(_, ctx)| ctx.structs())
     }
@@ -183,4 +187,11 @@ impl ParsedModule {
             _ => None,
         })
     }
+
+    fn consts(&self) -> impl Iterator<Item = &syn::ItemConst> {
+        self.items.iter().filter_map(|i| match i {
+            syn::Item::Const(item) => Some(item),
+            _ => None,
+        })
+    }
 }

+ 6 - 0
tests/misc/programs/misc/src/lib.rs

@@ -12,6 +12,12 @@ mod event;
 
 declare_id!("Fg6PaFpoGXkYsidMpWTK6W2BeZ7FEfcYkg476zPFsLnS");
 
+#[constant]
+pub const BASE: u128 = 1_000_000;
+#[constant]
+pub const DECIMALS: u8 = 6;
+pub const NO_IDL: u16 = 55;
+
 #[program]
 pub mod misc {
     use super::*;

+ 23 - 0
tests/misc/tests/misc.js

@@ -836,6 +836,29 @@ describe("misc", () => {
     assert.ok(account.data, 3);
   });
 
+  it("Should include BASE const in IDL", async () => {
+    assert(
+      miscIdl.constants.find(
+        (c) => c.name === "BASE" && c.ty === "u128" && c.value === "1_000_000"
+      ) !== undefined
+    );
+  });
+
+  it("Should include DECIMALS const in IDL", async () => {
+    assert(
+      miscIdl.constants.find(
+        (c) => c.name === "DECIMALS" && c.ty === "u8" && c.value === "6"
+      ) !== undefined
+    );
+  });
+
+  it("Should not include NO_IDL const in IDL", async () => {
+    assert.equal(
+      miscIdl.constants.find((c) => c.name === "NO_IDL"),
+      undefined
+    );
+  });
+  
   it("Can use multidimensional array", async () => {
     const array2d = new Array(10).fill(new Array(10).fill(99));
     const data = anchor.web3.Keypair.generate();

+ 7 - 0
ts/src/idl.ts

@@ -11,6 +11,13 @@ export type Idl = {
   types?: IdlTypeDef[];
   events?: IdlEvent[];
   errors?: IdlErrorCode[];
+  constants?: IdlConstant[];
+};
+
+export type IdlConstant = {
+  name: string;
+  type: IdlType;
+  value: string;
 };
 
 export type IdlEvent = {