Browse Source

lang: add seeds::program constraint for PDAs (#1197)

Alan O'Donnell 3 years ago
parent
commit
e04f144e12

+ 4 - 0
CHANGELOG.md

@@ -11,6 +11,10 @@ incremented for features.
 
 ## [Unreleased]
 
+### Features
+
+* lang: Add `seeds::program` constraint for specifying which program_id to use when deriving PDAs.([#1197](https://github.com/project-serum/anchor/pull/1197))
+
 ### Breaking
 
 * lang: rename `loader_account` module to `account_loader` module ([#1279](https://github.com/project-serum/anchor/pull/1279))

+ 4 - 4
examples/tutorial/yarn.lock

@@ -30,10 +30,10 @@
     "@ethersproject/logger" "^5.5.0"
     hash.js "1.1.7"
 
-"@project-serum/anchor@^0.19.0":
-  version "0.19.0"
-  resolved "https://registry.yarnpkg.com/@project-serum/anchor/-/anchor-0.19.0.tgz#79f1fbe7c3134860ccbfe458a0e09daf79644885"
-  integrity sha512-cs0LBmJOrL9eJ8MRNqitnzbpCT5QEzVdJmiIjfNV5YaGn1K9vISR7DtISj3Bdl3KBdLqii4CTw1mpHdi8iXUCg==
+"@project-serum/anchor@^0.20.0":
+  version "0.20.0"
+  resolved "https://registry.yarnpkg.com/@project-serum/anchor/-/anchor-0.20.0.tgz#547f5c0ff7e66809fa7118b2e3abd8087b5ec519"
+  integrity sha512-p1KOiqGBIbNsopMrSVoPwgxR1iPffsdjMNCOysahTPL9whX2CLX9HQCdopHjYaGl7+SdHRuXml6Wahk/wUmC8g==
   dependencies:
     "@project-serum/borsh" "^0.2.2"
     "@solana/web3.js" "^1.17.0"

+ 30 - 7
lang/derive/accounts/src/lib.rs

@@ -159,7 +159,9 @@ use syn::parse_macro_input;
 ///                         you can pass it in as instruction data and set the bump value like shown in the example,
 ///                         using the <code>instruction_data</code> attribute.
 ///                         Anchor will then check that the bump returned by <code>find_program_address</code> equals
-///                         the bump in the instruction data.
+///                         the bump in the instruction data.<br>
+///                         <code>seeds::program</code> cannot be used together with init because the creation of an
+///                         account requires its signature which for PDAs only the currently executing program can provide.
 ///                     </li>
 ///                 </ul>
 ///                 Example:
@@ -228,21 +230,42 @@ use syn::parse_macro_input;
 ///         <tr>
 ///             <td>
 ///                 <code>#[account(seeds = &lt;seeds&gt;, bump)]</code><br><br>
-///                 <code>#[account(seeds = &lt;seeds&gt;, bump = &lt;expr&gt;)]</code>
+///                 <code>#[account(seeds = &lt;seeds&gt;, bump, seeds::program = &lt;expr&gt;)]<br><br>
+///                 <code>#[account(seeds = &lt;seeds&gt;, bump = &lt;expr&gt;)]</code><br><br>
+///                 <code>#[account(seeds = &lt;seeds&gt;, bump = &lt;expr&gt;, seeds::program = &lt;expr&gt;)]</code><br><br>
 ///             </td>
 ///             <td>
 ///                 Checks that given account is a PDA derived from the currently executing program,
 ///                 the seeds, and if provided, the bump. If not provided, anchor uses the canonical
-///                 bump. Will be adjusted in the future to allow PDA to be derived from other programs.<br>
+///                 bump. <br>
+///                 Add <code>seeds::program = &lt;expr&gt;</code> to derive the PDA from a different
+///                 program than the currently executing one.<br>
 ///                 This constraint behaves slightly differently when used with <code>init</code>.
 ///                 See its description.
 ///                 <br><br>
 ///                 Example:
 ///                 <pre><code>
-/// #[account(seeds = [b"example_seed], bump)]
-/// pub canonical_pda: AccountInfo<'info>,
-/// #[account(seeds = [b"other_seed], bump = 142)]
-/// pub arbitrary_pda: AccountInfo<'info>
+/// #[derive(Accounts)]
+/// #[instruction(first_bump: u8, second_bump: u8)]
+/// pub struct Example {
+///     #[account(seeds = [b"example_seed], bump)]
+///     pub canonical_pda: AccountInfo<'info>,
+///     #[account(
+///         seeds = [b"example_seed],
+///         bump,
+///         seeds::program = other_program.key()
+///     )]
+///     pub canonical_pda_two: AccountInfo<'info>,
+///     #[account(seeds = [b"other_seed], bump = first_bump)]
+///     pub arbitrary_pda: AccountInfo<'info>
+///     #[account(
+///         seeds = [b"other_seed],
+///         bump = second_bump,
+///         seeds::program = other_program.key()
+///     )]
+///     pub arbitrary_pda_two: AccountInfo<'info>,
+///     pub other_program: Program<'info, OtherProgram>
+/// }
 ///                 </code></pre>
 ///             </td>
 ///         </tr>

+ 12 - 3
lang/syn/src/codegen/accounts/constraints.rs

@@ -327,6 +327,15 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma
 fn generate_constraint_seeds(f: &Field, c: &ConstraintSeedsGroup) -> proc_macro2::TokenStream {
     let name = &f.ident;
     let s = &mut c.seeds.clone();
+
+    let deriving_program_id = c
+        .program_seed
+        .clone()
+        // If they specified a seeds::program to use when deriving the PDA, use it.
+        .map(|program_id| quote! { #program_id })
+        // Otherwise fall back to the current program's program_id.
+        .unwrap_or(quote! { program_id });
+
     // If the seeds came with a trailing comma, we need to chop it off
     // before we interpolate them below.
     if let Some(pair) = s.pop() {
@@ -340,7 +349,7 @@ fn generate_constraint_seeds(f: &Field, c: &ConstraintSeedsGroup) -> proc_macro2
         quote! {
             let (__program_signer, __bump) = anchor_lang::solana_program::pubkey::Pubkey::find_program_address(
                 &[#s],
-                program_id,
+                &#deriving_program_id,
             );
             if #name.key() != __program_signer {
                 return Err(anchor_lang::__private::ErrorCode::ConstraintSeeds.into());
@@ -362,7 +371,7 @@ fn generate_constraint_seeds(f: &Field, c: &ConstraintSeedsGroup) -> proc_macro2
                         &[
                             Pubkey::find_program_address(
                                 &[#s],
-                                program_id,
+                                &#deriving_program_id,
                             ).1
                         ][..]
                     ]
@@ -378,7 +387,7 @@ fn generate_constraint_seeds(f: &Field, c: &ConstraintSeedsGroup) -> proc_macro2
         quote! {
             let __program_signer = Pubkey::create_program_address(
                 &#seeds[..],
-                program_id,
+                &#deriving_program_id,
             ).map_err(|_| anchor_lang::__private::ErrorCode::ConstraintSeeds)?;
             if #name.key() != __program_signer {
                 return Err(anchor_lang::__private::ErrorCode::ConstraintSeeds.into());

+ 8 - 1
lang/syn/src/lib.rs

@@ -610,6 +610,7 @@ pub enum ConstraintToken {
     MintFreezeAuthority(Context<ConstraintMintFreezeAuthority>),
     MintDecimals(Context<ConstraintMintDecimals>),
     Bump(Context<ConstraintTokenBump>),
+    ProgramSeed(Context<ConstraintProgramSeed>),
 }
 
 impl Parse for ConstraintToken {
@@ -687,7 +688,8 @@ pub struct ConstraintInitGroup {
 pub struct ConstraintSeedsGroup {
     pub is_init: bool,
     pub seeds: Punctuated<Expr, Token![,]>,
-    pub bump: Option<Expr>, // None => bump was given without a target.
+    pub bump: Option<Expr>,         // None => bump was given without a target.
+    pub program_seed: Option<Expr>, // None => use the current program's program_id
 }
 
 #[derive(Debug, Clone)]
@@ -771,6 +773,11 @@ pub struct ConstraintTokenBump {
     bump: Option<Expr>,
 }
 
+#[derive(Debug, Clone)]
+pub struct ConstraintProgramSeed {
+    program_seed: Expr,
+}
+
 #[derive(Debug, Clone)]
 pub struct ConstraintAssociatedToken {
     pub wallet: Expr,

+ 69 - 10
lang/syn/src/parser/accounts/constraints.rs

@@ -182,6 +182,43 @@ pub fn parse_token(stream: ParseStream) -> ParseResult<ConstraintToken> {
             };
             ConstraintToken::Bump(Context::new(ident.span(), ConstraintTokenBump { bump }))
         }
+        "seeds" => {
+            if stream.peek(Token![:]) {
+                stream.parse::<Token![:]>()?;
+                stream.parse::<Token![:]>()?;
+                let kw = stream.call(Ident::parse_any)?.to_string();
+                stream.parse::<Token![=]>()?;
+
+                let span = ident
+                    .span()
+                    .join(stream.span())
+                    .unwrap_or_else(|| ident.span());
+
+                match kw.as_str() {
+                    "program" => ConstraintToken::ProgramSeed(Context::new(
+                        span,
+                        ConstraintProgramSeed {
+                            program_seed: stream.parse()?,
+                        },
+                    )),
+                    _ => return Err(ParseError::new(ident.span(), "Invalid attribute")),
+                }
+            } else {
+                stream.parse::<Token![=]>()?;
+                let span = ident
+                    .span()
+                    .join(stream.span())
+                    .unwrap_or_else(|| ident.span());
+                let seeds;
+                let bracket = bracketed!(seeds in stream);
+                ConstraintToken::Seeds(Context::new(
+                    span.join(bracket.span).unwrap_or(span),
+                    ConstraintSeeds {
+                        seeds: seeds.parse_terminated(Expr::parse)?,
+                    },
+                ))
+            }
+        }
         _ => {
             stream.parse::<Token![=]>()?;
             let span = ident
@@ -234,16 +271,6 @@ pub fn parse_token(stream: ParseStream) -> ParseResult<ConstraintToken> {
                         space: stream.parse()?,
                     },
                 )),
-                "seeds" => {
-                    let seeds;
-                    let bracket = bracketed!(seeds in stream);
-                    ConstraintToken::Seeds(Context::new(
-                        span.join(bracket.span).unwrap_or(span),
-                        ConstraintSeeds {
-                            seeds: seeds.parse_terminated(Expr::parse)?,
-                        },
-                    ))
-                }
                 "constraint" => ConstraintToken::Raw(Context::new(
                     span,
                     ConstraintRaw {
@@ -308,6 +335,7 @@ pub struct ConstraintGroupBuilder<'ty> {
     pub mint_freeze_authority: Option<Context<ConstraintMintFreezeAuthority>>,
     pub mint_decimals: Option<Context<ConstraintMintDecimals>>,
     pub bump: Option<Context<ConstraintTokenBump>>,
+    pub program_seed: Option<Context<ConstraintProgramSeed>>,
 }
 
 impl<'ty> ConstraintGroupBuilder<'ty> {
@@ -338,6 +366,7 @@ impl<'ty> ConstraintGroupBuilder<'ty> {
             mint_freeze_authority: None,
             mint_decimals: None,
             bump: None,
+            program_seed: None,
         }
     }
 
@@ -494,6 +523,7 @@ impl<'ty> ConstraintGroupBuilder<'ty> {
             mint_freeze_authority,
             mint_decimals,
             bump,
+            program_seed,
         } = self;
 
         // Converts Option<Context<T>> -> Option<T>.
@@ -519,6 +549,7 @@ impl<'ty> ConstraintGroupBuilder<'ty> {
             bump: into_inner!(bump)
                 .map(|b| b.bump)
                 .expect("bump must be provided with seeds"),
+            program_seed: into_inner!(program_seed).map(|id| id.program_seed),
         });
         let associated_token = match (associated_token_mint, associated_token_authority) {
             (Some(mint), Some(auth)) => Some(ConstraintAssociatedToken {
@@ -620,6 +651,7 @@ impl<'ty> ConstraintGroupBuilder<'ty> {
             ConstraintToken::MintFreezeAuthority(c) => self.add_mint_freeze_authority(c),
             ConstraintToken::MintDecimals(c) => self.add_mint_decimals(c),
             ConstraintToken::Bump(c) => self.add_bump(c),
+            ConstraintToken::ProgramSeed(c) => self.add_program_seed(c),
         }
     }
 
@@ -725,6 +757,33 @@ impl<'ty> ConstraintGroupBuilder<'ty> {
         Ok(())
     }
 
+    fn add_program_seed(&mut self, c: Context<ConstraintProgramSeed>) -> ParseResult<()> {
+        if self.program_seed.is_some() {
+            return Err(ParseError::new(c.span(), "seeds::program already provided"));
+        }
+        if self.seeds.is_none() {
+            return Err(ParseError::new(
+                c.span(),
+                "seeds must be provided before seeds::program",
+            ));
+        }
+        if let Some(ref init) = self.init {
+            if init.if_needed {
+                return Err(ParseError::new(
+                    c.span(),
+                    "seeds::program cannot be used with init_if_needed",
+                ));
+            } else {
+                return Err(ParseError::new(
+                    c.span(),
+                    "seeds::program cannot be used with init",
+                ));
+            }
+        }
+        self.program_seed.replace(c);
+        Ok(())
+    }
+
     fn add_token_authority(&mut self, c: Context<ConstraintTokenAuthority>) -> ParseResult<()> {
         if self.token_authority.is_some() {
             return Err(ParseError::new(

+ 3 - 0
tests/misc/package.json

@@ -15,5 +15,8 @@
   },
   "scripts": {
     "test": "anchor test"
+  },
+  "dependencies": {
+    "mocha": "^9.1.3"
   }
 }

+ 22 - 0
tests/misc/programs/misc/src/context.rs

@@ -379,3 +379,25 @@ pub struct InitIfNeededChecksRentExemption<'info> {
     pub system_program: Program<'info, System>
 }
 
+#[derive(Accounts)]
+#[instruction(bump: u8, second_bump: u8)]
+pub struct TestProgramIdConstraint<'info> {
+    // not a real associated token account
+    // just deriving like this for testing purposes
+    #[account(seeds = [b"seed"], bump = bump, seeds::program = anchor_spl::associated_token::ID)]
+    first: AccountInfo<'info>,
+
+    #[account(seeds = [b"seed"], bump = second_bump, seeds::program = crate::ID)]
+    second: AccountInfo<'info>,
+}
+
+#[derive(Accounts)]
+pub struct TestProgramIdConstraintUsingFindPda<'info> {
+    // not a real associated token account
+    // just deriving like this for testing purposes
+    #[account(seeds = [b"seed"], bump, seeds::program = anchor_spl::associated_token::ID)]
+    first: AccountInfo<'info>,
+
+    #[account(seeds = [b"seed"], bump, seeds::program = crate::ID)]
+    second: AccountInfo<'info>,
+}

+ 14 - 0
tests/misc/programs/misc/src/lib.rs

@@ -268,4 +268,18 @@ pub mod misc {
     pub fn init_if_needed_checks_rent_exemption(_ctx: Context<InitIfNeededChecksRentExemption>) -> ProgramResult {
         Ok(())
     }
+        
+    pub fn test_program_id_constraint(
+        _ctx: Context<TestProgramIdConstraint>,
+        _bump: u8,
+        _second_bump: u8
+    ) -> ProgramResult {
+        Ok(())
+    }
+
+    pub fn test_program_id_constraint_find_pda(
+        _ctx: Context<TestProgramIdConstraintUsingFindPda>,
+    ) -> ProgramResult {
+        Ok(())
+    }
 }

+ 100 - 0
tests/misc/tests/misc.js

@@ -1528,4 +1528,104 @@ describe("misc", () => {
       assert.equal("A rent exempt constraint was violated", err.msg);
     }
   });
+
+  describe("Can validate PDAs derived from other program ids", () => {
+    it("With bumps using create_program_address", async () => {
+      const [firstPDA, firstBump] =
+        await anchor.web3.PublicKey.findProgramAddress(
+          [anchor.utils.bytes.utf8.encode("seed")],
+          ASSOCIATED_TOKEN_PROGRAM_ID
+        );
+      const [secondPDA, secondBump] =
+        await anchor.web3.PublicKey.findProgramAddress(
+          [anchor.utils.bytes.utf8.encode("seed")],
+          program.programId
+        );
+
+      // correct bump but wrong address
+      const wrongAddress = anchor.web3.Keypair.generate().publicKey;
+      try {
+        await program.rpc.testProgramIdConstraint(firstBump, secondBump, {
+          accounts: {
+            first: wrongAddress,
+            second: secondPDA,
+          },
+        });
+        assert.ok(false);
+      } catch (err) {
+        assert.equal(err.code, 2006);
+      }
+
+      // matching bump seed for wrong address but derived from wrong program
+      try {
+        await program.rpc.testProgramIdConstraint(secondBump, secondBump, {
+          accounts: {
+            first: secondPDA,
+            second: secondPDA,
+          },
+        });
+        assert.ok(false);
+      } catch (err) {
+        assert.equal(err.code, 2006);
+      }
+
+      // correct inputs should lead to successful tx
+      await program.rpc.testProgramIdConstraint(firstBump, secondBump, {
+        accounts: {
+          first: firstPDA,
+          second: secondPDA,
+        },
+      });
+    });
+
+    it("With bumps using find_program_address", async () => {
+      const firstPDA = (
+        await anchor.web3.PublicKey.findProgramAddress(
+          [anchor.utils.bytes.utf8.encode("seed")],
+          ASSOCIATED_TOKEN_PROGRAM_ID
+        )
+      )[0];
+      const secondPDA = (
+        await anchor.web3.PublicKey.findProgramAddress(
+          [anchor.utils.bytes.utf8.encode("seed")],
+          program.programId
+        )
+      )[0];
+
+      // random wrong address
+      const wrongAddress = anchor.web3.Keypair.generate().publicKey;
+      try {
+        await program.rpc.testProgramIdConstraintFindPda({
+          accounts: {
+            first: wrongAddress,
+            second: secondPDA,
+          },
+        });
+        assert.ok(false);
+      } catch (err) {
+        assert.equal(err.code, 2006);
+      }
+
+      // same seeds but derived from wrong program
+      try {
+        await program.rpc.testProgramIdConstraintFindPda({
+          accounts: {
+            first: secondPDA,
+            second: secondPDA,
+          },
+        });
+        assert.ok(false);
+      } catch (err) {
+        assert.equal(err.code, 2006);
+      }
+
+      // correct inputs should lead to successful tx
+      await program.rpc.testProgramIdConstraintFindPda({
+        accounts: {
+          first: firstPDA,
+          second: secondPDA,
+        },
+      });
+    });
+  });
 });