Browse Source

ts: Extend filters for `all` method (#788)

Daniel Marin 4 years ago
parent
commit
73f046e0a5

+ 7 - 0
tests/misc/programs/misc/src/account.rs

@@ -29,3 +29,10 @@ pub struct DataZeroCopy {
     pub data: u16,
     pub data: u16,
     pub bump: u8,
     pub bump: u8,
 }
 }
+
+#[account]
+#[derive(Default)]
+pub struct DataWithFilter {
+    pub authority: Pubkey,
+    pub filterable: Pubkey,
+}

+ 8 - 0
tests/misc/programs/misc/src/context.rs

@@ -211,3 +211,11 @@ pub struct TestCompositePayer<'info> {
     pub data: Account<'info, Data>,
     pub data: Account<'info, Data>,
     pub system_program: Program<'info, System>,
     pub system_program: Program<'info, System>,
 }
 }
+
+#[derive(Accounts)]
+pub struct TestFetchAll<'info> {
+    #[account(init, payer = authority)]
+    pub data: Account<'info, DataWithFilter>,
+    pub authority: Signer<'info>,
+    pub system_program: Program<'info, System>,
+}

+ 6 - 0
tests/misc/programs/misc/src/lib.rs

@@ -164,4 +164,10 @@ pub mod misc {
         assert!(ctx.accounts.token.mint == ctx.accounts.mint.key());
         assert!(ctx.accounts.token.mint == ctx.accounts.mint.key());
         Ok(())
         Ok(())
     }
     }
+
+    pub fn test_fetch_all(ctx: Context<TestFetchAll>, filterable: Pubkey) -> ProgramResult {
+        ctx.accounts.data.authority = ctx.accounts.authority.key();
+        ctx.accounts.data.filterable = filterable;
+        Ok(())
+    }
 }
 }

+ 99 - 0
tests/misc/tests/misc.js

@@ -6,6 +6,7 @@ const {
   TOKEN_PROGRAM_ID,
   TOKEN_PROGRAM_ID,
   Token,
   Token,
 } = require("@solana/spl-token");
 } = require("@solana/spl-token");
+const miscIdl = require("../target/idl/misc.json");
 
 
 describe("misc", () => {
 describe("misc", () => {
   // Configure the client to use the local cluster.
   // Configure the client to use the local cluster.
@@ -640,4 +641,102 @@ describe("misc", () => {
     assert.ok(account.owner.equals(program.provider.wallet.publicKey));
     assert.ok(account.owner.equals(program.provider.wallet.publicKey));
     assert.ok(account.mint.equals(mint.publicKey));
     assert.ok(account.mint.equals(mint.publicKey));
   });
   });
+
+  it("Can fetch all accounts of a given type", async () => {
+    // Initialize the accounts.
+    const data1 = anchor.web3.Keypair.generate();
+    const data2 = anchor.web3.Keypair.generate();
+    const data3 = anchor.web3.Keypair.generate();
+    const data4 = anchor.web3.Keypair.generate();
+    // Initialize filterable data.
+    const filterable1 = anchor.web3.Keypair.generate().publicKey;
+    const filterable2 = anchor.web3.Keypair.generate().publicKey;
+    // Set up a secondary wallet and program.
+    const anotherProgram = new anchor.Program(
+      miscIdl,
+      program.programId,
+      new anchor.Provider(
+        program.provider.connection,
+        new anchor.Wallet(anchor.web3.Keypair.generate()),
+        { commitment: program.provider.connection.commitment }
+      )
+    );
+    // Request airdrop for secondary wallet.
+    const signature = await program.provider.connection.requestAirdrop(
+      anotherProgram.provider.wallet.publicKey,
+      anchor.web3.LAMPORTS_PER_SOL
+    );
+    await program.provider.connection.confirmTransaction(signature);
+    // Create all the accounts.
+    await Promise.all([
+      program.rpc.testFetchAll(filterable1, {
+        accounts: {
+          data: data1.publicKey,
+          authority: program.provider.wallet.publicKey,
+          systemProgram: anchor.web3.SystemProgram.programId,
+        },
+        signers: [data1],
+      }),
+      program.rpc.testFetchAll(filterable1, {
+        accounts: {
+          data: data2.publicKey,
+          authority: program.provider.wallet.publicKey,
+          systemProgram: anchor.web3.SystemProgram.programId,
+        },
+        signers: [data2],
+      }),
+      program.rpc.testFetchAll(filterable2, {
+        accounts: {
+          data: data3.publicKey,
+          authority: program.provider.wallet.publicKey,
+          systemProgram: anchor.web3.SystemProgram.programId,
+        },
+        signers: [data3],
+      }),
+      anotherProgram.rpc.testFetchAll(filterable1, {
+        accounts: {
+          data: data4.publicKey,
+          authority: anotherProgram.provider.wallet.publicKey,
+          systemProgram: anchor.web3.SystemProgram.programId,
+        },
+        signers: [data4],
+      }),
+    ]);
+    // Call for multiple kinds of .all.
+    const allAccounts = await program.account.dataWithFilter.all();
+    const allAccountsFilteredByBuffer =
+      await program.account.dataWithFilter.all(
+        program.provider.wallet.publicKey.toBuffer()
+      );
+    const allAccountsFilteredByProgramFilters1 =
+      await program.account.dataWithFilter.all([
+        {
+          memcmp: {
+            offset: 8,
+            bytes: program.provider.wallet.publicKey.toBase58(),
+          },
+        },
+        { memcmp: { offset: 40, bytes: filterable1.toBase58() } },
+      ]);
+    const allAccountsFilteredByProgramFilters2 =
+      await program.account.dataWithFilter.all([
+        {
+          memcmp: {
+            offset: 8,
+            bytes: program.provider.wallet.publicKey.toBase58(),
+          },
+        },
+        { memcmp: { offset: 40, bytes: filterable2.toBase58() } },
+      ]);
+    // Without filters there should be 4 accounts.
+    assert.equal(allAccounts.length, 4);
+    // Filtering by main wallet there should be 3 accounts.
+    assert.equal(allAccountsFilteredByBuffer.length, 3);
+    // Filtering all the main wallet accounts and matching the filterable1 value
+    // results in a 2 accounts.
+    assert.equal(allAccountsFilteredByProgramFilters1.length, 2);
+    // Filtering all the main wallet accounts and matching the filterable2 value
+    // results in 1 account.
+    assert.equal(allAccountsFilteredByProgramFilters2.length, 1);
+  });
 });
 });

