Эх сурвалжийг харах

Allow PdaValueNode to inline their own PdaNode definition (#43)

Loris Leiva 1 жил өмнө
parent
commit
668b550aa2

+ 9 - 0
.changeset/tasty-rocks-sleep.md

@@ -0,0 +1,9 @@
+---
+"@kinobi-so/renderers-js-umi": patch
+"@kinobi-so/renderers-js": patch
+"@kinobi-so/node-types": patch
+"@kinobi-so/visitors": patch
+"@kinobi-so/nodes": patch
+---
+
+Allow PdaValueNode to inline their own PdaNode definition

+ 1 - 0
packages/node-types/src/PdaNode.ts

@@ -7,6 +7,7 @@ export interface PdaNode<TSeeds extends PdaSeedNode[] = PdaSeedNode[]> {
     // Data.
     readonly name: CamelCaseString;
     readonly docs: Docs;
+    readonly programId?: string;
 
     // Children.
     readonly seeds: TSeeds;

+ 2 - 1
packages/node-types/src/contextualValueNodes/PdaValueNode.ts

@@ -1,10 +1,11 @@
 import type { PdaLinkNode } from '../linkNodes';
+import type { PdaNode } from '../PdaNode';
 import type { PdaSeedValueNode } from './PdaSeedValueNode';
 
 export interface PdaValueNode<TSeeds extends PdaSeedValueNode[] = PdaSeedValueNode[]> {
     readonly kind: 'pdaValueNode';
 
     // Children.
-    readonly pda: PdaLinkNode;
+    readonly pda: PdaLinkNode | PdaNode;
     readonly seeds: TSeeds;
 }

+ 1 - 0
packages/nodes/src/PdaNode.ts

@@ -17,6 +17,7 @@ export function pdaNode<const TSeeds extends PdaSeedNode[]>(input: PdaNodeInput<
         // Data.
         name: camelCase(input.name),
         docs: parseDocs(input.docs),
+        ...(input.programId && { programId: input.programId }),
 
         // Children.
         seeds: input.seeds,

+ 2 - 2
packages/nodes/src/contextualValueNodes/PdaValueNode.ts

@@ -1,9 +1,9 @@
-import type { PdaLinkNode, PdaSeedValueNode, PdaValueNode } from '@kinobi-so/node-types';
+import type { PdaLinkNode, PdaNode, PdaSeedValueNode, PdaValueNode } from '@kinobi-so/node-types';
 
 import { pdaLinkNode } from '../linkNodes';
 
 export function pdaValueNode<const TSeeds extends PdaSeedValueNode[] = []>(
-    pda: PdaLinkNode | string,
+    pda: PdaLinkNode | PdaNode | string,
     seeds: TSeeds = [] as PdaSeedValueNode[] as TSeeds,
 ): PdaValueNode<TSeeds> {
     return Object.freeze({

+ 49 - 0
packages/renderers-js-umi/src/renderInstructionDefaults.ts

@@ -76,6 +76,55 @@ export function renderInstructionDefaults(
             imports.add('shared', 'expectPublicKey');
             return render(`expectPublicKey(resolvedAccounts.${name}.value)`);
         case 'pdaValueNode':
+            // Inlined PDA value.
+            if (isNode(defaultValue.pda, 'pdaNode')) {
+                const pdaProgram = defaultValue.pda.programId
+                    ? `context.programs.getPublicKey('${defaultValue.pda.programId}', '${defaultValue.pda.programId}')`
+                    : 'programId';
+                const pdaSeeds = defaultValue.pda.seeds.flatMap((seed): string[] => {
+                    if (isNode(seed, 'constantPdaSeedNode') && isNode(seed.value, 'programIdValueNode')) {
+                        imports
+                            .add('umiSerializers', 'publicKey')
+                            .addAlias('umiSerializers', 'publicKey', 'publicKeySerializer');
+                        return [`publicKeySerializer().serialize(${pdaProgram})`];
+                    }
+                    if (isNode(seed, 'constantPdaSeedNode') && !isNode(seed.value, 'programIdValueNode')) {
+                        const typeManifest = visit(seed.type, typeManifestVisitor);
+                        const valueManifest = visit(seed.value, typeManifestVisitor);
+                        imports.mergeWith(typeManifest.serializerImports);
+                        imports.mergeWith(valueManifest.valueImports);
+                        return [`${typeManifest.serializer}.serialize(${valueManifest.value})`];
+                    }
+                    if (isNode(seed, 'variablePdaSeedNode')) {
+                        const typeManifest = visit(seed.type, typeManifestVisitor);
+                        const valueSeed = defaultValue.seeds.find(s => s.name === seed.name)?.value;
+                        if (!valueSeed) return [];
+                        if (isNode(valueSeed, 'accountValueNode')) {
+                            imports.mergeWith(typeManifest.serializerImports);
+                            imports.add('shared', 'expectPublicKey');
+                            return [
+                                `${typeManifest.serializer}.serialize(expectPublicKey(resolvedAccounts.${camelCase(valueSeed.name)}.value))`,
+                            ];
+                        }
+                        if (isNode(valueSeed, 'argumentValueNode')) {
+                            imports.mergeWith(typeManifest.serializerImports);
+                            imports.add('shared', 'expectSome');
+                            return [
+                                `${typeManifest.serializer}.serialize(expectSome(${argObject}.${camelCase(valueSeed.name)}))`,
+                            ];
+                        }
+                        const valueManifest = visit(valueSeed, typeManifestVisitor);
+                        imports.mergeWith(typeManifest.serializerImports);
+                        imports.mergeWith(valueManifest.valueImports);
+                        return [`${typeManifest.serializer}.serialize(${valueManifest.value})`];
+                    }
+                    return [];
+                });
+
+                return render(`context.eddsa.findPda(${pdaProgram}, [${pdaSeeds.join(', ')}])`);
+            }
+
+            // Linked PDA value.
             const pdaFunction = `find${pascalCase(defaultValue.pda.name)}Pda`;
             const pdaImportFrom = defaultValue.pda.importFrom ?? 'generatedAccounts';
             imports.add(pdaImportFrom, pdaFunction);

+ 153 - 0
packages/renderers-js-umi/test/instructionsPage.test.ts

@@ -0,0 +1,153 @@
+import {
+    accountValueNode,
+    constantPdaSeedNodeFromString,
+    instructionAccountNode,
+    instructionNode,
+    pdaNode,
+    pdaSeedValueNode,
+    pdaValueNode,
+    programNode,
+    publicKeyTypeNode,
+    variablePdaSeedNode,
+} from '@kinobi-so/nodes';
+import { visit } from '@kinobi-so/visitors-core';
+import { test } from 'vitest';
+
+import { getRenderMapVisitor } from '../src';
+import { renderMapContains, renderMapContainsImports } from './_setup';
+
+test('it renders instruction accounts with linked PDAs as default value', async () => {
+    // Given the following program with a PDA node and an instruction account using it as default value.
+    const node = programNode({
+        instructions: [
+            instructionNode({
+                accounts: [
+                    instructionAccountNode({ isSigner: true, isWritable: false, name: 'authority' }),
+                    instructionAccountNode({
+                        defaultValue: pdaValueNode('counter', [
+                            pdaSeedValueNode('authority', accountValueNode('authority')),
+                        ]),
+                        isSigner: false,
+                        isWritable: false,
+                        name: 'counter',
+                    }),
+                ],
+                name: 'increment',
+            }),
+        ],
+        name: 'counter',
+        pdas: [
+            pdaNode({
+                name: 'counter',
+                seeds: [
+                    constantPdaSeedNodeFromString('utf8', 'counter'),
+                    variablePdaSeedNode('authority', publicKeyTypeNode()),
+                ],
+            }),
+        ],
+        publicKey: '1111',
+    });
+
+    // When we render it.
+    const renderMap = visit(node, getRenderMapVisitor());
+
+    // Then we expect the following default value to be rendered.
+    await renderMapContains(renderMap, 'instructions/increment.ts', [
+        'if (!resolvedAccounts.counter.value) { ' +
+            'resolvedAccounts.counter.value = findCounterPda( context, { authority: expectPublicKey ( resolvedAccounts.authority.value ) } ); ' +
+            '}',
+    ]);
+    renderMapContainsImports(renderMap, 'instructions/increment.ts', { '../accounts': ['findCounterPda'] });
+});
+
+test('it renders instruction accounts with inlined PDAs as default value', async () => {
+    // Given the following instruction with an inlined PDA default value.
+    const node = programNode({
+        instructions: [
+            instructionNode({
+                accounts: [
+                    instructionAccountNode({ isSigner: true, isWritable: false, name: 'authority' }),
+                    instructionAccountNode({
+                        defaultValue: pdaValueNode(
+                            pdaNode({
+                                name: 'counter',
+                                seeds: [
+                                    constantPdaSeedNodeFromString('utf8', 'counter'),
+                                    variablePdaSeedNode('authority', publicKeyTypeNode()),
+                                ],
+                            }),
+                            [pdaSeedValueNode('authority', accountValueNode('authority'))],
+                        ),
+                        isSigner: false,
+                        isWritable: false,
+                        name: 'counter',
+                    }),
+                ],
+                name: 'increment',
+            }),
+        ],
+        name: 'counter',
+        publicKey: '1111',
+    });
+
+    // When we render it.
+    const renderMap = visit(node, getRenderMapVisitor());
+
+    // Then we expect the following default value to be rendered.
+    await renderMapContains(renderMap, 'instructions/increment.ts', [
+        'if (!resolvedAccounts.counter.value) { ' +
+            'resolvedAccounts.counter.value = context.eddsa.findPda( programId, [ ' +
+            "  string({ size: 'variable' }).serialize( 'counter' ), " +
+            '  publicKeySerializer().serialize( expectPublicKey( resolvedAccounts.authority.value ) ) ' +
+            '] ); ' +
+            '}',
+    ]);
+});
+
+test('it renders instruction accounts with inlined PDAs from another program as default value', async () => {
+    // Given the following instruction with an inlined PDA default value from another program.
+    const node = programNode({
+        instructions: [
+            instructionNode({
+                accounts: [
+                    instructionAccountNode({ isSigner: true, isWritable: false, name: 'authority' }),
+                    instructionAccountNode({
+                        defaultValue: pdaValueNode(
+                            pdaNode({
+                                name: 'counter',
+                                programId: '2222',
+                                seeds: [
+                                    constantPdaSeedNodeFromString('utf8', 'counter'),
+                                    variablePdaSeedNode('authority', publicKeyTypeNode()),
+                                ],
+                            }),
+                            [pdaSeedValueNode('authority', accountValueNode('authority'))],
+                        ),
+                        isSigner: false,
+                        isWritable: false,
+                        name: 'counter',
+                    }),
+                ],
+                name: 'increment',
+            }),
+        ],
+        name: 'counter',
+        publicKey: '1111',
+    });
+
+    // When we render it.
+    const renderMap = visit(node, getRenderMapVisitor());
+
+    // Then we expect the following default value to be rendered.
+    await renderMapContains(renderMap, 'instructions/increment.ts', [
+        'if (!resolvedAccounts.counter.value) { ' +
+            'resolvedAccounts.counter.value = context.eddsa.findPda( ' +
+            "  context.programs.getPublicKey('2222', '2222'), " +
+            '  [ ' +
+            "    string({ size: 'variable' }).serialize( 'counter' ), " +
+            '    publicKeySerializer().serialize( expectPublicKey( resolvedAccounts.authority.value ) ) ' +
+            '  ] ' +
+            '); ' +
+            '}',
+    ]);
+});

+ 65 - 0
packages/renderers-js/src/fragments/instructionInputDefault.ts

@@ -55,6 +55,71 @@ export function getInstructionInputDefaultFragment(
             return defaultFragment(`expectAddress(accounts.${name}.value)`).addImports('shared', 'expectAddress');
 
         case 'pdaValueNode':
+            // Inlined PDA value.
+            if (isNode(defaultValue.pda, 'pdaNode')) {
+                const pdaProgram = defaultValue.pda.programId
+                    ? fragment(
+                          `'${defaultValue.pda.programId}' as Address<'${defaultValue.pda.programId}'>`,
+                      ).addImports('solanaAddresses', 'Address')
+                    : fragment('programAddress');
+                const pdaSeeds = defaultValue.pda.seeds.flatMap((seed): Fragment[] => {
+                    if (isNode(seed, 'constantPdaSeedNode') && isNode(seed.value, 'programIdValueNode')) {
+                        return [
+                            fragment(`getAddressEncoder().encode(${pdaProgram})`)
+                                .mergeImportsWith(pdaProgram)
+                                .addImports('solanaAddresses', 'getAddressEncoder'),
+                        ];
+                    }
+                    if (isNode(seed, 'constantPdaSeedNode') && !isNode(seed.value, 'programIdValueNode')) {
+                        const typeManifest = visit(seed.type, typeManifestVisitor);
+                        const valueManifest = visit(seed.value, typeManifestVisitor);
+                        return [
+                            fragment(`${typeManifest.encoder}.encode(${valueManifest.value})`).mergeImportsWith(
+                                typeManifest.encoder,
+                                valueManifest.value,
+                            ),
+                        ];
+                    }
+                    if (isNode(seed, 'variablePdaSeedNode')) {
+                        const typeManifest = visit(seed.type, typeManifestVisitor);
+                        const valueSeed = defaultValue.seeds.find(s => s.name === seed.name)?.value;
+                        if (!valueSeed) return [];
+                        if (isNode(valueSeed, 'accountValueNode')) {
+                            return [
+                                fragment(
+                                    `${typeManifest.encoder}.encode(expectAddress(accounts.${camelCase(valueSeed.name)}.value))`,
+                                )
+                                    .mergeImportsWith(typeManifest.encoder)
+                                    .addImports('shared', 'expectAddress'),
+                            ];
+                        }
+                        if (isNode(valueSeed, 'argumentValueNode')) {
+                            return [
+                                fragment(
+                                    `${typeManifest.encoder}.encode(expectSome(args.${camelCase(valueSeed.name)}))`,
+                                )
+                                    .mergeImportsWith(typeManifest.encoder)
+                                    .addImports('shared', 'expectSome'),
+                            ];
+                        }
+                        const valueManifest = visit(valueSeed, typeManifestVisitor);
+                        return [
+                            fragment(`${typeManifest.encoder}.encode(${valueManifest.value})`).mergeImportsWith(
+                                typeManifest.encoder,
+                                valueManifest.value,
+                            ),
+                        ];
+                    }
+                    return [];
+                });
+                const pdaStatement = mergeFragments([pdaProgram, ...pdaSeeds], ([p, ...s]) => {
+                    const programAddress = p === 'programAddress' ? p : `programAddress: ${p}`;
+                    return `await getProgramDerivedAddress({ ${programAddress}, seeds: [${s.join(', ')}] })`;
+                }).addImports('solanaAddresses', 'getProgramDerivedAddress');
+                return defaultFragment(pdaStatement.render).mergeImportsWith(pdaStatement);
+            }
+
+            // Linked PDA value.
             const pdaFunction = nameApi.pdaFindFunction(defaultValue.pda.name);
             const pdaImportFrom = defaultValue.pda.importFrom ?? 'generatedPdas';
             const pdaArgs = [];

+ 153 - 1
packages/renderers-js/test/instructionsPage.test.ts

@@ -1,17 +1,24 @@
 import {
+    accountValueNode,
     argumentValueNode,
+    constantPdaSeedNodeFromString,
     instructionAccountNode,
     instructionArgumentNode,
     instructionNode,
     numberTypeNode,
+    pdaNode,
+    pdaSeedValueNode,
+    pdaValueNode,
     programNode,
+    publicKeyTypeNode,
     resolverValueNode,
+    variablePdaSeedNode,
 } from '@kinobi-so/nodes';
 import { visit } from '@kinobi-so/visitors-core';
 import { test } from 'vitest';
 
 import { getRenderMapVisitor } from '../src';
-import { codeContains, codeDoesNotContain, renderMapContains } from './_setup';
+import { codeContains, codeDoesNotContain, renderMapContains, renderMapContainsImports } from './_setup';
 
 test('it renders instruction accounts that can either be signer or non-signer', async () => {
     // Given the following instruction with a signer or non-signer account.
@@ -100,3 +107,148 @@ test('it only renders the args variable on the async function if the sync functi
     await codeContains(asyncFunction, ['// Original args.', 'const args = { ...input }']);
     await codeDoesNotContain(syncFunction, ['// Original args.', 'const args = { ...input }']);
 });
+
+test('it renders instruction accounts with linked PDAs as default value', async () => {
+    // Given the following program with a PDA node and an instruction account using it as default value.
+    const node = programNode({
+        instructions: [
+            instructionNode({
+                accounts: [
+                    instructionAccountNode({ isSigner: true, isWritable: false, name: 'authority' }),
+                    instructionAccountNode({
+                        defaultValue: pdaValueNode('counter', [
+                            pdaSeedValueNode('authority', accountValueNode('authority')),
+                        ]),
+                        isSigner: false,
+                        isWritable: false,
+                        name: 'counter',
+                    }),
+                ],
+                name: 'increment',
+            }),
+        ],
+        name: 'counter',
+        pdas: [
+            pdaNode({
+                name: 'counter',
+                seeds: [
+                    constantPdaSeedNodeFromString('utf8', 'counter'),
+                    variablePdaSeedNode('authority', publicKeyTypeNode()),
+                ],
+            }),
+        ],
+        publicKey: '1111',
+    });
+
+    // When we render it.
+    const renderMap = visit(node, getRenderMapVisitor());
+
+    // Then we expect the following default value to be rendered.
+    await renderMapContains(renderMap, 'instructions/increment.ts', [
+        'if (!accounts.counter.value) { ' +
+            'accounts.counter.value = await findCounterPda( { authority: expectAddress ( accounts.authority.value ) } ); ' +
+            '}',
+    ]);
+    renderMapContainsImports(renderMap, 'instructions/increment.ts', { '../pdas': ['findCounterPda'] });
+});
+
+test('it renders instruction accounts with inlined PDAs as default value', async () => {
+    // Given the following instruction with an inlined PDA default value.
+    const node = programNode({
+        instructions: [
+            instructionNode({
+                accounts: [
+                    instructionAccountNode({ isSigner: true, isWritable: false, name: 'authority' }),
+                    instructionAccountNode({
+                        defaultValue: pdaValueNode(
+                            pdaNode({
+                                name: 'counter',
+                                seeds: [
+                                    constantPdaSeedNodeFromString('utf8', 'counter'),
+                                    variablePdaSeedNode('authority', publicKeyTypeNode()),
+                                ],
+                            }),
+                            [pdaSeedValueNode('authority', accountValueNode('authority'))],
+                        ),
+                        isSigner: false,
+                        isWritable: false,
+                        name: 'counter',
+                    }),
+                ],
+                name: 'increment',
+            }),
+        ],
+        name: 'counter',
+        publicKey: '1111',
+    });
+
+    // When we render it.
+    const renderMap = visit(node, getRenderMapVisitor());
+
+    // Then we expect the following default value to be rendered.
+    await renderMapContains(renderMap, 'instructions/increment.ts', [
+        'if (!accounts.counter.value) { ' +
+            'accounts.counter.value = await getProgramDerivedAddress( { ' +
+            '  programAddress, ' +
+            '  seeds: [ ' +
+            "    getUtf8Encoder().encode('counter'), " +
+            '    getAddressEncoder().encode(expectAddress(accounts.authority.value)) ' +
+            '  ] ' +
+            '} ); ' +
+            '}',
+    ]);
+    renderMapContainsImports(renderMap, 'instructions/increment.ts', {
+        '@solana/web3.js': ['getProgramDerivedAddress'],
+    });
+});
+
+test('it renders instruction accounts with inlined PDAs from another program as default value', async () => {
+    // Given the following instruction with an inlined PDA default value from another program.
+    const node = programNode({
+        instructions: [
+            instructionNode({
+                accounts: [
+                    instructionAccountNode({ isSigner: true, isWritable: false, name: 'authority' }),
+                    instructionAccountNode({
+                        defaultValue: pdaValueNode(
+                            pdaNode({
+                                name: 'counter',
+                                programId: '2222',
+                                seeds: [
+                                    constantPdaSeedNodeFromString('utf8', 'counter'),
+                                    variablePdaSeedNode('authority', publicKeyTypeNode()),
+                                ],
+                            }),
+                            [pdaSeedValueNode('authority', accountValueNode('authority'))],
+                        ),
+                        isSigner: false,
+                        isWritable: false,
+                        name: 'counter',
+                    }),
+                ],
+                name: 'increment',
+            }),
+        ],
+        name: 'counter',
+        publicKey: '1111',
+    });
+
+    // When we render it.
+    const renderMap = visit(node, getRenderMapVisitor());
+
+    // Then we expect the following default value to be rendered.
+    await renderMapContains(renderMap, 'instructions/increment.ts', [
+        'if (!accounts.counter.value) { ' +
+            'accounts.counter.value = await getProgramDerivedAddress( { ' +
+            "  programAddress: '2222' as Address<'2222'>, " +
+            '  seeds: [ ' +
+            "    getUtf8Encoder().encode('counter'), " +
+            '    getAddressEncoder().encode(expectAddress(accounts.authority.value)) ' +
+            '  ] ' +
+            '} ); ' +
+            '}',
+    ]);
+    renderMapContainsImports(renderMap, 'instructions/increment.ts', {
+        '@solana/web3.js': ['Address', 'getProgramDerivedAddress'],
+    });
+});

+ 1 - 1
packages/visitors/src/fillDefaultPdaSeedValuesVisitor.ts

@@ -38,7 +38,7 @@ export function fillDefaultPdaSeedValuesVisitor(
             visitPdaValue(node, { next }) {
                 const visitedNode = next(node);
                 assertIsNode(visitedNode, 'pdaValueNode');
-                const foundPda = linkables.get(visitedNode.pda);
+                const foundPda = isNode(visitedNode.pda, 'pdaNode') ? visitedNode.pda : linkables.get(visitedNode.pda);
                 if (!foundPda) return visitedNode;
                 const seeds = addDefaultSeedValuesFromPdaWhenMissing(instruction, foundPda, visitedNode.seeds);
                 if (strictMode && !allSeedsAreValid(instruction, foundPda, seeds)) {