Browse Source

lang: Add `Account` utility type to get accounts from bytes (#3091)

acheron 1 year ago
parent
commit
117717468f

+ 1 - 0
CHANGELOG.md

@@ -16,6 +16,7 @@ The minor version will be incremented upon a breaking change and the patch versi
 - cli, idl: Pass `cargo` args to IDL generation when building program or IDL ([#3059](https://github.com/coral-xyz/anchor/pull/3059)).
 - cli: Add checks for incorrect usage of `idl-build` feature ([#3061](https://github.com/coral-xyz/anchor/pull/3061)).
 - lang: Export `Discriminator` trait from `prelude` ([#3075](https://github.com/coral-xyz/anchor/pull/3075)).
+- lang: Add `Account` utility type to get accounts from bytes ([#3091](https://github.com/coral-xyz/anchor/pull/3091)).
 
 ### Fixes
 

+ 56 - 2
lang/attribute/program/src/declare_program/mods/utils.rs

@@ -4,16 +4,72 @@ use quote::{format_ident, quote};
 use super::common::gen_discriminator;
 
 pub fn gen_utils_mod(idl: &Idl) -> proc_macro2::TokenStream {
+    let account = gen_account(idl);
     let event = gen_event(idl);
 
     quote! {
         /// Program utilities.
         pub mod utils {
+            use super::*;
+
+            #account
             #event
         }
     }
 }
 
+fn gen_account(idl: &Idl) -> proc_macro2::TokenStream {
+    let variants = idl
+        .accounts
+        .iter()
+        .map(|acc| format_ident!("{}", acc.name))
+        .map(|name| quote! { #name(#name) });
+    let match_arms = idl.accounts.iter().map(|acc| {
+        let disc = gen_discriminator(&acc.discriminator);
+        let name = format_ident!("{}", acc.name);
+        let account = quote! {
+            #name::try_from_slice(&value[8..])
+                .map(Self::#name)
+                .map_err(Into::into)
+        };
+        quote! { #disc => #account }
+    });
+
+    quote! {
+        /// An enum that includes all accounts of the declared program as a tuple variant.
+        ///
+        /// See [`Self::try_from_bytes`] to create an instance from bytes.
+        pub enum Account {
+            #(#variants,)*
+        }
+
+        impl Account {
+            /// Try to create an account based on the given bytes.
+            ///
+            /// This method returns an error if the discriminator of the given bytes don't match
+            /// with any of the existing accounts, or if the deserialization fails.
+            pub fn try_from_bytes(bytes: &[u8]) -> Result<Self> {
+                Self::try_from(bytes)
+            }
+        }
+
+        impl TryFrom<&[u8]> for Account {
+            type Error = anchor_lang::error::Error;
+
+            fn try_from(value: &[u8]) -> Result<Self> {
+                if value.len() < 8 {
+                    return Err(ProgramError::InvalidArgument.into());
+                }
+
+                match &value[..8] {
+                    #(#match_arms,)*
+                    _ => Err(ProgramError::InvalidArgument.into()),
+                }
+            }
+        }
+    }
+}
+
 fn gen_event(idl: &Idl) -> proc_macro2::TokenStream {
     let variants = idl
         .events
@@ -32,8 +88,6 @@ fn gen_event(idl: &Idl) -> proc_macro2::TokenStream {
     });
 
     quote! {
-        use super::*;
-
         /// An enum that includes all events of the declared program as a tuple variant.
         ///
         /// See [`Self::try_from_bytes`] to create an instance from bytes.

+ 24 - 0
tests/declare-program/programs/declare-program/src/lib.rs

@@ -52,6 +52,30 @@ pub mod declare_program {
         Ok(())
     }
 
+    pub fn account_utils(_ctx: Context<Utils>) -> Result<()> {
+        use external::utils::Account;
+
+        // Empty
+        if Account::try_from_bytes(&[]).is_ok() {
+            return Err(ProgramError::Custom(0).into());
+        }
+
+        const DISC: &[u8] = &external::accounts::MyAccount::DISCRIMINATOR;
+
+        // Correct discriminator but invalid data
+        if Account::try_from_bytes(DISC).is_ok() {
+            return Err(ProgramError::Custom(1).into());
+        };
+
+        // Correct discriminator and valid data
+        match Account::try_from_bytes(&[DISC, &[1, 0, 0, 0]].concat()) {
+            Ok(Account::MyAccount(my_account)) => require_eq!(my_account.field, 1),
+            Err(e) => return Err(e.into()),
+        }
+
+        Ok(())
+    }
+
     pub fn event_utils(_ctx: Context<Utils>) -> Result<()> {
         use external::utils::Event;
 

+ 4 - 0
tests/declare-program/tests/declare-program.ts

@@ -47,6 +47,10 @@ describe("declare-program", () => {
     assert.strictEqual(myAccount.field, value);
   });
 
+  it("Can use account utils", async () => {
+    await program.methods.accountUtils().rpc();
+  });
+
   it("Can use event utils", async () => {
     await program.methods.eventUtils().rpc();
   });