Browse Source

lang, ts: Fallback functions (#457)

Armani Ferrante 4 years ago
parent
commit
915e6dd398

+ 5 - 0
CHANGELOG.md

@@ -11,6 +11,11 @@ incremented for features.
 
 ## [Unreleased]
 
+### Features
+
+* lang: Add fallback functions ([#457](https://github.com/project-serum/anchor/pull/457)).
+* lang: Add feature flag for using the old state account discriminator. This is a temporary flag for those with programs built prior to v0.7.0 but want to use the latest Anchor version. Expect this to be removed in a future version ([#446](https://github.com/project-serum/anchor/pull/446)).
+
 ### Breaking Changes
 
 * cli: Remove `.spec` suffix on TypeScript tests files ([#441](https://github.com/project-serum/anchor/pull/441)).

+ 8 - 0
examples/misc/programs/misc/src/lib.rs

@@ -128,6 +128,14 @@ pub mod misc {
     pub fn test_token_seeds_init(_ctx: Context<TestTokenSeedsInit>, _nonce: u8) -> ProgramResult {
         Ok(())
     }
+
+    pub fn default<'info>(
+        _program_id: &Pubkey,
+        _accounts: &[AccountInfo<'info>],
+        _data: &[u8],
+    ) -> ProgramResult {
+        Err(ProgramError::Custom(1234))
+    }
 }
 
 #[derive(Accounts)]

+ 38 - 22
examples/misc/tests/misc.js

@@ -141,17 +141,19 @@ describe("misc", () => {
 
     // Manual associated address calculation for test only. Clients should use
     // the generated methods.
-    const [associatedAccount, nonce] =
-      await anchor.web3.PublicKey.findProgramAddress(
-        [
-          anchor.utils.bytes.utf8.encode("anchor"),
-          program.provider.wallet.publicKey.toBuffer(),
-          state.toBuffer(),
-          data.publicKey.toBuffer(),
-          anchor.utils.bytes.utf8.encode("my-seed"),
-        ],
-        program.programId
-      );
+    const [
+      associatedAccount,
+      nonce,
+    ] = await anchor.web3.PublicKey.findProgramAddress(
+      [
+        anchor.utils.bytes.utf8.encode("anchor"),
+        program.provider.wallet.publicKey.toBuffer(),
+        state.toBuffer(),
+        data.publicKey.toBuffer(),
+        anchor.utils.bytes.utf8.encode("my-seed"),
+      ],
+      program.programId
+    );
     await assert.rejects(
       async () => {
         await program.account.testData.fetch(associatedAccount);
@@ -186,17 +188,19 @@ describe("misc", () => {
 
   it("Can use an associated program account", async () => {
     const state = await program.state.address();
-    const [associatedAccount, nonce] =
-      await anchor.web3.PublicKey.findProgramAddress(
-        [
-          anchor.utils.bytes.utf8.encode("anchor"),
-          program.provider.wallet.publicKey.toBuffer(),
-          state.toBuffer(),
-          data.publicKey.toBuffer(),
-          anchor.utils.bytes.utf8.encode("my-seed"),
-        ],
-        program.programId
-      );
+    const [
+      associatedAccount,
+      nonce,
+    ] = await anchor.web3.PublicKey.findProgramAddress(
+      [
+        anchor.utils.bytes.utf8.encode("anchor"),
+        program.provider.wallet.publicKey.toBuffer(),
+        state.toBuffer(),
+        data.publicKey.toBuffer(),
+        anchor.utils.bytes.utf8.encode("my-seed"),
+      ],
+      program.programId
+    );
     await program.rpc.testAssociatedAccount(new anchor.BN(5), {
       accounts: {
         myAccount: associatedAccount,
@@ -435,4 +439,16 @@ describe("misc", () => {
     assert.ok(account.owner.equals(program.provider.wallet.publicKey));
     assert.ok(account.mint.equals(mint.publicKey));
   });
+
+  it("Can execute a fallback function", async () => {
+    await assert.rejects(
+      async () => {
+        await anchor.utils.rpc.invoke(program.programId);
+      },
+      (err) => {
+        assert.ok(err.toString().includes("custom program error: 0x4d2"));
+        return true;
+      }
+    );
+  });
 });

+ 15 - 3
lang/syn/src/codegen/program/dispatch.rs

@@ -113,7 +113,9 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
             }
         })
         .collect();
-
+    let fallback_fn = gen_fallback(program).unwrap_or(quote! {
+        Err(anchor_lang::__private::ErrorCode::InstructionFallbackNotFound.into())
+    });
     quote! {
         /// Performs method dispatch.
         ///
@@ -152,10 +154,20 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
                 #(#trait_dispatch_arms)*
                 #(#global_dispatch_arms)*
                 _ => {
-                    msg!("Fallback functions are not supported. If you have a use case, please file an issue.");
-                    Err(anchor_lang::__private::ErrorCode::InstructionFallbackNotFound.into())
+                    #fallback_fn
                 }
             }
         }
     }
 }
+
+pub fn gen_fallback(program: &Program) -> Option<proc_macro2::TokenStream> {
+    program.fallback_fn.as_ref().map(|fallback_fn| {
+        let program_name = &program.name;
+        let method = &fallback_fn.raw_method;
+        let fn_name = &method.sig.ident;
+        quote! {
+            #program_name::#fn_name(program_id, accounts, ix_data)
+        }
+    })
+}

+ 6 - 2
lang/syn/src/codegen/program/entry.rs

@@ -1,7 +1,11 @@
+use crate::program_codegen::dispatch;
 use crate::Program;
 use quote::quote;
 
-pub fn generate(_program: &Program) -> proc_macro2::TokenStream {
+pub fn generate(program: &Program) -> proc_macro2::TokenStream {
+    let fallback_maybe = dispatch::gen_fallback(program).unwrap_or(quote! {
+        Err(anchor_lang::__private::ErrorCode::InstructionMissing.into());
+    });
     quote! {
         #[cfg(not(feature = "no-entrypoint"))]
         anchor_lang::solana_program::entrypoint!(entry);
@@ -52,7 +56,7 @@ pub fn generate(_program: &Program) -> proc_macro2::TokenStream {
                 msg!("anchor-debug is active");
             }
             if ix_data.len() < 8 {
-                return Err(anchor_lang::__private::ErrorCode::InstructionMissing.into());
+                return #fallback_maybe
             }
 
             // Split the instruction data into the first 8 byte method

+ 6 - 0
lang/syn/src/lib.rs

@@ -30,6 +30,7 @@ pub struct Program {
     pub ixs: Vec<Ix>,
     pub name: Ident,
     pub program_mod: ItemMod,
+    pub fallback_fn: Option<FallbackFn>,
 }
 
 impl Parse for Program {
@@ -92,6 +93,11 @@ pub struct IxArg {
     pub raw_arg: PatType,
 }
 
+#[derive(Debug)]
+pub struct FallbackFn {
+    raw_method: ItemFn,
+}
+
 #[derive(Debug)]
 pub struct AccountsStruct {
     // Name of the accounts struct.

+ 38 - 5
lang/syn/src/parser/program/instructions.rs

@@ -1,20 +1,24 @@
 use crate::parser::program::ctx_accounts_ident;
-use crate::{Ix, IxArg};
+use crate::{FallbackFn, Ix, IxArg};
 use syn::parse::{Error as ParseError, Result as ParseResult};
 use syn::spanned::Spanned;
 
 // Parse all non-state ix handlers from the program mod definition.
-pub fn parse(program_mod: &syn::ItemMod) -> ParseResult<Vec<Ix>> {
+pub fn parse(program_mod: &syn::ItemMod) -> ParseResult<(Vec<Ix>, Option<FallbackFn>)> {
     let mod_content = &program_mod
         .content
         .as_ref()
         .ok_or_else(|| ParseError::new(program_mod.span(), "program content not provided"))?
         .1;
 
-    mod_content
+    let ixs = mod_content
         .iter()
         .filter_map(|item| match item {
-            syn::Item::Fn(item_fn) => Some(item_fn),
+            syn::Item::Fn(item_fn) => {
+                let (ctx, _) = parse_args(item_fn).ok()?;
+                ctx_accounts_ident(&ctx.raw_arg).ok()?;
+                Some(item_fn)
+            }
             _ => None,
         })
         .map(|method: &syn::ItemFn| {
@@ -27,7 +31,36 @@ pub fn parse(program_mod: &syn::ItemMod) -> ParseResult<Vec<Ix>> {
                 anchor_ident,
             })
         })
-        .collect::<ParseResult<Vec<Ix>>>()
+        .collect::<ParseResult<Vec<Ix>>>()?;
+
+    let fallback_fn = {
+        let fallback_fns = mod_content
+            .iter()
+            .filter_map(|item| match item {
+                syn::Item::Fn(item_fn) => {
+                    let (ctx, _args) = parse_args(item_fn).ok()?;
+                    if ctx_accounts_ident(&ctx.raw_arg).is_ok() {
+                        return None;
+                    }
+                    Some(item_fn)
+                }
+                _ => None,
+            })
+            .collect::<Vec<_>>();
+        if fallback_fns.len() > 1 {
+            return Err(ParseError::new(
+                fallback_fns[0].span(),
+                "More than one fallback function found",
+            ));
+        }
+        fallback_fns
+            .first()
+            .map(|method: &&syn::ItemFn| FallbackFn {
+                raw_method: (*method).clone(),
+            })
+    };
+
+    Ok((ixs, fallback_fn))
 }
 
 pub fn parse_args(method: &syn::ItemFn) -> ParseResult<(IxArg, Vec<IxArg>)> {

+ 2 - 2
lang/syn/src/parser/program/mod.rs

@@ -7,13 +7,13 @@ mod state;
 
 pub fn parse(program_mod: syn::ItemMod) -> ParseResult<Program> {
     let state = state::parse(&program_mod)?;
-    let ixs = instructions::parse(&program_mod)?;
-
+    let (ixs, fallback_fn) = instructions::parse(&program_mod)?;
     Ok(Program {
         state,
         ixs,
         name: program_mod.ident.clone(),
         program_mod,
+        fallback_fn,
     })
 }
 

+ 38 - 1
ts/src/utils/rpc.ts

@@ -1,5 +1,42 @@
 import assert from "assert";
-import { PublicKey, AccountInfo, Connection } from "@solana/web3.js";
+import {
+  AccountInfo,
+  AccountMeta,
+  Connection,
+  PublicKey,
+  TransactionSignature,
+  Transaction,
+  TransactionInstruction,
+} from "@solana/web3.js";
+import { Address, translateAddress } from "../program/common";
+import Provider, { getProvider } from "../provider";
+
+/**
+ * Sends a transaction to a program with the given accounts and instruction
+ * data.
+ */
+export async function invoke(
+  programId: Address,
+  accounts?: Array<AccountMeta>,
+  data?: Buffer,
+  provider?: Provider
+): Promise<TransactionSignature> {
+  programId = translateAddress(programId);
+  if (!provider) {
+    provider = getProvider();
+  }
+
+  const tx = new Transaction();
+  tx.add(
+    new TransactionInstruction({
+      programId,
+      keys: accounts ?? [],
+      data,
+    })
+  );
+
+  return await provider.send(tx);
+}
 
 export async function getMultipleAccounts(
   connection: Connection,