浏览代码

Use NodePaths in fillDefaultPdaSeedValuesVisitor (#281)

This PR refactors the `fillDefaultPdaSeedValuesVisitor` to use `NodePaths`.
Loris Leiva 1 年之前
父节点
当前提交
c78f3ca229

+ 5 - 0
.changeset/curly-berries-jog.md

@@ -0,0 +1,5 @@
+---
+'@codama/visitors': minor
+---
+
+Use `NodePaths` in `fillDefaultPdaSeedValuesVisitor`

+ 2 - 2
packages/visitors/README.md

@@ -96,7 +96,7 @@ codama.update(deduplicateIdenticalDefinedTypesVisitor());
 
 ### `fillDefaultPdaSeedValuesVisitor`
 
-This visitor fills any missing `PdaSeedValueNodes` from `PdaValueNodes` using the provided `InstructionNode` such that:
+This visitor fills any missing `PdaSeedValueNodes` from `PdaValueNodes` using the provided `NodePath<InstructionNode>` such that:
 
 -   If a `VariablePdaSeedNode` is of type `PublicKeyTypeNode` and the name of the seed matches the name of an account in the `InstructionNode`, then a new `PdaSeedValueNode` will be added with the matching account.
 -   Otherwise, if a `VariablePdaSeedNode` is of any other type and the name of the seed matches the name of an argument in the `InstructionNode`, then a new `PdaSeedValueNode` will be added with the matching argument.
@@ -107,7 +107,7 @@ It also requires a [`LinkableDictionary`](../visitors-core/README.md#linkable-di
 Note that this visitor is mainly used for internal purposes.
 
 ```ts
-codama.update(fillDefaultPdaSeedValuesVisitor(instructionNode, linkables, strictMode));
+codama.update(fillDefaultPdaSeedValuesVisitor(instructionPath, linkables, strictMode));
 ```
 
 ### `flattenInstructionDataArgumentsVisitor`

+ 12 - 4
packages/visitors/src/fillDefaultPdaSeedValuesVisitor.ts

@@ -14,7 +14,15 @@ import {
     pdaSeedValueNode,
     pdaValueNode,
 } from '@codama/nodes';
-import { extendVisitor, identityVisitor, LinkableDictionary, NodeStack, pipe, Visitor } from '@codama/visitors-core';
+import {
+    extendVisitor,
+    getLastNodeFromPath,
+    identityVisitor,
+    LinkableDictionary,
+    NodePath,
+    pipe,
+    Visitor,
+} from '@codama/visitors-core';
 
 /**
  * Fills in default values for variable PDA seeds that are not explicitly provided.
@@ -29,11 +37,11 @@ import { extendVisitor, identityVisitor, LinkableDictionary, NodeStack, pipe, Vi
  * pdaSeedValueNodes contains invalid seeds or if there aren't enough variable seeds.
  */
 export function fillDefaultPdaSeedValuesVisitor(
-    instruction: InstructionNode,
-    stack: NodeStack,
+    instructionPath: NodePath<InstructionNode>,
     linkables: LinkableDictionary,
     strictMode: boolean = false,
 ) {
+    const instruction = getLastNodeFromPath(instructionPath);
     return pipe(identityVisitor(INSTRUCTION_INPUT_VALUE_NODES), v =>
         extendVisitor(v, {
             visitPdaValue(node, { next }) {
@@ -41,7 +49,7 @@ export function fillDefaultPdaSeedValuesVisitor(
                 assertIsNode(visitedNode, 'pdaValueNode');
                 const foundPda = isNode(visitedNode.pda, 'pdaNode')
                     ? visitedNode.pda
-                    : linkables.get([...stack.getPath(), visitedNode.pda]);
+                    : linkables.get([...instructionPath, visitedNode.pda]);
                 if (!foundPda) return visitedNode;
                 const seeds = addDefaultSeedValuesFromPdaWhenMissing(instruction, foundPda, visitedNode.seeds);
                 if (strictMode && !allSeedsAreValid(instruction, foundPda, seeds)) {

+ 2 - 1
packages/visitors/src/setInstructionAccountDefaultValuesVisitor.ts

@@ -167,6 +167,7 @@ export function setInstructionAccountDefaultValuesVisitor(rules: InstructionAcco
         v =>
             extendVisitor(v, {
                 visitInstruction(node) {
+                    const instructionPath = stack.getPath('instructionNode');
                     const instructionAccounts = node.accounts.map((account): InstructionAccountNode => {
                         const rule = matchRule(node, account);
                         if (!rule) return account;
@@ -180,7 +181,7 @@ export function setInstructionAccountDefaultValuesVisitor(rules: InstructionAcco
                                 ...account,
                                 defaultValue: visit(
                                     rule.defaultValue,
-                                    fillDefaultPdaSeedValuesVisitor(node, stack, linkables, true),
+                                    fillDefaultPdaSeedValuesVisitor(instructionPath, linkables, true),
                                 ),
                             };
                         } catch (error) {

+ 5 - 4
packages/visitors/src/updateInstructionsVisitor.ts

@@ -16,6 +16,7 @@ import {
     BottomUpNodeTransformerWithSelector,
     bottomUpTransformerVisitor,
     LinkableDictionary,
+    NodePath,
     NodeStack,
     pipe,
     recordLinkablesOnFirstVisitVisitor,
@@ -72,10 +73,11 @@ export function updateInstructionsVisitor(map: Record<string, InstructionUpdates
                     return null;
                 }
 
+                const instructionPath = stack.getPath('instructionNode');
                 const { accounts: accountUpdates, arguments: argumentUpdates, ...metadataUpdates } = updates;
                 const { newArguments, newExtraArguments } = handleInstructionArguments(node, argumentUpdates ?? {});
                 const newAccounts = node.accounts.map(account =>
-                    handleInstructionAccount(node, stack, account, accountUpdates ?? {}, linkables),
+                    handleInstructionAccount(instructionPath, account, accountUpdates ?? {}, linkables),
                 );
                 return instructionNode({
                     ...node,
@@ -96,8 +98,7 @@ export function updateInstructionsVisitor(map: Record<string, InstructionUpdates
 }
 
 function handleInstructionAccount(
-    instruction: InstructionNode,
-    stack: NodeStack,
+    instructionPath: NodePath<InstructionNode>,
     account: InstructionAccountNode,
     accountUpdates: InstructionAccountUpdates,
     linkables: LinkableDictionary,
@@ -115,7 +116,7 @@ function handleInstructionAccount(
 
     return instructionAccountNode({
         ...acountWithoutDefault,
-        defaultValue: visit(defaultValue, fillDefaultPdaSeedValuesVisitor(instruction, stack, linkables)),
+        defaultValue: visit(defaultValue, fillDefaultPdaSeedValuesVisitor(instructionPath, linkables)),
     });
 }
 

+ 4 - 7
packages/visitors/test/fillDefaultPdaSeedValuesVisitor.test.ts

@@ -14,7 +14,7 @@ import {
     publicKeyTypeNode,
     variablePdaSeedNode,
 } from '@codama/nodes';
-import { LinkableDictionary, NodeStack, visit } from '@codama/visitors-core';
+import { LinkableDictionary, visit } from '@codama/visitors-core';
 import { expect, test } from 'vitest';
 
 import { fillDefaultPdaSeedValuesVisitor } from '../src';
@@ -56,10 +56,9 @@ test('it fills missing pda seed values with default values', () => {
         arguments: [instructionArgumentNode({ name: 'seed2', type: numberTypeNode('u64') })],
         name: 'myInstruction',
     });
-    const instructionStack = new NodeStack([program, instruction]);
 
     // When we fill the PDA seeds with default values.
-    const result = visit(node, fillDefaultPdaSeedValuesVisitor(instruction, instructionStack, linkables));
+    const result = visit(node, fillDefaultPdaSeedValuesVisitor([program, instruction], linkables));
 
     // Then we expect the following pdaValueNode to be returned.
     expect(result).toEqual(
@@ -111,10 +110,9 @@ test('it fills nested pda value nodes', () => {
         arguments: [instructionArgumentNode({ name: 'seed2', type: numberTypeNode('u64') })],
         name: 'myInstruction',
     });
-    const instructionStack = new NodeStack([program, instruction]);
 
     // When we fill the PDA seeds with default values.
-    const result = visit(node, fillDefaultPdaSeedValuesVisitor(instruction, instructionStack, linkables));
+    const result = visit(node, fillDefaultPdaSeedValuesVisitor([program, instruction], linkables));
 
     // Then we expect the following conditionalValueNode to be returned.
     expect(result).toEqual(
@@ -159,10 +157,9 @@ test('it ignores default seeds missing from the instruction', () => {
         arguments: [instructionArgumentNode({ name: 'seed2', type: numberTypeNode('u64') })],
         name: 'myInstruction',
     });
-    const instructionStack = new NodeStack([program, instruction]);
 
     // When we fill the PDA seeds with default values.
-    const result = visit(node, fillDefaultPdaSeedValuesVisitor(instruction, instructionStack, linkables));
+    const result = visit(node, fillDefaultPdaSeedValuesVisitor([program, instruction], linkables));
 
     // Then we expect the following pdaValueNode to be returned.
     expect(result).toEqual(