Browse Source

lang: Add init_with_needed keyword (#906)

Armani Ferrante 4 years ago
parent
commit
95bb9b3183

+ 1 - 0
CHANGELOG.md

@@ -18,6 +18,7 @@ incremented for features.
 * ts: `Program<T>` can now be typed with an IDL type ([#795](https://github.com/project-serum/anchor/pull/795)).
 * lang: Add `mint::freeze_authority` keyword for mint initialization within `#[derive(Accounts)]` ([#835](https://github.com/project-serum/anchor/pull/835)).
 * lang: Add `AccountLoader` type for `zero_copy` accounts with support for CPI ([#792](https://github.com/project-serum/anchor/pull/792)).
+* lang: Add `#[account(init_if_needed)]` keyword for allowing one to invoke the same instruction even if the account was created already ([#906](https://github.com/project-serum/anchor/pull/906)).
 * lang: Add custom errors support for raw constraints ([#905](https://github.com/project-serum/anchor/pull/905)).
 
 ### Breaking

+ 64 - 48
lang/syn/src/codegen/accounts/constraints.rs

@@ -316,7 +316,7 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma
             }
         }
     };
-    generate_init(f, seeds_with_nonce, payer, &c.space, &c.kind)
+    generate_init(f, c.if_needed, seeds_with_nonce, payer, &c.space, &c.kind)
 }
 
 fn generate_constraint_seeds(f: &Field, c: &ConstraintSeedsGroup) -> proc_macro2::TokenStream {
@@ -397,8 +397,10 @@ fn generate_constraint_associated_token(
     }
 }
 
+// `if_needed` is set if account allocation and initialization is optional.
 pub fn generate_init(
     f: &Field,
+    if_needed: bool,
     seeds_with_nonce: proc_macro2::TokenStream,
     payer: proc_macro2::TokenStream,
     space: &Option<Expr>,
@@ -407,6 +409,11 @@ pub fn generate_init(
     let field = &f.ident;
     let ty_decl = f.ty_decl();
     let from_account_info = f.from_account_info_unchecked(Some(kind));
+    let if_needed = if if_needed {
+        quote! {true}
+    } else {
+        quote! {false}
+    };
     match kind {
         InitKind::Token { owner, mint } => {
             let create_account = generate_create_account(
@@ -417,22 +424,25 @@ pub fn generate_init(
             );
             quote! {
                 let #field: #ty_decl = {
-                    // Define payer variable.
-                    #payer
-
-                    // Create the account with the system program.
-                    #create_account
-
-                    // Initialize the token account.
-                    let cpi_program = token_program.to_account_info();
-                    let accounts = anchor_spl::token::InitializeAccount {
-                        account: #field.to_account_info(),
-                        mint: #mint.to_account_info(),
-                        authority: #owner.to_account_info(),
-                        rent: rent.to_account_info(),
-                    };
-                    let cpi_ctx = CpiContext::new(cpi_program, accounts);
-                    anchor_spl::token::initialize_account(cpi_ctx)?;
+                    if !#if_needed || #field.to_account_info().owner == &anchor_lang::solana_program::system_program::ID {
+                        // Define payer variable.
+                        #payer
+
+                        // Create the account with the system program.
+                        #create_account
+
+                        // Initialize the token account.
+                        let cpi_program = token_program.to_account_info();
+                        let accounts = anchor_spl::token::InitializeAccount {
+                            account: #field.to_account_info(),
+                            mint: #mint.to_account_info(),
+                            authority: #owner.to_account_info(),
+                            rent: rent.to_account_info(),
+                        };
+                        let cpi_ctx = CpiContext::new(cpi_program, accounts);
+                        anchor_spl::token::initialize_account(cpi_ctx)?;
+                    }
+
                     let pa: #ty_decl = #from_account_info;
                     pa
                 };
@@ -441,20 +451,22 @@ pub fn generate_init(
         InitKind::AssociatedToken { owner, mint } => {
             quote! {
                 let #field: #ty_decl = {
-                    #payer
-
-                    let cpi_program = associated_token_program.to_account_info();
-                    let cpi_accounts = anchor_spl::associated_token::Create {
-                        payer: payer.to_account_info(),
-                        associated_token: #field.to_account_info(),
-                        authority: #owner.to_account_info(),
-                        mint: #mint.to_account_info(),
-                        system_program: system_program.to_account_info(),
-                        token_program: token_program.to_account_info(),
-                        rent: rent.to_account_info(),
-                    };
-                    let cpi_ctx = CpiContext::new(cpi_program, cpi_accounts);
-                    anchor_spl::associated_token::create(cpi_ctx)?;
+                    if !#if_needed || #field.to_account_info().owner == &anchor_lang::solana_program::system_program::ID {
+                        #payer
+
+                        let cpi_program = associated_token_program.to_account_info();
+                        let cpi_accounts = anchor_spl::associated_token::Create {
+                            payer: payer.to_account_info(),
+                            associated_token: #field.to_account_info(),
+                            authority: #owner.to_account_info(),
+                            mint: #mint.to_account_info(),
+                            system_program: system_program.to_account_info(),
+                            token_program: token_program.to_account_info(),
+                            rent: rent.to_account_info(),
+                        };
+                        let cpi_ctx = CpiContext::new(cpi_program, cpi_accounts);
+                        anchor_spl::associated_token::create(cpi_ctx)?;
+                    }
                     let pa: #ty_decl = #from_account_info;
                     pa
                 };
@@ -477,20 +489,22 @@ pub fn generate_init(
             };
             quote! {
                 let #field: #ty_decl = {
-                    // Define payer variable.
-                    #payer
-
-                    // Create the account with the system program.
-                    #create_account
-
-                    // Initialize the mint account.
-                    let cpi_program = token_program.to_account_info();
-                    let accounts = anchor_spl::token::InitializeMint {
-                        mint: #field.to_account_info(),
-                        rent: rent.to_account_info(),
-                    };
-                    let cpi_ctx = CpiContext::new(cpi_program, accounts);
-                    anchor_spl::token::initialize_mint(cpi_ctx, #decimals, &#owner.to_account_info().key, #freeze_authority)?;
+                    if !#if_needed || #field.to_account_info().owner == &anchor_lang::solana_program::system_program::ID {
+                        // Define payer variable.
+                        #payer
+
+                        // Create the account with the system program.
+                        #create_account
+
+                        // Initialize the mint account.
+                        let cpi_program = token_program.to_account_info();
+                        let accounts = anchor_spl::token::InitializeMint {
+                            mint: #field.to_account_info(),
+                            rent: rent.to_account_info(),
+                        };
+                        let cpi_ctx = CpiContext::new(cpi_program, accounts);
+                        anchor_spl::token::initialize_mint(cpi_ctx, #decimals, &#owner.to_account_info().key, #freeze_authority)?;
+                    }
                     let pa: #ty_decl = #from_account_info;
                     pa
                 };
@@ -535,9 +549,11 @@ pub fn generate_init(
                 generate_create_account(field, quote! {space}, owner, seeds_with_nonce);
             quote! {
                 let #field = {
-                    #space
-                    #payer
-                    #create_account
+                    if !#if_needed || #field.to_account_info().owner == &anchor_lang::solana_program::system_program::ID {
+                        #space
+                        #payer
+                        #create_account
+                    }
                     let pa: #ty_decl = #from_account_info;
                     pa
                 };

+ 7 - 1
lang/syn/src/lib.rs

@@ -594,7 +594,12 @@ impl Parse for ConstraintToken {
 }
 
 #[derive(Debug, Clone)]
-pub struct ConstraintInit {}
+pub struct ConstraintInit {
+    pub if_needed: bool,
+}
+
+#[derive(Debug, Clone)]
+pub struct ConstraintInitIfNeeded {}
 
 #[derive(Debug, Clone)]
 pub struct ConstraintZeroed {}
@@ -639,6 +644,7 @@ pub enum ConstraintRentExempt {
 
 #[derive(Debug, Clone)]
 pub struct ConstraintInitGroup {
+    pub if_needed: bool,
     pub seeds: Option<ConstraintSeedsGroup>,
     pub payer: Option<Expr>,
     pub space: Option<Expr>,

+ 10 - 2
lang/syn/src/parser/accounts/constraints.rs

@@ -60,7 +60,14 @@ pub fn parse_token(stream: ParseStream) -> ParseResult<ConstraintToken> {
     let kw = ident.to_string();
 
     let c = match kw.as_str() {
-        "init" => ConstraintToken::Init(Context::new(ident.span(), ConstraintInit {})),
+        "init" => ConstraintToken::Init(Context::new(
+            ident.span(),
+            ConstraintInit { if_needed: false },
+        )),
+        "init_if_needed" => ConstraintToken::Init(Context::new(
+            ident.span(),
+            ConstraintInit { if_needed: true },
+        )),
         "zero" => ConstraintToken::Zeroed(Context::new(ident.span(), ConstraintZeroed {})),
         "mut" => ConstraintToken::Mut(Context::new(ident.span(), ConstraintMut {})),
         "signer" => ConstraintToken::Signer(Context::new(ident.span(), ConstraintSigner {})),
@@ -518,7 +525,8 @@ impl<'ty> ConstraintGroupBuilder<'ty> {
             _ => None,
         };
         Ok(ConstraintGroup {
-            init: init.as_ref().map(|_| Ok(ConstraintInitGroup {
+            init: init.as_ref().map(|i| Ok(ConstraintInitGroup {
+            if_needed: i.if_needed,
                 seeds: seeds.clone(),
                 payer: into_inner!(payer.clone()).map(|a| a.target),
                 space: space.clone().map(|s| s.space.clone()),

+ 8 - 0
tests/misc/programs/misc/src/context.rs

@@ -244,3 +244,11 @@ pub struct TestEmptySeedsConstraint<'info> {
     #[account(seeds = [], bump)]
     pub pda: AccountInfo<'info>,
 }
+
+#[derive(Accounts)]
+pub struct TestInitIfNeeded<'info> {
+    #[account(init_if_needed, payer = payer)]
+    pub data: Account<'info, DataU16>,
+    pub payer: Signer<'info>,
+    pub system_program: Program<'info, System>,
+}

+ 5 - 0
tests/misc/programs/misc/src/lib.rs

@@ -184,4 +184,9 @@ pub mod misc {
     pub fn test_empty_seeds_constraint(ctx: Context<TestEmptySeedsConstraint>) -> ProgramResult {
         Ok(())
     }
+
+    pub fn test_init_if_needed(ctx: Context<TestInitIfNeeded>, data: u16) -> ProgramResult {
+        ctx.accounts.data.data = data;
+        Ok(())
+    }
 }

+ 57 - 22
tests/misc/tests/misc.js

@@ -652,8 +652,8 @@ describe("misc", () => {
       accounts: {
         token: associatedToken,
         mint: mint.publicKey,
-        wallet: program.provider.wallet.publicKey
-      }
+        wallet: program.provider.wallet.publicKey,
+      },
     });
 
     await assert.rejects(
@@ -662,8 +662,8 @@ describe("misc", () => {
           accounts: {
             token: associatedToken,
             mint: mint.publicKey,
-            wallet: anchor.web3.Keypair.generate().publicKey
-          }
+            wallet: anchor.web3.Keypair.generate().publicKey,
+          },
         });
       },
       (err) => {
@@ -735,12 +735,11 @@ describe("misc", () => {
     ]);
     // Call for multiple kinds of .all.
     const allAccounts = await program.account.dataWithFilter.all();
-    const allAccountsFilteredByBuffer =
-      await program.account.dataWithFilter.all(
-        program.provider.wallet.publicKey.toBuffer()
-      );
-    const allAccountsFilteredByProgramFilters1 =
-      await program.account.dataWithFilter.all([
+    const allAccountsFilteredByBuffer = await program.account.dataWithFilter.all(
+      program.provider.wallet.publicKey.toBuffer()
+    );
+    const allAccountsFilteredByProgramFilters1 = await program.account.dataWithFilter.all(
+      [
         {
           memcmp: {
             offset: 8,
@@ -748,9 +747,10 @@ describe("misc", () => {
           },
         },
         { memcmp: { offset: 40, bytes: filterable1.toBase58() } },
-      ]);
-    const allAccountsFilteredByProgramFilters2 =
-      await program.account.dataWithFilter.all([
+      ]
+    );
+    const allAccountsFilteredByProgramFilters2 = await program.account.dataWithFilter.all(
+      [
         {
           memcmp: {
             offset: 8,
@@ -758,7 +758,8 @@ describe("misc", () => {
           },
         },
         { memcmp: { offset: 40, bytes: filterable2.toBase58() } },
-      ]);
+      ]
+    );
     // Without filters there should be 4 accounts.
     assert.equal(allAccounts.length, 4);
     // Filtering by main wallet there should be 3 accounts.
@@ -772,27 +773,33 @@ describe("misc", () => {
   });
 
   it("Can use pdas with empty seeds", async () => {
-    const [pda, bump] = await PublicKey.findProgramAddress([], program.programId);
+    const [pda, bump] = await PublicKey.findProgramAddress(
+      [],
+      program.programId
+    );
 
     await program.rpc.testInitWithEmptySeeds({
       accounts: {
         pda: pda,
         authority: program.provider.wallet.publicKey,
-        systemProgram: anchor.web3.SystemProgram.programId
-      }
+        systemProgram: anchor.web3.SystemProgram.programId,
+      },
     });
     await program.rpc.testEmptySeedsConstraint({
       accounts: {
-        pda: pda
-      }
+        pda: pda,
+      },
     });
 
-    const [pda2, bump2] = await PublicKey.findProgramAddress(["non-empty"], program.programId);
+    const [pda2, bump2] = await PublicKey.findProgramAddress(
+      ["non-empty"],
+      program.programId
+    );
     await assert.rejects(
       program.rpc.testEmptySeedsConstraint({
         accounts: {
-          pda: pda2
-        }
+          pda: pda2,
+        },
       }),
       (err) => {
         assert.equal(err.code, 146);
@@ -800,4 +807,32 @@ describe("misc", () => {
       }
     );
   });
+
+  const ifNeededAcc = anchor.web3.Keypair.generate();
+
+  it("Can init if needed a new account", async () => {
+    await program.rpc.testInitIfNeeded(1, {
+      accounts: {
+        data: ifNeededAcc.publicKey,
+        systemProgram: anchor.web3.SystemProgram.programId,
+        payer: program.provider.wallet.publicKey,
+      },
+      signers: [ifNeededAcc],
+    });
+    const account = await program.account.dataU16.fetch(ifNeededAcc.publicKey);
+    assert.ok(account.data, 1);
+  });
+
+  it("Can init if needed a previously created account", async () => {
+    await program.rpc.testInitIfNeeded(3, {
+      accounts: {
+        data: ifNeededAcc.publicKey,
+        systemProgram: anchor.web3.SystemProgram.programId,
+        payer: program.provider.wallet.publicKey,
+      },
+      signers: [ifNeededAcc],
+    });
+    const account = await program.account.dataU16.fetch(ifNeededAcc.publicKey);
+    assert.ok(account.data, 3);
+  });
 });