Tom Linton 3 anni fa
parent
commit
8fb942efd5

+ 9 - 2
tests/cpi-returns/programs/callee/src/lib.rs

@@ -11,7 +11,9 @@ pub mod callee {
         pub value: u64,
     }
 
-    pub fn initialize(_ctx: Context<Initialize>) -> Result<()> {
+    pub fn initialize(ctx: Context<Initialize>) -> Result<()> {
+        let account = &mut ctx.accounts.account;
+        account.value = 10;
         Ok(())
     }
 
@@ -27,6 +29,12 @@ pub mod callee {
     pub fn return_vec(_ctx: Context<CpiReturn>) -> Result<Vec<u8>> {
         Ok(vec![12, 13, 14, 100])
     }
+
+    // Used for testing views
+    pub fn return_u64_from_account(ctx: Context<CpiReturn>) -> Result<u64> {
+        let account = &ctx.accounts.account;
+        Ok(account.value)
+    }
 }
 
 #[derive(Accounts)]
@@ -40,7 +48,6 @@ pub struct Initialize<'info> {
 
 #[derive(Accounts)]
 pub struct CpiReturn<'info> {
-    #[account(mut)]
     pub account: Account<'info, CpiReturnAccount>,
 }
 

+ 21 - 0
tests/cpi-returns/programs/caller/src/lib.rs

@@ -9,6 +9,12 @@ declare_id!("HmbTLCmaGvZhKnn1Zfa1JVnp7vkMV4DYVxPLWBVoN65L");
 pub mod caller {
     use super::*;
 
+    #[derive(AnchorSerialize, AnchorDeserialize)]
+    pub struct Struct {
+        pub a: u64,
+        pub b: u64,
+    }
+
     pub fn cpi_call_return_u64(ctx: Context<CpiReturnContext>) -> Result<()> {
         let cpi_program = ctx.accounts.cpi_return_program.to_account_info();
         let cpi_accounts = CpiReturn {
@@ -44,6 +50,18 @@ pub mod caller {
         anchor_lang::solana_program::log::sol_log_data(&[&solana_return.try_to_vec().unwrap()]);
         Ok(())
     }
+
+    pub fn return_u64(ctx: Context<ReturnContext>) -> Result<u64> {
+        Ok(99)
+    }
+
+    pub fn return_struct(ctx: Context<ReturnContext>) -> Result<Struct> {
+        Ok(Struct { a: 1, b: 2 })
+    }
+
+    pub fn return_vec(ctx: Context<ReturnContext>) -> Result<Vec<u64>> {
+        Ok(vec![1, 2, 3])
+    }
 }
 
 #[derive(Accounts)]
@@ -52,3 +70,6 @@ pub struct CpiReturnContext<'info> {
     pub cpi_return: Account<'info, CpiReturnAccount>,
     pub cpi_return_program: Program<'info, Callee>,
 }
+
+#[derive(Accounts)]
+pub struct ReturnContext {}

+ 59 - 0
tests/cpi-returns/tests/cpi-return.ts

@@ -158,4 +158,63 @@ describe("CPI return", () => {
       defined: "StructReturn",
     });
   });
+
+  it("can return a u64 via view", async () => {
+    assert(new anchor.BN(99).eq(await callerProgram.views.returnU64()));
+    // Via methods API
+    assert(
+      new anchor.BN(99).eq(await callerProgram.methods.returnU64().view())
+    );
+  });
+
+  it("can return a struct via view", async () => {
+    const struct = await callerProgram.views.returnStruct();
+    assert(struct.a.eq(new anchor.BN(1)));
+    assert(struct.b.eq(new anchor.BN(2)));
+    // Via methods API
+    const struct2 = await callerProgram.methods.returnStruct().view();
+    assert(struct2.a.eq(new anchor.BN(1)));
+    assert(struct2.b.eq(new anchor.BN(2)));
+  });
+
+  it("can return a vec via view", async () => {
+    const vec = await callerProgram.views.returnVec();
+    assert(vec[0].eq(new anchor.BN(1)));
+    assert(vec[1].eq(new anchor.BN(2)));
+    assert(vec[2].eq(new anchor.BN(3)));
+    // Via methods API
+    const vec2 = await callerProgram.methods.returnVec().view();
+    assert(vec2[0].eq(new anchor.BN(1)));
+    assert(vec2[1].eq(new anchor.BN(2)));
+    assert(vec2[2].eq(new anchor.BN(3)));
+  });
+
+  it("can return a u64 from an account via view", async () => {
+    const value = new anchor.BN(10);
+    assert(
+      value.eq(
+        await calleeProgram.methods
+          .returnU64FromAccount()
+          .accounts({ account: cpiReturn.publicKey })
+          .view()
+      )
+    );
+  });
+
+  it("cant call view on mutable instruction", async () => {
+    assert.equal(calleeProgram.views.initialize, undefined);
+    try {
+      await calleeProgram.methods
+        .initialize()
+        .accounts({
+          account: cpiReturn.publicKey,
+          user: provider.wallet.publicKey,
+          systemProgram: SystemProgram.programId,
+        })
+        .signers([cpiReturn])
+        .view();
+    } catch (e) {
+      assert(e.message.includes("Method does not support views"));
+    }
+  });
 });