+ 24 - 6
ts/src/program/namespace/account.ts

@@ -7,6 +7,7 @@ import {
   SystemProgram,
   SystemProgram,
   TransactionInstruction,
   TransactionInstruction,
   Commitment,
   Commitment,
+  GetProgramAccountsFilter,
 } from "@solana/web3.js";
 } from "@solana/web3.js";
 import Provider from "../../provider";
 import Provider from "../../provider";
 import { Idl, IdlTypeDef } from "../../idl";
 import { Idl, IdlTypeDef } from "../../idl";
@@ -187,12 +188,24 @@ export class AccountClient<T = any> {
 
 
   /**
   /**
    * Returns all instances of this account type for the program.
    * Returns all instances of this account type for the program.
+   *
+   * @param filters User-provided filters to narrow the results from `connection.getProgramAccounts`.
+   *
+   *                When filters are not defined this method returns all
+   *                the account instances.
+   *
+   *                When filters are of type `Buffer`, the filters are appended
+   *                after the discriminator.
+   *
+   *                When filters are of type `GetProgramAccountsFilter[]`,
+   *                filters are appended after the discriminator filter.
    */
    */
-  async all(filter?: Buffer): Promise<ProgramAccount<T>[]> {
-    let bytes = AccountsCoder.accountDiscriminator(this._idlAccount.name);
-    if (filter !== undefined) {
-      bytes = Buffer.concat([bytes, filter]);
-    }
+  async all(
+    filters?: Buffer | GetProgramAccountsFilter[]
+  ): Promise<ProgramAccount<T>[]> {
+    const discriminator = AccountsCoder.accountDiscriminator(
+      this._idlAccount.name
+    );
 
 
     let resp = await this._provider.connection.getProgramAccounts(
     let resp = await this._provider.connection.getProgramAccounts(
       this._programId,
       this._programId,
@@ -202,9 +215,14 @@ export class AccountClient<T = any> {
           {
           {
             memcmp: {
             memcmp: {
               offset: 0,
               offset: 0,
-              bytes: bs58.encode(bytes),
+              bytes: bs58.encode(
+                filters instanceof Buffer
+                  ? Buffer.concat([discriminator, filters])
+                  : discriminator
+              ),
             },
             },
           },
           },
+          ...(Array.isArray(filters) ? filters : []),
         ],
         ],
       }
       }
     );
     );