فهرست منبع

lang: Add associated init constraint (#318)

Armani Ferrante 4 سال پیش
والد
کامیت
68c601ab22

+ 6 - 9
.travis.yml

@@ -41,12 +41,6 @@ jobs:
         - cargo build
         - cargo fmt -- --check
         - cargo test
-    - <<: *defaults
-      name: Build and test TypeScript
-      script:
-        - cd ts
-        - yarn
-        - yarn build
     - <<: *examples
       name: Runs the examples 1
       script:
@@ -58,17 +52,20 @@ jobs:
         - pushd examples/multisig && anchor test && popd
         - pushd examples/interface && anchor test && popd
         - pushd examples/lockup && anchor test && popd
+    - <<: *examples
+      name: Runs the examples 2
+      script:
         - pushd examples/misc && anchor test && popd
         - pushd examples/events && anchor test && popd
         - pushd examples/cashiers-check && anchor test && popd
         - pushd examples/typescript && yarn && anchor test && popd
         - pushd examples/zero-copy && yarn && anchor test && popd
-    - <<: *examples
-      name: Runs the examples 2
-      script:
         - pushd examples/chat && yarn && anchor test && popd
         - pushd examples/ido-pool && yarn && anchor test && popd
         - pushd examples/swap/deps/serum-dex/dex && cargo build-bpf && cd ../../../ && anchor test && popd
+    - <<: *examples
+      name: Runs the examples 3
+      script:
         - pushd examples/pyth && yarn && anchor test && popd
         - pushd examples/tutorial/basic-0 && anchor test && popd
         - pushd examples/tutorial/basic-1 && anchor test && popd

+ 4 - 0
CHANGELOG.md

@@ -17,6 +17,10 @@ incremented for features.
 * cli: Add global options for override Anchor.toml values ([#313](https://github.com/project-serum/anchor/pull/313)).
 * spl: Add `SetAuthority` instruction ([#307](https://github.com/project-serum/anchor/pull/307/files)).
 
+## Breaking Changes
+
+* lang: `#[account(associated)]` now requires `init` to be provided to create an associated account. If not provided, then the address will be assumed to exist, and a constraint will be added to ensure its correctness ([#318](https://github.com/project-serum/anchor/pull/318)).
+
 ## [0.6.0] - 2021-05-23
 
 ## Features

+ 3 - 2
docs/src/tutorials/tutorial-6.md

@@ -77,8 +77,9 @@ Lastly, notice the two accounts at the bottom of account context.
     system_program: AccountInfo<'info>,
 ```
 
-Although a bit of an implementaion detail, these accounts are required so that Anchor
-can create your associated account. By convention, the names must be as given here.
+In the same way that `rent` is required when using `init` in the previous tutorials,
+`rent` and additionally the `system-program` must be provided when creating an associated
+account. By convention, the names must be as given here.
 
 For more details on how to use `#[account(associated)]`, see [docs.rs](https://docs.rs/anchor-lang/latest/anchor_lang/derive.Accounts.html).
 

+ 2 - 2
examples/chat/programs/chat/src/lib.rs

@@ -36,7 +36,7 @@ pub mod chat {
 
 #[derive(Accounts)]
 pub struct CreateUser<'info> {
-    #[account(associated = authority, space = "312")]
+    #[account(init, associated = authority, space = "312")]
     user: ProgramAccount<'info, User>,
     #[account(signer)]
     authority: AccountInfo<'info>,
@@ -53,7 +53,7 @@ pub struct CreateChatRoom<'info> {
 
 #[derive(Accounts)]
 pub struct SendMessage<'info> {
-    #[account(has_one = authority)]
+    #[account(associated = authority, has_one = authority)]
     user: ProgramAccount<'info, User>,
     #[account(signer)]
     authority: AccountInfo<'info>,

+ 21 - 3
examples/misc/programs/misc/src/lib.rs

@@ -45,7 +45,15 @@ pub mod misc {
         misc2::cpi::state::set_data(ctx, data)
     }
 
-    pub fn test_associated_account_creation(
+    pub fn test_init_associated_account(
+        ctx: Context<TestInitAssociatedAccount>,
+        data: u64,
+    ) -> ProgramResult {
+        ctx.accounts.my_account.data = data;
+        Ok(())
+    }
+
+    pub fn test_associated_account(
         ctx: Context<TestAssociatedAccount>,
         data: u64,
     ) -> ProgramResult {
@@ -119,8 +127,8 @@ pub struct TestStateCpi<'info> {
 // accounts are needed when creating the associated program address within
 // the program.
 #[derive(Accounts)]
-pub struct TestAssociatedAccount<'info> {
-    #[account(associated = authority, with = state, with = data)]
+pub struct TestInitAssociatedAccount<'info> {
+    #[account(init, associated = authority, with = state, with = data)]
     my_account: ProgramAccount<'info, TestData>,
     #[account(mut, signer)]
     authority: AccountInfo<'info>,
@@ -130,6 +138,16 @@ pub struct TestAssociatedAccount<'info> {
     system_program: AccountInfo<'info>,
 }
 
+#[derive(Accounts)]
+pub struct TestAssociatedAccount<'info> {
+    #[account(associated = authority, with = state, with = data)]
+    my_account: ProgramAccount<'info, TestData>,
+    #[account(mut, signer)]
+    authority: AccountInfo<'info>,
+    state: ProgramState<'info, MyState>,
+    data: ProgramAccount<'info, Data>,
+}
+
 #[derive(Accounts)]
 pub struct TestU16<'info> {
     #[account(init)]

+ 33 - 2
examples/misc/tests/misc.js

@@ -126,7 +126,7 @@ describe("misc", () => {
     assert.ok(stateAccount.auth.equals(program.provider.wallet.publicKey));
   });
 
-  it("Can create an associated program account", async () => {
+  it("Can init an associated program account", async () => {
     const state = await program.state.address();
 
     // Manual associated address calculation for test only. Clients should use
@@ -155,7 +155,7 @@ describe("misc", () => {
         return true;
       }
     );
-    await program.rpc.testAssociatedAccountCreation(new anchor.BN(1234), {
+    await program.rpc.testInitAssociatedAccount(new anchor.BN(1234), {
       accounts: {
         myAccount: associatedAccount,
         authority: program.provider.wallet.publicKey,
@@ -174,6 +174,37 @@ describe("misc", () => {
     assert.ok(account.data.toNumber() === 1234);
   });
 
+  it("Can use an associated program account", async () => {
+    const state = await program.state.address();
+    const [
+      associatedAccount,
+      nonce,
+    ] = await anchor.web3.PublicKey.findProgramAddress(
+      [
+        Buffer.from([97, 110, 99, 104, 111, 114]), // b"anchor".
+        program.provider.wallet.publicKey.toBuffer(),
+        state.toBuffer(),
+        data.publicKey.toBuffer(),
+      ],
+      program.programId
+    );
+    await program.rpc.testAssociatedAccount(new anchor.BN(5), {
+      accounts: {
+        myAccount: associatedAccount,
+        authority: program.provider.wallet.publicKey,
+        state,
+        data: data.publicKey,
+      },
+    });
+    // Try out the generated associated method.
+    const account = await program.account.testData.associated(
+      program.provider.wallet.publicKey,
+      state,
+      data.publicKey
+    );
+    assert.ok(account.data.toNumber() === 5);
+  });
+
   it("Can retrieve events when simulating a transaction", async () => {
     const resp = await program.simulate.testSimulate(44);
     const expectedRaw = [

+ 1 - 1
examples/tutorial/basic-5/programs/basic-5/src/lib.rs

@@ -34,7 +34,7 @@ pub struct CreateMint<'info> {
 
 #[derive(Accounts)]
 pub struct CreateToken<'info> {
-    #[account(associated = authority, with = mint)]
+    #[account(init, associated = authority, with = mint)]
     token: ProgramAccount<'info, Token>,
     #[account(mut, signer)]
     authority: AccountInfo<'info>,

+ 3 - 2
examples/zero-copy/programs/zero-copy/src/lib.rs

@@ -132,7 +132,7 @@ pub struct UpdateFooSecond<'info> {
 
 #[derive(Accounts)]
 pub struct CreateBar<'info> {
-    #[account(associated = authority, with = foo)]
+    #[account(init, associated = authority, with = foo)]
     bar: Loader<'info, Bar>,
     #[account(signer)]
     authority: AccountInfo<'info>,
@@ -143,10 +143,11 @@ pub struct CreateBar<'info> {
 
 #[derive(Accounts)]
 pub struct UpdateBar<'info> {
-    #[account(mut, has_one = authority)]
+    #[account(mut, associated = authority, with = foo, has_one = authority)]
     bar: Loader<'info, Bar>,
     #[account(signer)]
     authority: AccountInfo<'info>,
+    foo: Loader<'info, Foo>,
 }
 
 #[derive(Accounts)]

+ 18 - 15
examples/zero-copy/tests/zero-copy.js

@@ -1,4 +1,6 @@
 const anchor = require("@project-serum/anchor");
+const PublicKey = anchor.web3.PublicKey;
+const BN = anchor.BN;
 const assert = require("assert");
 
 describe("zero-copy", () => {
@@ -19,15 +21,15 @@ describe("zero-copy", () => {
     assert.ok(state.authority.equals(program.provider.wallet.publicKey));
     assert.ok(state.events.length === 250);
     state.events.forEach((event, idx) => {
-      assert.ok(event.from.equals(new anchor.web3.PublicKey()));
+      assert.ok(event.from.equals(new PublicKey()));
       assert.ok(event.data.toNumber() === 0);
     });
   });
 
   it("Updates zero copy state", async () => {
     let event = {
-      from: new anchor.web3.PublicKey(),
-      data: new anchor.BN(1234),
+      from: new PublicKey(),
+      data: new BN(1234),
     };
     await program.state.rpc.setEvent(5, event, {
       accounts: {
@@ -42,7 +44,7 @@ describe("zero-copy", () => {
         assert.ok(event.from.equals(event.from));
         assert.ok(event.data.eq(event.data));
       } else {
-        assert.ok(event.from.equals(new anchor.web3.PublicKey()));
+        assert.ok(event.from.equals(new PublicKey()));
         assert.ok(event.data.toNumber() === 0);
       }
     });
@@ -72,7 +74,7 @@ describe("zero-copy", () => {
   });
 
   it("Updates a zero copy account field", async () => {
-    await program.rpc.updateFoo(new anchor.BN(1234), {
+    await program.rpc.updateFoo(new BN(1234), {
       accounts: {
         foo: foo.publicKey,
         authority: program.provider.wallet.publicKey,
@@ -94,7 +96,7 @@ describe("zero-copy", () => {
   });
 
   it("Updates a a second zero copy account field", async () => {
-    await program.rpc.updateFooSecond(new anchor.BN(55), {
+    await program.rpc.updateFooSecond(new BN(55), {
       accounts: {
         foo: foo.publicKey,
         secondAuthority: program.provider.wallet.publicKey,
@@ -138,13 +140,14 @@ describe("zero-copy", () => {
   });
 
   it("Updates an associated zero copy account", async () => {
-    await program.rpc.updateBar(new anchor.BN(99), {
+    await program.rpc.updateBar(new BN(99), {
       accounts: {
         bar: await program.account.bar.associatedAddress(
           program.provider.wallet.publicKey,
           foo.publicKey
         ),
         authority: program.provider.wallet.publicKey,
+        foo: foo.publicKey,
       },
     });
     const bar = await program.account.bar.associated(
@@ -172,14 +175,14 @@ describe("zero-copy", () => {
     const account = await program.account.eventQ(eventQ.publicKey);
     assert.ok(account.events.length === 25000);
     account.events.forEach((event) => {
-      assert.ok(event.from.equals(new anchor.web3.PublicKey()));
+      assert.ok(event.from.equals(new PublicKey()));
       assert.ok(event.data.toNumber() === 0);
     });
   });
 
   it("Updates a large event queue", async () => {
     // Set index 0.
-    await program.rpc.updateLargeAccount(0, new anchor.BN(48), {
+    await program.rpc.updateLargeAccount(0, new BN(48), {
       accounts: {
         eventQ: eventQ.publicKey,
         from: program.provider.wallet.publicKey,
@@ -193,13 +196,13 @@ describe("zero-copy", () => {
         assert.ok(event.from.equals(program.provider.wallet.publicKey));
         assert.ok(event.data.toNumber() === 48);
       } else {
-        assert.ok(event.from.equals(new anchor.web3.PublicKey()));
+        assert.ok(event.from.equals(new PublicKey()));
         assert.ok(event.data.toNumber() === 0);
       }
     });
 
     // Set index 11111.
-    await program.rpc.updateLargeAccount(11111, new anchor.BN(1234), {
+    await program.rpc.updateLargeAccount(11111, new BN(1234), {
       accounts: {
         eventQ: eventQ.publicKey,
         from: program.provider.wallet.publicKey,
@@ -216,13 +219,13 @@ describe("zero-copy", () => {
         assert.ok(event.from.equals(program.provider.wallet.publicKey));
         assert.ok(event.data.toNumber() === 1234);
       } else {
-        assert.ok(event.from.equals(new anchor.web3.PublicKey()));
+        assert.ok(event.from.equals(new PublicKey()));
         assert.ok(event.data.toNumber() === 0);
       }
     });
 
     // Set last index.
-    await program.rpc.updateLargeAccount(24999, new anchor.BN(99), {
+    await program.rpc.updateLargeAccount(24999, new BN(99), {
       accounts: {
         eventQ: eventQ.publicKey,
         from: program.provider.wallet.publicKey,
@@ -242,7 +245,7 @@ describe("zero-copy", () => {
         assert.ok(event.from.equals(program.provider.wallet.publicKey));
         assert.ok(event.data.toNumber() === 99);
       } else {
-        assert.ok(event.from.equals(new anchor.web3.PublicKey()));
+        assert.ok(event.from.equals(new PublicKey()));
         assert.ok(event.data.toNumber() === 0);
       }
     });
@@ -252,7 +255,7 @@ describe("zero-copy", () => {
     // Fail to set non existing index.
     await assert.rejects(
       async () => {
-        await program.rpc.updateLargeAccount(25000, new anchor.BN(1), {
+        await program.rpc.updateLargeAccount(25000, new BN(1), {
           accounts: {
             eventQ: eventQ.publicKey,
             from: program.provider.wallet.publicKey,

+ 1 - 1
lang/derive/accounts/src/lib.rs

@@ -49,7 +49,7 @@ use syn::parse_macro_input;
 /// | `#[account(executable)]` | On `AccountInfo` structs | Checks the given account is an executable program. |
 /// | `#[account(state = <target>)]` | On `CpiState` structs | Checks the given state is the canonical state account for the target program. |
 /// | `#[account(owner = <target>)]` | On `CpiState`, `CpiAccount`, and `AccountInfo` | Checks the account owner matches the target. |
-/// | `#[account(associated = <target>, with? = <target>, payer? = <target>, space? = "<literal>")]` | On `ProgramAccount` | Creates an associated program account at a program derived address. `associated` is the SOL address to create the account for. `with` is an optional association, for example, a `Mint` account in the SPL token program. `payer` is an optional account to pay for the account creation, defaulting to the `associated` target if none is given. `space` is an optional literal specifying how large the account is, defaulting to the account's serialized `Default::default` size (+ 8 for the account discriminator) if none is given. When creating an associated account, a `rent` `Sysvar` and `system_program` `AccountInfo` must be present in the `Accounts` struct. |
+/// | `#[account(associated = <target>, with? = <target>, payer? = <target>, space? = "<literal>")]` | On `ProgramAccount` | Whe `init` is provided, creates an associated program account at a program derived address. `associated` is the SOL address to create the account for. `with` is an optional association, for example, a `Mint` account in the SPL token program. `payer` is an optional account to pay for the account creation, defaulting to the `associated` target if none is given. `space` is an optional literal specifying how large the account is, defaulting to the account's serialized `Default::default` size (+ 8 for the account discriminator) if none is given. When creating an associated account, a `rent` `Sysvar` and `system_program` `AccountInfo` must be present in the `Accounts` struct. When `init` is not provided, then ensures the given associated account has the expected address, defined by the program and the given seeds. |
 // TODO: How do we make the markdown render correctly without putting everything
 //       on absurdly long lines?
 #[proc_macro_derive(Accounts, attributes(account))]

+ 67 - 30
lang/syn/src/codegen/accounts.rs

@@ -8,8 +8,11 @@ use quote::quote;
 
 pub fn generate(accs: AccountsStruct) -> proc_macro2::TokenStream {
     // All fields without an `#[account(associated)]` attribute.
-    let non_associated_fields: Vec<&AccountField> =
-        accs.fields.iter().filter(|af| !is_associated(af)).collect();
+    let non_associated_fields: Vec<&AccountField> = accs
+        .fields
+        .iter()
+        .filter(|af| !is_associated_init(af))
+        .collect();
 
     // Deserialization for each field
     let deser_fields: Vec<proc_macro2::TokenStream> = accs
@@ -30,7 +33,7 @@ pub fn generate(accs: AccountsStruct) -> proc_macro2::TokenStream {
                     // Associated fields are *first* deserialized into
                     // AccountInfos, and then later deserialized into
                     // ProgramAccounts in the "constraint check" phase.
-                    if is_associated(af) {
+                    if is_associated_init(af) {
                         let name = &f.ident;
                         quote!{
                             let #name = &accounts[0];
@@ -63,7 +66,7 @@ pub fn generate(accs: AccountsStruct) -> proc_macro2::TokenStream {
         .iter()
         .filter_map(|af| match af {
             AccountField::AccountsStruct(_s) => None,
-            AccountField::Field(f) => match is_associated(af) {
+            AccountField::Field(f) => match is_associated_init(af) {
                 false => None,
                 true => Some(f),
             },
@@ -359,15 +362,15 @@ pub fn generate(accs: AccountsStruct) -> proc_macro2::TokenStream {
     }
 }
 
-// Returns true if the given AccountField has an associated constraint.
-fn is_associated(af: &AccountField) -> bool {
+// Returns true if the given AccountField has an associated init constraint.
+fn is_associated_init(af: &AccountField) -> bool {
     match af {
         AccountField::AccountsStruct(_s) => false,
         AccountField::Field(f) => f
             .constraints
             .iter()
             .filter(|c| match c {
-                Constraint::Associated(_c) => true,
+                Constraint::Associated(c) => c.is_init,
                 _ => false,
             })
             .next()
@@ -525,6 +528,16 @@ pub fn generate_constraint_state(f: &Field, c: &ConstraintState) -> proc_macro2:
 pub fn generate_constraint_associated(
     f: &Field,
     c: &ConstraintAssociated,
+) -> proc_macro2::TokenStream {
+    if c.is_init {
+        generate_constraint_associated_init(f, c)
+    } else {
+        generate_constraint_associated_seeds(f, c)
+    }
+}
+pub fn generate_constraint_associated_init(
+    f: &Field,
+    c: &ConstraintAssociated,
 ) -> proc_macro2::TokenStream {
     let associated_target = c.associated_target.clone();
     let field = &f.ident;
@@ -564,24 +577,8 @@ pub fn generate_constraint_associated(
         },
     };
 
-    let seeds_no_nonce = match f.associated_seeds.len() {
-        0 => quote! {
-            [
-                &b"anchor"[..],
-                #associated_target.to_account_info().key.as_ref(),
-            ]
-        },
-        _ => {
-            let seeds = to_seeds_tts(&f.associated_seeds);
-            quote! {
-                [
-                    &b"anchor"[..],
-                    #associated_target.to_account_info().key.as_ref(),
-                    #seeds
-                ]
-            }
-        }
-    };
+    let associated_pubkey_and_nonce = generate_associated_pubkey(f, c);
+
     let seeds_with_nonce = match f.associated_seeds.len() {
         0 => quote! {
             [
@@ -624,11 +621,9 @@ pub fn generate_constraint_associated(
             #space
             #payer
 
-            let (associated_field, nonce) = Pubkey::find_program_address(
-                &#seeds_no_nonce,
-                program_id,
-            );
-            if &associated_field != #field.key {
+            #associated_pubkey_and_nonce
+
+            if &__associated_field != #field.key {
                 return Err(ProgramError::Custom(45)); // todo: proper error.
             }
             let lamports = rent.minimum_balance(space);
@@ -666,6 +661,48 @@ pub fn generate_constraint_associated(
     }
 }
 
+pub fn generate_constraint_associated_seeds(
+    f: &Field,
+    c: &ConstraintAssociated,
+) -> proc_macro2::TokenStream {
+    let generated_associated_pubkey_and_nonce = generate_associated_pubkey(f, c);
+    let name = &f.ident;
+    quote! {
+        #generated_associated_pubkey_and_nonce
+        if #name.to_account_info().key != &__associated_field {
+            // TODO: proper error.
+            return Err(anchor_lang::solana_program::program_error::ProgramError::Custom(45));
+        }
+    }
+}
+pub fn generate_associated_pubkey(f: &Field, c: &ConstraintAssociated) -> proc_macro2::TokenStream {
+    let associated_target = c.associated_target.clone();
+    let seeds_no_nonce = match f.associated_seeds.len() {
+        0 => quote! {
+            [
+                &b"anchor"[..],
+                #associated_target.to_account_info().key.as_ref(),
+            ]
+        },
+        _ => {
+            let seeds = to_seeds_tts(&f.associated_seeds);
+            quote! {
+                [
+                    &b"anchor"[..],
+                    #associated_target.to_account_info().key.as_ref(),
+                    #seeds
+                ]
+            }
+        }
+    };
+    quote! {
+        let (__associated_field, nonce) = Pubkey::find_program_address(
+            &#seeds_no_nonce,
+            program_id,
+        );
+    }
+}
+
 // Returns the inner part of the seeds slice as a token stream.
 fn to_seeds_tts(seeds: &[syn::Ident]) -> proc_macro2::TokenStream {
     assert!(seeds.len() > 0);

+ 1 - 0
lang/syn/src/lib.rs

@@ -350,6 +350,7 @@ pub struct ConstraintState {
 #[derive(Debug)]
 pub struct ConstraintAssociated {
     pub associated_target: proc_macro2::Ident,
+    pub is_init: bool,
 }
 
 #[derive(Debug)]

+ 10 - 0
lang/syn/src/parser/accounts.rs

@@ -329,6 +329,7 @@ fn parse_constraints(
                     };
                     constraints.push(Constraint::Associated(ConstraintAssociated {
                         associated_target,
+                        is_init,
                     }));
                 }
                 "with" => {
@@ -395,6 +396,15 @@ fn parse_constraints(
         }
     }
 
+    // If init, then tag the associated constraint as being part of init.
+    if is_init {
+        for c in &mut constraints {
+            if let Constraint::Associated(ConstraintAssociated { is_init, .. }) = c {
+                *is_init = true;
+            }
+        }
+    }
+
     // If `associated` is given, remove `init` since it's redundant.
     if is_associated {
         is_init = false;