+ 14 - 2
ts/src/program/index.ts

@@ -11,6 +11,7 @@ import NamespaceFactory, {
   StateClient,
   SimulateNamespace,
   MethodsNamespace,
+  ViewNamespace,
 } from "./namespace/index.js";
 import { utf8 } from "../utils/bytes/index.js";
 import { EventManager } from "./event.js";
@@ -217,6 +218,8 @@ export class Program<IDL extends Idl = Idl> {
    */
   readonly methods: MethodsNamespace<IDL>;
 
+  readonly views?: ViewNamespace<IDL>;
+
   /**
    * Address of the program.
    */
@@ -280,8 +283,16 @@ export class Program<IDL extends Idl = Idl> {
     this._events = new EventManager(this._programId, provider, this._coder);
 
     // Dynamic namespaces.
-    const [rpc, instruction, transaction, account, simulate, methods, state] =
-      NamespaceFactory.build(idl, this._coder, programId, provider);
+    const [
+      rpc,
+      instruction,
+      transaction,
+      account,
+      simulate,
+      methods,
+      state,
+      views,
+    ] = NamespaceFactory.build(idl, this._coder, programId, provider);
     this.rpc = rpc;
     this.instruction = instruction;
     this.transaction = transaction;
@@ -289,6 +300,7 @@ export class Program<IDL extends Idl = Idl> {
     this.simulate = simulate;
     this.methods = methods;
     this.state = state;
+    this.views = views;
   }
 
   /**

+ 11 - 2
ts/src/program/namespace/index.ts

@@ -11,6 +11,7 @@ import AccountFactory, { AccountNamespace } from "./account.js";
 import SimulateFactory, { SimulateNamespace } from "./simulate.js";
 import { parseIdlErrors } from "../common.js";
 import { MethodsBuilderFactory, MethodsNamespace } from "./methods";
+import ViewFactory, { ViewNamespace } from "./views";
 
 // Re-exports.
 export { StateClient } from "./state.js";
@@ -21,6 +22,7 @@ export { AccountNamespace, AccountClient, ProgramAccount } from "./account.js";
 export { SimulateNamespace, SimulateFn } from "./simulate.js";
 export { IdlAccounts, IdlTypes } from "./types.js";
 export { MethodsBuilderFactory, MethodsNamespace } from "./methods";
+export { ViewNamespace, ViewFn } from "./views";
 
 export default class NamespaceFactory {
   /**
@@ -38,13 +40,15 @@ export default class NamespaceFactory {
     AccountNamespace<IDL>,
     SimulateNamespace<IDL>,
     MethodsNamespace<IDL>,
-    StateClient<IDL> | undefined
+    StateClient<IDL> | undefined,
+    ViewNamespace<IDL> | undefined
   ] {
     const rpc: RpcNamespace = {};
     const instruction: InstructionNamespace = {};
     const transaction: TransactionNamespace = {};
     const simulate: SimulateNamespace = {};
     const methods: MethodsNamespace = {};
+    const view: ViewNamespace = {};
 
     const idlErrors = parseIdlErrors(idl);
 
@@ -71,6 +75,7 @@ export default class NamespaceFactory {
         programId,
         idl
       );
+      const viewItem = ViewFactory.build(programId, idlIx, simulateItem, idl);
       const methodItem = MethodsBuilderFactory.build<IDL, typeof idlIx>(
         provider,
         programId,
@@ -79,9 +84,9 @@ export default class NamespaceFactory {
         txItem,
         rpcItem,
         simulateItem,
+        viewItem,
         account
       );
-
       const name = camelCase(idlIx.name);
 
       instruction[name] = ixItem;
@@ -89,6 +94,9 @@ export default class NamespaceFactory {
       rpc[name] = rpcItem;
       simulate[name] = simulateItem;
       methods[name] = methodItem;
+      if (viewItem) {
+        view[name] = viewItem;
+      }
     });
 
     return [
@@ -99,6 +107,7 @@ export default class NamespaceFactory {
       simulate as SimulateNamespace<IDL>,
       methods as MethodsNamespace<IDL>,
       state,
+      view as ViewNamespace<IDL>,
     ];
   }
 }

+ 20 - 0
ts/src/program/namespace/methods.ts

@@ -14,6 +14,7 @@ import { AllInstructions, MethodsFn, MakeMethodsNamespace } from "./types.js";
 import { InstructionFn } from "./instruction.js";
 import { RpcFn } from "./rpc.js";
 import { SimulateFn } from "./simulate.js";
+import { ViewFn } from "./views.js";
 import Provider from "../../provider.js";
 import { AccountNamespace } from "./account.js";
 import { AccountsResolver } from "../accounts-resolver.js";
@@ -33,6 +34,7 @@ export class MethodsBuilderFactory {
     txFn: TransactionFn<IDL>,
     rpcFn: RpcFn<IDL>,
     simulateFn: SimulateFn<IDL>,
+    viewFn: ViewFn<IDL> | undefined,
     accountNamespace: AccountNamespace<IDL>
   ): MethodsFn<IDL, I, MethodsBuilder<IDL, I>> {
     return (...args) =>
@@ -42,6 +44,7 @@ export class MethodsBuilderFactory {
         txFn,
         rpcFn,
         simulateFn,
+        viewFn,
         provider,
         programId,
         idlIx,
@@ -64,6 +67,7 @@ export class MethodsBuilder<IDL extends Idl, I extends AllInstructions<IDL>> {
     private _txFn: TransactionFn<IDL>,
     private _rpcFn: RpcFn<IDL>,
     private _simulateFn: SimulateFn<IDL>,
+    private _viewFn: ViewFn<IDL> | undefined,
     _provider: Provider,
     _programId: PublicKey,
     _idlIx: AllInstructions<IDL>,
@@ -125,6 +129,22 @@ export class MethodsBuilder<IDL extends Idl, I extends AllInstructions<IDL>> {
     });
   }
 
+  public async view(options?: ConfirmOptions): Promise<any> {
+    await this._accountsResolver.resolve();
+    if (!this._viewFn) {
+      throw new Error("Method does not support views");
+    }
+    // @ts-ignore
+    return this._viewFn(...this._args, {
+      accounts: this._accounts,
+      signers: this._signers,
+      remainingAccounts: this._remainingAccounts,
+      preInstructions: this._preInstructions,
+      postInstructions: this._postInstructions,
+      options: options,
+    });
+  }
+
   public async simulate(
     options?: ConfirmOptions
   ): Promise<SimulateResponse<any, any>> {

+ 60 - 0
ts/src/program/namespace/views.ts

@@ -0,0 +1,60 @@
+import { PublicKey } from "@solana/web3.js";
+import { Idl, IdlAccount } from "../../idl.js";
+import { SimulateFn } from "./simulate.js";
+import {
+  AllInstructions,
+  InstructionContextFn,
+  MakeInstructionsNamespace,
+} from "./types";
+import { IdlCoder } from "../../coder/borsh/idl";
+import { decode } from "../../utils/bytes/base64";
+
+export default class ViewFactory {
+  public static build<IDL extends Idl, I extends AllInstructions<IDL>>(
+    programId: PublicKey,
+    idlIx: AllInstructions<IDL>,
+    simulateFn: SimulateFn<IDL>,
+    idl: IDL
+  ): ViewFn<IDL, I> | undefined {
+    const isMut = idlIx.accounts.find((a: IdlAccount) => a.isMut);
+    const hasReturn = !!idlIx.returns;
+    if (isMut || !hasReturn) return;
+
+    const view: ViewFn<IDL> = async (...args) => {
+      let simulationResult = await simulateFn(...args);
+      const returnPrefix = `Program return: ${programId} `;
+      let returnLog = simulationResult.raw.find((l) =>
+        l.startsWith(returnPrefix)
+      );
+      if (!returnLog) {
+        throw new Error("View expected return log");
+      }
+      let returnData = decode(returnLog.slice(returnPrefix.length));
+      let returnType = idlIx.returns;
+      if (!returnType) {
+        throw new Error("View expected return type");
+      }
+      const coder = IdlCoder.fieldLayout(
+        { type: returnType },
+        Array.from([...(idl.accounts ?? []), ...(idl.types ?? [])])
+      );
+      return coder.decode(returnData);
+    };
+    return view;
+  }
+}
+
+export type ViewNamespace<
+  IDL extends Idl = Idl,
+  I extends AllInstructions<IDL> = AllInstructions<IDL>
+> = MakeInstructionsNamespace<IDL, I, Promise<any>>;
+
+/**
+ * ViewFn is a single method generated from an IDL. It simulates a method
+ * against a cluster configured by the provider, and then parses the events
+ * and extracts return data from the raw logs emitted during the simulation.
+ */
+export type ViewFn<
+  IDL extends Idl = Idl,
+  I extends AllInstructions<IDL> = AllInstructions<IDL>
+> = InstructionContextFn<IDL, I, Promise<any>>;