Quellcode durchsuchen

Allow passing NodeStacks to nested visitors (#285)

This PR adds an `options` objects to most core visitors (or refactors optional arguments into an `options` objects).

Within this `options` object, it offers the ability for any visitor that uses a `NodeStack` to use the provided one instead of a brand new one. This enables visitors to call others visitors during their traversal whilst all sharing the same `NodeStack` and not losing track of where we are in the tree.
Loris Leiva vor 1 Jahr
Ursprung
Commit
ce4936c031
32 geänderte Dateien mit 303 neuen und 203 gelöschten Zeilen
  1. 6 0
      .changeset/spicy-camels-tease.md
  2. 1 1
      packages/library/test/index.test.ts
  3. 1 1
      packages/renderers-js-umi/src/getRenderMapVisitor.ts
  4. 10 8
      packages/renderers-js-umi/src/getTypeManifestVisitor.ts
  5. 3 4
      packages/renderers-js/src/getRenderMapVisitor.ts
  6. 10 8
      packages/renderers-js/src/getTypeManifestVisitor.ts
  7. 3 4
      packages/renderers-rust/src/getRenderMapVisitor.ts
  8. 1 1
      packages/renderers-rust/src/getTypeManifestVisitor.ts
  9. 3 3
      packages/visitors-core/src/bottomUpTransformerVisitor.ts
  10. 2 2
      packages/visitors-core/src/deleteNodesVisitor.ts
  11. 13 9
      packages/visitors-core/src/getByteSizeVisitor.ts
  12. 61 61
      packages/visitors-core/src/identityVisitor.ts
  13. 61 61
      packages/visitors-core/src/mergeVisitor.ts
  14. 3 3
      packages/visitors-core/src/nonNullableIdentityVisitor.ts
  15. 2 2
      packages/visitors-core/src/removeDocsVisitor.ts
  16. 3 2
      packages/visitors-core/src/staticVisitor.ts
  17. 3 3
      packages/visitors-core/src/topDownTransformerVisitor.ts
  18. 4 2
      packages/visitors-core/src/voidVisitor.ts
  19. 50 6
      packages/visitors-core/test/bottomUpTransformerVisitor.test.ts
  20. 3 4
      packages/visitors-core/test/deleteNodesVisitor.test.ts
  21. 1 1
      packages/visitors-core/test/extendVisitor.test.ts
  22. 2 4
      packages/visitors-core/test/getByteSizeVisitor.test.ts
  23. 1 1
      packages/visitors-core/test/identityVisitor.test.ts
  24. 1 1
      packages/visitors-core/test/mapVisitor.test.ts
  25. 1 1
      packages/visitors-core/test/mergeVisitor.test.ts
  26. 1 1
      packages/visitors-core/test/removeDocsVisitor.test.ts
  27. 1 1
      packages/visitors-core/test/staticVisitor.test.ts
  28. 47 3
      packages/visitors-core/test/topDownTransformerVisitor.test.ts
  29. 1 1
      packages/visitors/src/fillDefaultPdaSeedValuesVisitor.ts
  30. 2 2
      packages/visitors/src/setFixedAccountSizesVisitor.ts
  31. 1 1
      packages/visitors/src/setInstructionAccountDefaultValuesVisitor.ts
  32. 1 1
      packages/visitors/src/transformDefinedTypesIntoAccountsVisitor.ts

+ 6 - 0
.changeset/spicy-camels-tease.md

@@ -0,0 +1,6 @@
+---
+'@codama/visitors-core': minor
+'@codama/visitors': minor
+---
+
+Allow passing `NodeStacks` to nested visitors

+ 1 - 1
packages/library/test/index.test.ts

@@ -12,7 +12,7 @@ test('it exports visitors', () => {
 
 test('it accepts visitors', () => {
     const codama = createFromRoot(rootNode(programNode({ name: 'myProgram', publicKey: '1111' })));
-    const visitor = voidVisitor(['rootNode']);
+    const visitor = voidVisitor({ keys: ['rootNode'] });
     const result = codama.accept(visitor) satisfies void;
     expect(typeof result).toBe('undefined');
 });

+ 1 - 1
packages/renderers-js-umi/src/getRenderMapVisitor.ts

@@ -100,7 +100,7 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}): Visitor<
         });
     const typeManifestVisitor = getTypeManifestVisitor();
     const resolvedInstructionInputVisitor = getResolvedInstructionInputsVisitor();
-    const byteSizeVisitor = getByteSizeVisitor(linkables, stack);
+    const byteSizeVisitor = getByteSizeVisitor(linkables, { stack });
 
     function getInstructionAccountType(account: ResolvedInstructionAccount): string {
         if (account.isPda && account.isSigner === false) return 'Pda';

+ 10 - 8
packages/renderers-js-umi/src/getTypeManifestVisitor.ts

@@ -85,14 +85,16 @@ export function getTypeManifestVisitor(input: {
                     value: '',
                     valueImports: new ImportMap(),
                 }) as TypeManifest,
-            [
-                ...REGISTERED_TYPE_NODE_KINDS,
-                ...REGISTERED_VALUE_NODE_KINDS,
-                'definedTypeLinkNode',
-                'definedTypeNode',
-                'accountNode',
-                'instructionNode',
-            ],
+            {
+                keys: [
+                    ...REGISTERED_TYPE_NODE_KINDS,
+                    ...REGISTERED_VALUE_NODE_KINDS,
+                    'definedTypeLinkNode',
+                    'definedTypeNode',
+                    'accountNode',
+                    'instructionNode',
+                ],
+            },
         ),
         v =>
             extendVisitor(v, {

+ 3 - 4
packages/renderers-js/src/getRenderMapVisitor.ts

@@ -136,10 +136,9 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}) {
     };
 
     return pipe(
-        staticVisitor(
-            () => new RenderMap(),
-            ['rootNode', 'programNode', 'pdaNode', 'accountNode', 'definedTypeNode', 'instructionNode'],
-        ),
+        staticVisitor(() => new RenderMap(), {
+            keys: ['rootNode', 'programNode', 'pdaNode', 'accountNode', 'definedTypeNode', 'instructionNode'],
+        }),
         v =>
             extendVisitor(v, {
                 visitAccount(node) {

+ 10 - 8
packages/renderers-js/src/getTypeManifestVisitor.ts

@@ -58,14 +58,16 @@ export function getTypeManifestVisitor(input: {
                     strictType: fragment(''),
                     value: fragment(''),
                 }) as TypeManifest,
-            [
-                ...REGISTERED_TYPE_NODE_KINDS,
-                ...REGISTERED_VALUE_NODE_KINDS,
-                'definedTypeLinkNode',
-                'definedTypeNode',
-                'accountNode',
-                'instructionNode',
-            ],
+            {
+                keys: [
+                    ...REGISTERED_TYPE_NODE_KINDS,
+                    ...REGISTERED_VALUE_NODE_KINDS,
+                    'definedTypeLinkNode',
+                    'definedTypeNode',
+                    'accountNode',
+                    'instructionNode',
+                ],
+            },
         ),
         visitor =>
             extendVisitor(visitor, {

+ 3 - 4
packages/renderers-rust/src/getRenderMapVisitor.ts

@@ -53,10 +53,9 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}) {
     const anchorTraits = options.anchorTraits ?? true;
 
     return pipe(
-        staticVisitor(
-            () => new RenderMap(),
-            ['rootNode', 'programNode', 'instructionNode', 'accountNode', 'definedTypeNode'],
-        ),
+        staticVisitor(() => new RenderMap(), {
+            keys: ['rootNode', 'programNode', 'instructionNode', 'accountNode', 'definedTypeNode'],
+        }),
         v =>
             extendVisitor(v, {
                 visitAccount(node) {

+ 1 - 1
packages/renderers-rust/src/getTypeManifestVisitor.ts

@@ -44,7 +44,7 @@ export function getTypeManifestVisitor(options: {
                 ...mergeManifests(values),
                 type: values.map(v => v.type).join('\n'),
             }),
-            [...REGISTERED_TYPE_NODE_KINDS, 'definedTypeLinkNode', 'definedTypeNode', 'accountNode'],
+            { keys: [...REGISTERED_TYPE_NODE_KINDS, 'definedTypeLinkNode', 'definedTypeNode', 'accountNode'] },
         ),
         v =>
             extendVisitor(v, {

+ 3 - 3
packages/visitors-core/src/bottomUpTransformerVisitor.ts

@@ -17,7 +17,7 @@ export type BottomUpNodeTransformerWithSelector = {
 
 export function bottomUpTransformerVisitor<TNodeKind extends NodeKind = NodeKind>(
     transformers: (BottomUpNodeTransformer | BottomUpNodeTransformerWithSelector)[],
-    nodeKeys?: TNodeKind[],
+    options: { keys?: TNodeKind[]; stack?: NodeStack } = {},
 ): Visitor<Node | null, TNodeKind> {
     const transformerFunctions = transformers.map((transformer): BottomUpNodeTransformer => {
         if (typeof transformer === 'function') return transformer;
@@ -27,9 +27,9 @@ export function bottomUpTransformerVisitor<TNodeKind extends NodeKind = NodeKind
                 : node;
     });
 
-    const stack = new NodeStack();
+    const stack = options.stack ?? new NodeStack();
     return pipe(
-        identityVisitor(nodeKeys),
+        identityVisitor(options),
         v =>
             interceptVisitor(v, (node, next) => {
                 return transformerFunctions.reduce(

+ 2 - 2
packages/visitors-core/src/deleteNodesVisitor.ts

@@ -5,7 +5,7 @@ import { TopDownNodeTransformerWithSelector, topDownTransformerVisitor } from '.
 
 export function deleteNodesVisitor<TNodeKind extends NodeKind = NodeKind>(
     selectors: NodeSelector[],
-    nodeKeys?: TNodeKind[],
+    options?: Parameters<typeof topDownTransformerVisitor<TNodeKind>>[1],
 ) {
     return topDownTransformerVisitor<TNodeKind>(
         selectors.map(
@@ -14,6 +14,6 @@ export function deleteNodesVisitor<TNodeKind extends NodeKind = NodeKind>(
                 transform: () => null,
             }),
         ),
-        nodeKeys,
+        options,
     );
 }

+ 13 - 9
packages/visitors-core/src/getByteSizeVisitor.ts

@@ -19,8 +19,10 @@ export type ByteSizeVisitorKeys =
 
 export function getByteSizeVisitor(
     linkables: LinkableDictionary,
-    stack: NodeStack,
+    options: { stack?: NodeStack } = {},
 ): Visitor<number | null, ByteSizeVisitorKeys> {
+    const stack = options.stack ?? new NodeStack();
+
     const visitedDefinedTypes = new Map<string, number | null>();
     const definedTypeStack: string[] = [];
 
@@ -30,14 +32,16 @@ export function getByteSizeVisitor(
     const baseVisitor = mergeVisitor(
         () => null as number | null,
         (_, values) => sumSizes(values),
-        [
-            ...REGISTERED_TYPE_NODE_KINDS,
-            'definedTypeLinkNode',
-            'definedTypeNode',
-            'accountNode',
-            'instructionNode',
-            'instructionArgumentNode',
-        ],
+        {
+            keys: [
+                ...REGISTERED_TYPE_NODE_KINDS,
+                'definedTypeLinkNode',
+                'definedTypeNode',
+                'accountNode',
+                'instructionNode',
+                'instructionArgumentNode',
+            ],
+        },
     );
 
     return pipe(

+ 61 - 61
packages/visitors-core/src/identityVisitor.ts

@@ -76,16 +76,16 @@ import { staticVisitor } from './staticVisitor';
 import { visit as baseVisit, Visitor } from './visitor';
 
 export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
-    nodeKeys: TNodeKind[] = REGISTERED_NODE_KINDS as TNodeKind[],
+    options: { keys?: TNodeKind[] } = {},
 ): Visitor<Node | null, TNodeKind> {
-    const castedNodeKeys: NodeKind[] = nodeKeys;
-    const visitor = staticVisitor(node => Object.freeze({ ...node }), castedNodeKeys) as Visitor<Node | null>;
+    const keys: NodeKind[] = options.keys ?? (REGISTERED_NODE_KINDS as TNodeKind[]);
+    const visitor = staticVisitor(node => Object.freeze({ ...node }), { keys }) as Visitor<Node | null>;
     const visit =
         (v: Visitor<Node | null>) =>
         (node: Node): Node | null =>
-            castedNodeKeys.includes(node.kind) ? baseVisit(node, v) : Object.freeze({ ...node });
+            keys.includes(node.kind) ? baseVisit(node, v) : Object.freeze({ ...node });
 
-    if (castedNodeKeys.includes('rootNode')) {
+    if (keys.includes('rootNode')) {
         visitor.visitRoot = function visitRoot(node) {
             const program = visit(this)(node.program);
             if (program === null) return null;
@@ -97,7 +97,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('programNode')) {
+    if (keys.includes('programNode')) {
         visitor.visitProgram = function visitProgram(node) {
             return programNode({
                 ...node,
@@ -114,7 +114,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('pdaNode')) {
+    if (keys.includes('pdaNode')) {
         visitor.visitPda = function visitPda(node) {
             return pdaNode({
                 ...node,
@@ -123,7 +123,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('accountNode')) {
+    if (keys.includes('accountNode')) {
         visitor.visitAccount = function visitAccount(node) {
             const data = visit(this)(node.data);
             if (data === null) return null;
@@ -134,7 +134,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('instructionNode')) {
+    if (keys.includes('instructionNode')) {
         visitor.visitInstruction = function visitInstruction(node) {
             return instructionNode({
                 ...node,
@@ -169,7 +169,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('instructionAccountNode')) {
+    if (keys.includes('instructionAccountNode')) {
         visitor.visitInstructionAccount = function visitInstructionAccount(node) {
             const defaultValue = node.defaultValue ? (visit(this)(node.defaultValue) ?? undefined) : undefined;
             if (defaultValue) assertIsNode(defaultValue, INSTRUCTION_INPUT_VALUE_NODES);
@@ -177,7 +177,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('instructionArgumentNode')) {
+    if (keys.includes('instructionArgumentNode')) {
         visitor.visitInstructionArgument = function visitInstructionArgument(node) {
             const type = visit(this)(node.type);
             if (type === null) return null;
@@ -188,7 +188,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('instructionRemainingAccountsNode')) {
+    if (keys.includes('instructionRemainingAccountsNode')) {
         visitor.visitInstructionRemainingAccounts = function visitInstructionRemainingAccounts(node) {
             const value = visit(this)(node.value);
             if (value === null) return null;
@@ -197,7 +197,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('instructionByteDeltaNode')) {
+    if (keys.includes('instructionByteDeltaNode')) {
         visitor.visitInstructionByteDelta = function visitInstructionByteDelta(node) {
             const value = visit(this)(node.value);
             if (value === null) return null;
@@ -206,7 +206,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('definedTypeNode')) {
+    if (keys.includes('definedTypeNode')) {
         visitor.visitDefinedType = function visitDefinedType(node) {
             const type = visit(this)(node.type);
             if (type === null) return null;
@@ -215,7 +215,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('arrayTypeNode')) {
+    if (keys.includes('arrayTypeNode')) {
         visitor.visitArrayType = function visitArrayType(node) {
             const size = visit(this)(node.count);
             if (size === null) return null;
@@ -227,7 +227,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('enumTypeNode')) {
+    if (keys.includes('enumTypeNode')) {
         visitor.visitEnumType = function visitEnumType(node) {
             return enumTypeNode(
                 node.variants.map(visit(this)).filter(removeNullAndAssertIsNodeFilter(ENUM_VARIANT_TYPE_NODES)),
@@ -236,7 +236,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('enumStructVariantTypeNode')) {
+    if (keys.includes('enumStructVariantTypeNode')) {
         visitor.visitEnumStructVariantType = function visitEnumStructVariantType(node) {
             const newStruct = visit(this)(node.struct);
             if (!newStruct) {
@@ -250,7 +250,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('enumTupleVariantTypeNode')) {
+    if (keys.includes('enumTupleVariantTypeNode')) {
         visitor.visitEnumTupleVariantType = function visitEnumTupleVariantType(node) {
             const newTuple = visit(this)(node.tuple);
             if (!newTuple) {
@@ -264,7 +264,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('mapTypeNode')) {
+    if (keys.includes('mapTypeNode')) {
         visitor.visitMapType = function visitMapType(node) {
             const size = visit(this)(node.count);
             if (size === null) return null;
@@ -279,7 +279,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('optionTypeNode')) {
+    if (keys.includes('optionTypeNode')) {
         visitor.visitOptionType = function visitOptionType(node) {
             const prefix = visit(this)(node.prefix);
             if (prefix === null) return null;
@@ -291,7 +291,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('zeroableOptionTypeNode')) {
+    if (keys.includes('zeroableOptionTypeNode')) {
         visitor.visitZeroableOptionType = function visitZeroableOptionType(node) {
             const item = visit(this)(node.item);
             if (item === null) return null;
@@ -302,7 +302,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('remainderOptionTypeNode')) {
+    if (keys.includes('remainderOptionTypeNode')) {
         visitor.visitRemainderOptionType = function visitRemainderOptionType(node) {
             const item = visit(this)(node.item);
             if (item === null) return null;
@@ -311,7 +311,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('booleanTypeNode')) {
+    if (keys.includes('booleanTypeNode')) {
         visitor.visitBooleanType = function visitBooleanType(node) {
             const size = visit(this)(node.size);
             if (size === null) return null;
@@ -320,7 +320,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('setTypeNode')) {
+    if (keys.includes('setTypeNode')) {
         visitor.visitSetType = function visitSetType(node) {
             const size = visit(this)(node.count);
             if (size === null) return null;
@@ -332,14 +332,14 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('structTypeNode')) {
+    if (keys.includes('structTypeNode')) {
         visitor.visitStructType = function visitStructType(node) {
             const fields = node.fields.map(visit(this)).filter(removeNullAndAssertIsNodeFilter('structFieldTypeNode'));
             return structTypeNode(fields);
         };
     }
 
-    if (castedNodeKeys.includes('structFieldTypeNode')) {
+    if (keys.includes('structFieldTypeNode')) {
         visitor.visitStructFieldType = function visitStructFieldType(node) {
             const type = visit(this)(node.type);
             if (type === null) return null;
@@ -350,14 +350,14 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('tupleTypeNode')) {
+    if (keys.includes('tupleTypeNode')) {
         visitor.visitTupleType = function visitTupleType(node) {
             const items = node.items.map(visit(this)).filter(removeNullAndAssertIsNodeFilter(TYPE_NODES));
             return tupleTypeNode(items);
         };
     }
 
-    if (castedNodeKeys.includes('amountTypeNode')) {
+    if (keys.includes('amountTypeNode')) {
         visitor.visitAmountType = function visitAmountType(node) {
             const number = visit(this)(node.number);
             if (number === null) return null;
@@ -366,7 +366,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('dateTimeTypeNode')) {
+    if (keys.includes('dateTimeTypeNode')) {
         visitor.visitDateTimeType = function visitDateTimeType(node) {
             const number = visit(this)(node.number);
             if (number === null) return null;
@@ -375,7 +375,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('solAmountTypeNode')) {
+    if (keys.includes('solAmountTypeNode')) {
         visitor.visitSolAmountType = function visitSolAmountType(node) {
             const number = visit(this)(node.number);
             if (number === null) return null;
@@ -384,7 +384,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('prefixedCountNode')) {
+    if (keys.includes('prefixedCountNode')) {
         visitor.visitPrefixedCount = function visitPrefixedCount(node) {
             const prefix = visit(this)(node.prefix);
             if (prefix === null) return null;
@@ -393,13 +393,13 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('arrayValueNode')) {
+    if (keys.includes('arrayValueNode')) {
         visitor.visitArrayValue = function visitArrayValue(node) {
             return arrayValueNode(node.items.map(visit(this)).filter(removeNullAndAssertIsNodeFilter(VALUE_NODES)));
         };
     }
 
-    if (castedNodeKeys.includes('constantValueNode')) {
+    if (keys.includes('constantValueNode')) {
         visitor.visitConstantValue = function visitConstantValue(node) {
             const type = visit(this)(node.type);
             if (type === null) return null;
@@ -411,7 +411,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('enumValueNode')) {
+    if (keys.includes('enumValueNode')) {
         visitor.visitEnumValue = function visitEnumValue(node) {
             const enumLink = visit(this)(node.enum);
             if (enumLink === null) return null;
@@ -422,7 +422,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('mapValueNode')) {
+    if (keys.includes('mapValueNode')) {
         visitor.visitMapValue = function visitMapValue(node) {
             return mapValueNode(
                 node.entries.map(visit(this)).filter(removeNullAndAssertIsNodeFilter('mapEntryValueNode')),
@@ -430,7 +430,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('mapEntryValueNode')) {
+    if (keys.includes('mapEntryValueNode')) {
         visitor.visitMapEntryValue = function visitMapEntryValue(node) {
             const key = visit(this)(node.key);
             if (key === null) return null;
@@ -442,13 +442,13 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('setValueNode')) {
+    if (keys.includes('setValueNode')) {
         visitor.visitSetValue = function visitSetValue(node) {
             return setValueNode(node.items.map(visit(this)).filter(removeNullAndAssertIsNodeFilter(VALUE_NODES)));
         };
     }
 
-    if (castedNodeKeys.includes('someValueNode')) {
+    if (keys.includes('someValueNode')) {
         visitor.visitSomeValue = function visitSomeValue(node) {
             const value = visit(this)(node.value);
             if (value === null) return null;
@@ -457,7 +457,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('structValueNode')) {
+    if (keys.includes('structValueNode')) {
         visitor.visitStructValue = function visitStructValue(node) {
             return structValueNode(
                 node.fields.map(visit(this)).filter(removeNullAndAssertIsNodeFilter('structFieldValueNode')),
@@ -465,7 +465,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('structFieldValueNode')) {
+    if (keys.includes('structFieldValueNode')) {
         visitor.visitStructFieldValue = function visitStructFieldValue(node) {
             const value = visit(this)(node.value);
             if (value === null) return null;
@@ -474,13 +474,13 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('tupleValueNode')) {
+    if (keys.includes('tupleValueNode')) {
         visitor.visitTupleValue = function visitTupleValue(node) {
             return tupleValueNode(node.items.map(visit(this)).filter(removeNullAndAssertIsNodeFilter(VALUE_NODES)));
         };
     }
 
-    if (castedNodeKeys.includes('constantPdaSeedNode')) {
+    if (keys.includes('constantPdaSeedNode')) {
         visitor.visitConstantPdaSeed = function visitConstantPdaSeed(node) {
             const type = visit(this)(node.type);
             if (type === null) return null;
@@ -492,7 +492,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('variablePdaSeedNode')) {
+    if (keys.includes('variablePdaSeedNode')) {
         visitor.visitVariablePdaSeed = function visitVariablePdaSeed(node) {
             const type = visit(this)(node.type);
             if (type === null) return null;
@@ -501,7 +501,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('resolverValueNode')) {
+    if (keys.includes('resolverValueNode')) {
         visitor.visitResolverValue = function visitResolverValue(node) {
             const dependsOn = (node.dependsOn ?? [])
                 .map(visit(this))
@@ -513,7 +513,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('conditionalValueNode')) {
+    if (keys.includes('conditionalValueNode')) {
         visitor.visitConditionalValue = function visitConditionalValue(node) {
             const condition = visit(this)(node.condition);
             if (condition === null) return null;
@@ -529,7 +529,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('pdaValueNode')) {
+    if (keys.includes('pdaValueNode')) {
         visitor.visitPdaValue = function visitPdaValue(node) {
             const pda = visit(this)(node.pda);
             if (pda === null) return null;
@@ -539,7 +539,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('pdaSeedValueNode')) {
+    if (keys.includes('pdaSeedValueNode')) {
         visitor.visitPdaSeedValue = function visitPdaSeedValue(node) {
             const value = visit(this)(node.value);
             if (value === null) return null;
@@ -548,7 +548,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('fixedSizeTypeNode')) {
+    if (keys.includes('fixedSizeTypeNode')) {
         visitor.visitFixedSizeType = function visitFixedSizeType(node) {
             const type = visit(this)(node.type);
             if (type === null) return null;
@@ -557,7 +557,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('sizePrefixTypeNode')) {
+    if (keys.includes('sizePrefixTypeNode')) {
         visitor.visitSizePrefixType = function visitSizePrefixType(node) {
             const prefix = visit(this)(node.prefix);
             if (prefix === null) return null;
@@ -569,7 +569,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('preOffsetTypeNode')) {
+    if (keys.includes('preOffsetTypeNode')) {
         visitor.visitPreOffsetType = function visitPreOffsetType(node) {
             const type = visit(this)(node.type);
             if (type === null) return null;
@@ -578,7 +578,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('postOffsetTypeNode')) {
+    if (keys.includes('postOffsetTypeNode')) {
         visitor.visitPostOffsetType = function visitPostOffsetType(node) {
             const type = visit(this)(node.type);
             if (type === null) return null;
@@ -587,7 +587,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('sentinelTypeNode')) {
+    if (keys.includes('sentinelTypeNode')) {
         visitor.visitSentinelType = function visitSentinelType(node) {
             const sentinel = visit(this)(node.sentinel);
             if (sentinel === null) return null;
@@ -599,7 +599,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('hiddenPrefixTypeNode')) {
+    if (keys.includes('hiddenPrefixTypeNode')) {
         visitor.visitHiddenPrefixType = function visitHiddenPrefixType(node) {
             const type = visit(this)(node.type);
             if (type === null) return null;
@@ -610,7 +610,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('hiddenSuffixTypeNode')) {
+    if (keys.includes('hiddenSuffixTypeNode')) {
         visitor.visitHiddenSuffixType = function visitHiddenSuffixType(node) {
             const type = visit(this)(node.type);
             if (type === null) return null;
@@ -621,7 +621,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('constantDiscriminatorNode')) {
+    if (keys.includes('constantDiscriminatorNode')) {
         visitor.visitConstantDiscriminator = function visitConstantDiscriminator(node) {
             const constant = visit(this)(node.constant);
             if (constant === null) return null;
@@ -630,7 +630,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('accountLinkNode')) {
+    if (keys.includes('accountLinkNode')) {
         visitor.visitAccountLink = function visitAccountLink(node) {
             const program = node.program ? (visit(this)(node.program) ?? undefined) : undefined;
             if (program) assertIsNode(program, 'programLinkNode');
@@ -638,7 +638,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('definedTypeLinkNode')) {
+    if (keys.includes('definedTypeLinkNode')) {
         visitor.visitDefinedTypeLink = function visitDefinedTypeLink(node) {
             const program = node.program ? (visit(this)(node.program) ?? undefined) : undefined;
             if (program) assertIsNode(program, 'programLinkNode');
@@ -646,7 +646,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('instructionLinkNode')) {
+    if (keys.includes('instructionLinkNode')) {
         visitor.visitInstructionLink = function visitInstructionLink(node) {
             const program = node.program ? (visit(this)(node.program) ?? undefined) : undefined;
             if (program) assertIsNode(program, 'programLinkNode');
@@ -654,7 +654,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('instructionAccountLinkNode')) {
+    if (keys.includes('instructionAccountLinkNode')) {
         visitor.visitInstructionAccountLink = function visitInstructionAccountLink(node) {
             const instruction = node.instruction ? (visit(this)(node.instruction) ?? undefined) : undefined;
             if (instruction) assertIsNode(instruction, 'instructionLinkNode');
@@ -662,7 +662,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('instructionArgumentLinkNode')) {
+    if (keys.includes('instructionArgumentLinkNode')) {
         visitor.visitInstructionArgumentLink = function visitInstructionArgumentLink(node) {
             const instruction = node.instruction ? (visit(this)(node.instruction) ?? undefined) : undefined;
             if (instruction) assertIsNode(instruction, 'instructionLinkNode');
@@ -670,7 +670,7 @@ export function identityVisitor<TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('pdaLinkNode')) {
+    if (keys.includes('pdaLinkNode')) {
         visitor.visitPdaLink = function visitPdaLink(node) {
             const program = node.program ? (visit(this)(node.program) ?? undefined) : undefined;
             if (program) assertIsNode(program, 'programLinkNode');

+ 61 - 61
packages/visitors-core/src/mergeVisitor.ts

@@ -6,22 +6,22 @@ import { visit as baseVisit, Visitor } from './visitor';
 export function mergeVisitor<TReturn, TNodeKind extends NodeKind = NodeKind>(
     leafValue: (node: Node) => TReturn,
     merge: (node: Node, values: TReturn[]) => TReturn,
-    nodeKeys: TNodeKind[] = REGISTERED_NODE_KINDS as TNodeKind[],
+    options: { keys?: TNodeKind[] } = {},
 ): Visitor<TReturn, TNodeKind> {
-    const castedNodeKeys: NodeKind[] = nodeKeys;
-    const visitor = staticVisitor(leafValue, castedNodeKeys) as Visitor<TReturn>;
+    const keys: NodeKind[] = options.keys ?? (REGISTERED_NODE_KINDS as NodeKind[]);
+    const visitor = staticVisitor(leafValue, { keys }) as Visitor<TReturn>;
     const visit =
         (v: Visitor<TReturn>) =>
         (node: Node): TReturn[] =>
-            castedNodeKeys.includes(node.kind) ? [baseVisit(node, v)] : [];
+            keys.includes(node.kind) ? [baseVisit(node, v)] : [];
 
-    if (castedNodeKeys.includes('rootNode')) {
+    if (keys.includes('rootNode')) {
         visitor.visitRoot = function visitRoot(node) {
             return merge(node, getAllPrograms(node).flatMap(visit(this)));
         };
     }
 
-    if (castedNodeKeys.includes('programNode')) {
+    if (keys.includes('programNode')) {
         visitor.visitProgram = function visitProgram(node) {
             return merge(node, [
                 ...node.pdas.flatMap(visit(this)),
@@ -33,13 +33,13 @@ export function mergeVisitor<TReturn, TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('pdaNode')) {
+    if (keys.includes('pdaNode')) {
         visitor.visitPda = function visitPda(node) {
             return merge(node, node.seeds.flatMap(visit(this)));
         };
     }
 
-    if (castedNodeKeys.includes('accountNode')) {
+    if (keys.includes('accountNode')) {
         visitor.visitAccount = function visitAccount(node) {
             return merge(node, [
                 ...visit(this)(node.data),
@@ -49,7 +49,7 @@ export function mergeVisitor<TReturn, TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('instructionNode')) {
+    if (keys.includes('instructionNode')) {
         visitor.visitInstruction = function visitInstruction(node) {
             return merge(node, [
                 ...node.accounts.flatMap(visit(this)),
@@ -63,13 +63,13 @@ export function mergeVisitor<TReturn, TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('instructionAccountNode')) {
+    if (keys.includes('instructionAccountNode')) {
         visitor.visitInstructionAccount = function visitInstructionAccount(node) {
             return merge(node, [...(node.defaultValue ? visit(this)(node.defaultValue) : [])]);
         };
     }
 
-    if (castedNodeKeys.includes('instructionArgumentNode')) {
+    if (keys.includes('instructionArgumentNode')) {
         visitor.visitInstructionArgument = function visitInstructionArgument(node) {
             return merge(node, [
                 ...visit(this)(node.type),
@@ -78,91 +78,91 @@ export function mergeVisitor<TReturn, TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('instructionRemainingAccountsNode')) {
+    if (keys.includes('instructionRemainingAccountsNode')) {
         visitor.visitInstructionRemainingAccounts = function visitInstructionRemainingAccounts(node) {
             return merge(node, visit(this)(node.value));
         };
     }
 
-    if (castedNodeKeys.includes('instructionByteDeltaNode')) {
+    if (keys.includes('instructionByteDeltaNode')) {
         visitor.visitInstructionByteDelta = function visitInstructionByteDelta(node) {
             return merge(node, visit(this)(node.value));
         };
     }
 
-    if (castedNodeKeys.includes('definedTypeNode')) {
+    if (keys.includes('definedTypeNode')) {
         visitor.visitDefinedType = function visitDefinedType(node) {
             return merge(node, visit(this)(node.type));
         };
     }
 
-    if (castedNodeKeys.includes('arrayTypeNode')) {
+    if (keys.includes('arrayTypeNode')) {
         visitor.visitArrayType = function visitArrayType(node) {
             return merge(node, [...visit(this)(node.count), ...visit(this)(node.item)]);
         };
     }
 
-    if (castedNodeKeys.includes('enumTypeNode')) {
+    if (keys.includes('enumTypeNode')) {
         visitor.visitEnumType = function visitEnumType(node) {
             return merge(node, [...visit(this)(node.size), ...node.variants.flatMap(visit(this))]);
         };
     }
 
-    if (castedNodeKeys.includes('enumStructVariantTypeNode')) {
+    if (keys.includes('enumStructVariantTypeNode')) {
         visitor.visitEnumStructVariantType = function visitEnumStructVariantType(node) {
             return merge(node, visit(this)(node.struct));
         };
     }
 
-    if (castedNodeKeys.includes('enumTupleVariantTypeNode')) {
+    if (keys.includes('enumTupleVariantTypeNode')) {
         visitor.visitEnumTupleVariantType = function visitEnumTupleVariantType(node) {
             return merge(node, visit(this)(node.tuple));
         };
     }
 
-    if (castedNodeKeys.includes('mapTypeNode')) {
+    if (keys.includes('mapTypeNode')) {
         visitor.visitMapType = function visitMapType(node) {
             return merge(node, [...visit(this)(node.count), ...visit(this)(node.key), ...visit(this)(node.value)]);
         };
     }
 
-    if (castedNodeKeys.includes('optionTypeNode')) {
+    if (keys.includes('optionTypeNode')) {
         visitor.visitOptionType = function visitOptionType(node) {
             return merge(node, [...visit(this)(node.prefix), ...visit(this)(node.item)]);
         };
     }
 
-    if (castedNodeKeys.includes('zeroableOptionTypeNode')) {
+    if (keys.includes('zeroableOptionTypeNode')) {
         visitor.visitZeroableOptionType = function visitZeroableOptionType(node) {
             return merge(node, [...visit(this)(node.item), ...(node.zeroValue ? visit(this)(node.zeroValue) : [])]);
         };
     }
 
-    if (castedNodeKeys.includes('remainderOptionTypeNode')) {
+    if (keys.includes('remainderOptionTypeNode')) {
         visitor.visitRemainderOptionType = function visitRemainderOptionType(node) {
             return merge(node, visit(this)(node.item));
         };
     }
 
-    if (castedNodeKeys.includes('booleanTypeNode')) {
+    if (keys.includes('booleanTypeNode')) {
         visitor.visitBooleanType = function visitBooleanType(node) {
             return merge(node, visit(this)(node.size));
         };
     }
 
-    if (castedNodeKeys.includes('setTypeNode')) {
+    if (keys.includes('setTypeNode')) {
         visitor.visitSetType = function visitSetType(node) {
             return merge(node, [...visit(this)(node.count), ...visit(this)(node.item)]);
         };
     }
 
-    if (castedNodeKeys.includes('structTypeNode')) {
+    if (keys.includes('structTypeNode')) {
         visitor.visitStructType = function visitStructType(node) {
             return merge(node, node.fields.flatMap(visit(this)));
         };
     }
 
-    if (castedNodeKeys.includes('structFieldTypeNode')) {
+    if (keys.includes('structFieldTypeNode')) {
         visitor.visitStructFieldType = function visitStructFieldType(node) {
             return merge(node, [
                 ...visit(this)(node.type),
@@ -171,115 +171,115 @@ export function mergeVisitor<TReturn, TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('tupleTypeNode')) {
+    if (keys.includes('tupleTypeNode')) {
         visitor.visitTupleType = function visitTupleType(node) {
             return merge(node, node.items.flatMap(visit(this)));
         };
     }
 
-    if (castedNodeKeys.includes('amountTypeNode')) {
+    if (keys.includes('amountTypeNode')) {
         visitor.visitAmountType = function visitAmountType(node) {
             return merge(node, visit(this)(node.number));
         };
     }
 
-    if (castedNodeKeys.includes('dateTimeTypeNode')) {
+    if (keys.includes('dateTimeTypeNode')) {
         visitor.visitDateTimeType = function visitDateTimeType(node) {
             return merge(node, visit(this)(node.number));
         };
     }
 
-    if (castedNodeKeys.includes('solAmountTypeNode')) {
+    if (keys.includes('solAmountTypeNode')) {
         visitor.visitSolAmountType = function visitSolAmountType(node) {
             return merge(node, visit(this)(node.number));
         };
     }
 
-    if (castedNodeKeys.includes('prefixedCountNode')) {
+    if (keys.includes('prefixedCountNode')) {
         visitor.visitPrefixedCount = function visitPrefixedCount(node) {
             return merge(node, visit(this)(node.prefix));
         };
     }
 
-    if (castedNodeKeys.includes('arrayValueNode')) {
+    if (keys.includes('arrayValueNode')) {
         visitor.visitArrayValue = function visitArrayValue(node) {
             return merge(node, node.items.flatMap(visit(this)));
         };
     }
 
-    if (castedNodeKeys.includes('constantValueNode')) {
+    if (keys.includes('constantValueNode')) {
         visitor.visitConstantValue = function visitConstantValue(node) {
             return merge(node, [...visit(this)(node.type), ...visit(this)(node.value)]);
         };
     }
 
-    if (castedNodeKeys.includes('enumValueNode')) {
+    if (keys.includes('enumValueNode')) {
         visitor.visitEnumValue = function visitEnumValue(node) {
             return merge(node, [...visit(this)(node.enum), ...(node.value ? visit(this)(node.value) : [])]);
         };
     }
 
-    if (castedNodeKeys.includes('mapValueNode')) {
+    if (keys.includes('mapValueNode')) {
         visitor.visitMapValue = function visitMapValue(node) {
             return merge(node, node.entries.flatMap(visit(this)));
         };
     }
 
-    if (castedNodeKeys.includes('mapEntryValueNode')) {
+    if (keys.includes('mapEntryValueNode')) {
         visitor.visitMapEntryValue = function visitMapEntryValue(node) {
             return merge(node, [...visit(this)(node.key), ...visit(this)(node.value)]);
         };
     }
 
-    if (castedNodeKeys.includes('setValueNode')) {
+    if (keys.includes('setValueNode')) {
         visitor.visitSetValue = function visitSetValue(node) {
             return merge(node, node.items.flatMap(visit(this)));
         };
     }
 
-    if (castedNodeKeys.includes('someValueNode')) {
+    if (keys.includes('someValueNode')) {
         visitor.visitSomeValue = function visitSomeValue(node) {
             return merge(node, visit(this)(node.value));
         };
     }
 
-    if (castedNodeKeys.includes('structValueNode')) {
+    if (keys.includes('structValueNode')) {
         visitor.visitStructValue = function visitStructValue(node) {
             return merge(node, node.fields.flatMap(visit(this)));
         };
     }
 
-    if (castedNodeKeys.includes('structFieldValueNode')) {
+    if (keys.includes('structFieldValueNode')) {
         visitor.visitStructFieldValue = function visitStructFieldValue(node) {
             return merge(node, visit(this)(node.value));
         };
     }
 
-    if (castedNodeKeys.includes('tupleValueNode')) {
+    if (keys.includes('tupleValueNode')) {
         visitor.visitTupleValue = function visitTupleValue(node) {
             return merge(node, node.items.flatMap(visit(this)));
         };
     }
 
-    if (castedNodeKeys.includes('constantPdaSeedNode')) {
+    if (keys.includes('constantPdaSeedNode')) {
         visitor.visitConstantPdaSeed = function visitConstantPdaSeed(node) {
             return merge(node, [...visit(this)(node.type), ...visit(this)(node.value)]);
         };
     }
 
-    if (castedNodeKeys.includes('variablePdaSeedNode')) {
+    if (keys.includes('variablePdaSeedNode')) {
         visitor.visitVariablePdaSeed = function visitVariablePdaSeed(node) {
             return merge(node, visit(this)(node.type));
         };
     }
 
-    if (castedNodeKeys.includes('resolverValueNode')) {
+    if (keys.includes('resolverValueNode')) {
         visitor.visitResolverValue = function visitResolverValue(node) {
             return merge(node, (node.dependsOn ?? []).flatMap(visit(this)));
         };
     }
 
-    if (castedNodeKeys.includes('conditionalValueNode')) {
+    if (keys.includes('conditionalValueNode')) {
         visitor.visitConditionalValue = function visitConditionalValue(node) {
             return merge(node, [
                 ...visit(this)(node.condition),
@@ -290,97 +290,97 @@ export function mergeVisitor<TReturn, TNodeKind extends NodeKind = NodeKind>(
         };
     }
 
-    if (castedNodeKeys.includes('pdaValueNode')) {
+    if (keys.includes('pdaValueNode')) {
         visitor.visitPdaValue = function visitPdaValue(node) {
             return merge(node, [...visit(this)(node.pda), ...node.seeds.flatMap(visit(this))]);
         };
     }
 
-    if (castedNodeKeys.includes('pdaSeedValueNode')) {
+    if (keys.includes('pdaSeedValueNode')) {
         visitor.visitPdaSeedValue = function visitPdaSeedValue(node) {
             return merge(node, visit(this)(node.value));
         };
     }
 
-    if (castedNodeKeys.includes('fixedSizeTypeNode')) {
+    if (keys.includes('fixedSizeTypeNode')) {
         visitor.visitFixedSizeType = function visitFixedSizeType(node) {
             return merge(node, visit(this)(node.type));
         };
     }
 
-    if (castedNodeKeys.includes('sizePrefixTypeNode')) {
+    if (keys.includes('sizePrefixTypeNode')) {
         visitor.visitSizePrefixType = function visitSizePrefixType(node) {
             return merge(node, [...visit(this)(node.prefix), ...visit(this)(node.type)]);
         };
     }
 
-    if (castedNodeKeys.includes('preOffsetTypeNode')) {
+    if (keys.includes('preOffsetTypeNode')) {
         visitor.visitPreOffsetType = function visitPreOffsetType(node) {
             return merge(node, visit(this)(node.type));
         };
     }
 
-    if (castedNodeKeys.includes('postOffsetTypeNode')) {
+    if (keys.includes('postOffsetTypeNode')) {
         visitor.visitPostOffsetType = function visitPostOffsetType(node) {
             return merge(node, visit(this)(node.type));
         };
     }
 
-    if (castedNodeKeys.includes('sentinelTypeNode')) {
+    if (keys.includes('sentinelTypeNode')) {
         visitor.visitSentinelType = function visitSentinelType(node) {
             return merge(node, [...visit(this)(node.sentinel), ...visit(this)(node.type)]);
         };
     }
 
-    if (castedNodeKeys.includes('hiddenPrefixTypeNode')) {
+    if (keys.includes('hiddenPrefixTypeNode')) {
         visitor.visitHiddenPrefixType = function visitHiddenPrefixType(node) {
             return merge(node, [...node.prefix.flatMap(visit(this)), ...visit(this)(node.type)]);
         };
     }
 
-    if (castedNodeKeys.includes('hiddenSuffixTypeNode')) {
+    if (keys.includes('hiddenSuffixTypeNode')) {
         visitor.visitHiddenSuffixType = function visitHiddenSuffixType(node) {
             return merge(node, [...visit(this)(node.type), ...node.suffix.flatMap(visit(this))]);
         };
     }
 
-    if (castedNodeKeys.includes('constantDiscriminatorNode')) {
+    if (keys.includes('constantDiscriminatorNode')) {
         visitor.visitConstantDiscriminator = function visitConstantDiscriminator(node) {
             return merge(node, visit(this)(node.constant));
         };
     }
 
-    if (castedNodeKeys.includes('accountLinkNode')) {
+    if (keys.includes('accountLinkNode')) {
         visitor.visitAccountLink = function visitAccountLink(node) {
             return merge(node, node.program ? visit(this)(node.program) : []);
         };
     }
 
-    if (castedNodeKeys.includes('definedTypeLinkNode')) {
+    if (keys.includes('definedTypeLinkNode')) {
         visitor.visitDefinedTypeLink = function visitDefinedTypeLink(node) {
             return merge(node, node.program ? visit(this)(node.program) : []);
         };
     }
 
-    if (castedNodeKeys.includes('instructionLinkNode')) {
+    if (keys.includes('instructionLinkNode')) {
         visitor.visitInstructionLink = function visitInstructionLink(node) {
             return merge(node, node.program ? visit(this)(node.program) : []);
         };
     }
 
-    if (castedNodeKeys.includes('instructionAccountLinkNode')) {
+    if (keys.includes('instructionAccountLinkNode')) {
         visitor.visitInstructionAccountLink = function visitInstructionAccountLink(node) {
             return merge(node, node.instruction ? visit(this)(node.instruction) : []);
         };
     }
 
-    if (castedNodeKeys.includes('instructionArgumentLinkNode')) {
+    if (keys.includes('instructionArgumentLinkNode')) {
         visitor.visitInstructionArgumentLink = function visitInstructionArgumentLink(node) {
             return merge(node, node.instruction ? visit(this)(node.instruction) : []);
         };
     }
 
-    if (castedNodeKeys.includes('pdaLinkNode')) {
+    if (keys.includes('pdaLinkNode')) {
         visitor.visitPdaLink = function visitPdaLink(node) {
             return merge(node, node.program ? visit(this)(node.program) : []);
         };

+ 3 - 3
packages/visitors-core/src/nonNullableIdentityVisitor.ts

@@ -1,10 +1,10 @@
-import { Node, NodeKind, REGISTERED_NODE_KINDS } from '@codama/nodes';
+import { Node, NodeKind } from '@codama/nodes';
 
 import { identityVisitor } from './identityVisitor';
 import { Visitor } from './visitor';
 
 export function nonNullableIdentityVisitor<TNodeKind extends NodeKind = NodeKind>(
-    nodeKeys: TNodeKind[] = REGISTERED_NODE_KINDS as TNodeKind[],
+    options: { keys?: TNodeKind[] } = {},
 ): Visitor<Node, TNodeKind> {
-    return identityVisitor<TNodeKind>(nodeKeys) as Visitor<Node, TNodeKind>;
+    return identityVisitor<TNodeKind>(options) as Visitor<Node, TNodeKind>;
 }

+ 2 - 2
packages/visitors-core/src/removeDocsVisitor.ts

@@ -3,8 +3,8 @@ import { NodeKind } from '@codama/nodes';
 import { interceptVisitor } from './interceptVisitor';
 import { nonNullableIdentityVisitor } from './nonNullableIdentityVisitor';
 
-export function removeDocsVisitor<TNodeKind extends NodeKind = NodeKind>(nodeKeys?: TNodeKind[]) {
-    return interceptVisitor(nonNullableIdentityVisitor(nodeKeys), (node, next) => {
+export function removeDocsVisitor<TNodeKind extends NodeKind = NodeKind>(options: { keys?: TNodeKind[] } = {}) {
+    return interceptVisitor(nonNullableIdentityVisitor(options), (node, next) => {
         if ('docs' in node) {
             return next({ ...node, docs: [] });
         }

+ 3 - 2
packages/visitors-core/src/staticVisitor.ts

@@ -4,10 +4,11 @@ import { getVisitFunctionName, Visitor } from './visitor';
 
 export function staticVisitor<TReturn, TNodeKind extends NodeKind = NodeKind>(
     fn: (node: Node) => TReturn,
-    nodeKeys: TNodeKind[] = REGISTERED_NODE_KINDS as TNodeKind[],
+    options: { keys?: TNodeKind[] } = {},
 ): Visitor<TReturn, TNodeKind> {
+    const keys = options.keys ?? (REGISTERED_NODE_KINDS as TNodeKind[]);
     const visitor = {} as Visitor<TReturn>;
-    nodeKeys.forEach(key => {
+    keys.forEach(key => {
         visitor[getVisitFunctionName(key)] = fn.bind(visitor);
     });
     return visitor;

+ 3 - 3
packages/visitors-core/src/topDownTransformerVisitor.ts

@@ -17,7 +17,7 @@ export type TopDownNodeTransformerWithSelector = {
 
 export function topDownTransformerVisitor<TNodeKind extends NodeKind = NodeKind>(
     transformers: (TopDownNodeTransformer | TopDownNodeTransformerWithSelector)[],
-    nodeKeys?: TNodeKind[],
+    options: { keys?: TNodeKind[]; stack?: NodeStack } = {},
 ): Visitor<Node | null, TNodeKind> {
     const transformerFunctions = transformers.map((transformer): TopDownNodeTransformer => {
         if (typeof transformer === 'function') return transformer;
@@ -27,9 +27,9 @@ export function topDownTransformerVisitor<TNodeKind extends NodeKind = NodeKind>
                 : node;
     });
 
-    const stack = new NodeStack();
+    const stack = options.stack ?? new NodeStack();
     return pipe(
-        identityVisitor(nodeKeys),
+        identityVisitor(options),
         v =>
             interceptVisitor(v, (node, next) => {
                 const appliedNode = transformerFunctions.reduce(

+ 4 - 2
packages/visitors-core/src/voidVisitor.ts

@@ -3,10 +3,12 @@ import type { NodeKind } from '@codama/nodes';
 import { mergeVisitor } from './mergeVisitor';
 import { Visitor } from './visitor';
 
-export function voidVisitor<TNodeKind extends NodeKind = NodeKind>(nodeKeys?: TNodeKind[]): Visitor<void, TNodeKind> {
+export function voidVisitor<TNodeKind extends NodeKind = NodeKind>(
+    options: { keys?: TNodeKind[] } = {},
+): Visitor<void, TNodeKind> {
     return mergeVisitor(
         () => undefined,
         () => undefined,
-        nodeKeys,
+        options,
     );
 }

+ 50 - 6
packages/visitors-core/test/bottomUpTransformerVisitor.test.ts

@@ -1,7 +1,22 @@
-import { isNode, numberTypeNode, publicKeyTypeNode, stringTypeNode, tupleTypeNode, TYPE_NODES } from '@codama/nodes';
+import {
+    definedTypeNode,
+    isNode,
+    numberTypeNode,
+    programNode,
+    publicKeyTypeNode,
+    stringTypeNode,
+    tupleTypeNode,
+    TYPE_NODES,
+} from '@codama/nodes';
 import { expect, test } from 'vitest';
 
-import { bottomUpTransformerVisitor, visit } from '../src';
+import {
+    BottomUpNodeTransformerWithSelector,
+    bottomUpTransformerVisitor,
+    findProgramNodeFromPath,
+    NodeStack,
+    visit,
+} from '../src';
 
 test('it can transform nodes into other nodes', () => {
     // Given the following tree.
@@ -48,10 +63,9 @@ test('it can create partial transformer visitors', () => {
 
     // And a transformer visitor that wraps every node into another tuple node
     // but that does not transform public key nodes.
-    const visitor = bottomUpTransformerVisitor(
-        [node => (isNode(node, TYPE_NODES) ? tupleTypeNode([node]) : node)],
-        ['tupleTypeNode', 'numberTypeNode'],
-    );
+    const visitor = bottomUpTransformerVisitor([node => (isNode(node, TYPE_NODES) ? tupleTypeNode([node]) : node)], {
+        keys: ['tupleTypeNode', 'numberTypeNode'],
+    });
 
     // When we visit the tree using that visitor.
     const result = visit(node, visitor);
@@ -107,3 +121,33 @@ test('it can transform nodes using multiple node selectors', () => {
         tupleTypeNode([numberTypeNode('u32'), tupleTypeNode([stringTypeNode('utf8'), publicKeyTypeNode()])]),
     );
 });
+
+test('it can start from an existing stack', () => {
+    // Given the following tuple node inside a program node.
+    const tuple = tupleTypeNode([numberTypeNode('u32'), publicKeyTypeNode()]);
+    const program = programNode({
+        definedTypes: [definedTypeNode({ name: 'myTuple', type: tuple })],
+        name: 'myProgram',
+        publicKey: '1111',
+    });
+
+    // And a transformer that removes all number nodes
+    // from programs whose public key is '1111'.
+    const transformer: BottomUpNodeTransformerWithSelector = {
+        select: ['[numberTypeNode]', path => findProgramNodeFromPath(path)?.publicKey === '1111'],
+        transform: () => null,
+    };
+
+    // When we visit the tuple with an existing stack that contains the program node.
+    const stack = new NodeStack([program, program.definedTypes[0]]);
+    const resultWithStack = visit(tuple, bottomUpTransformerVisitor([transformer], { stack }));
+
+    // Then we expect the number node to have been removed.
+    expect(resultWithStack).toStrictEqual(tupleTypeNode([publicKeyTypeNode()]));
+
+    // But when we visit the tuple without the stack.
+    const resultWithoutStack = visit(tuple, bottomUpTransformerVisitor([transformer]));
+
+    // Then we expect the number node to have been kept.
+    expect(resultWithoutStack).toStrictEqual(tuple);
+});

+ 3 - 4
packages/visitors-core/test/deleteNodesVisitor.test.ts

@@ -23,10 +23,9 @@ test('it can create partial visitors', () => {
 
     // And a visitor that deletes all number nodes and public key nodes
     // but does not support public key nodes.
-    const visitor = deleteNodesVisitor(
-        ['[numberTypeNode]', '[publicKeyTypeNode]'],
-        ['tupleTypeNode', 'numberTypeNode'],
-    );
+    const visitor = deleteNodesVisitor(['[numberTypeNode]', '[publicKeyTypeNode]'], {
+        keys: ['tupleTypeNode', 'numberTypeNode'],
+    });
 
     // When we visit the tree using that visitor.
     const result = visit(node, visitor);

+ 1 - 1
packages/visitors-core/test/extendVisitor.test.ts

@@ -50,7 +50,7 @@ test('it can visit itself using the exposed self argument', () => {
 
 test('it cannot extends nodes that are not supported by the base visitor', () => {
     // Given a base visitor that only supports tuple nodes.
-    const baseVisitor = voidVisitor(['tupleTypeNode']);
+    const baseVisitor = voidVisitor({ keys: ['tupleTypeNode'] });
 
     // Then we expect an error when we try to extend other nodes for that visitor.
     expect(() =>

+ 2 - 4
packages/visitors-core/test/getByteSizeVisitor.test.ts

@@ -23,9 +23,7 @@ import { expect, test } from 'vitest';
 import { getByteSizeVisitor, getRecordLinkablesVisitor, LinkableDictionary, NodeStack, visit, Visitor } from '../src';
 
 const expectSize = (node: Node, expectedSize: number | null) => {
-    expect(visit(node, getByteSizeVisitor(new LinkableDictionary(), new NodeStack()) as Visitor<number | null>)).toBe(
-        expectedSize,
-    );
+    expect(visit(node, getByteSizeVisitor(new LinkableDictionary()) as Visitor<number | null>)).toBe(expectedSize);
 };
 
 test.each([
@@ -138,7 +136,7 @@ test('it follows linked nodes using the correct paths', () => {
     visit(root, getRecordLinkablesVisitor(linkables));
 
     // When we visit the first defined type.
-    const visitor = getByteSizeVisitor(linkables, new NodeStack([root, programA]));
+    const visitor = getByteSizeVisitor(linkables, { stack: new NodeStack([root, programA]) });
     const result = visit(programA.definedTypes[0], visitor);
 
     // Then we expect the final linkable to be resolved.

+ 1 - 1
packages/visitors-core/test/identityVisitor.test.ts

@@ -42,7 +42,7 @@ test('it can create partial visitors', () => {
     // And an identity visitor that only supports 2 of these nodes
     // whilst using an interceptor to record the events that happened.
     const events: string[] = [];
-    const visitor = interceptVisitor(identityVisitor(['tupleTypeNode', 'numberTypeNode']), (node, next) => {
+    const visitor = interceptVisitor(identityVisitor({ keys: ['tupleTypeNode', 'numberTypeNode'] }), (node, next) => {
         events.push(`visiting:${node.kind}`);
         return next(node);
     });

+ 1 - 1
packages/visitors-core/test/mapVisitor.test.ts

@@ -27,7 +27,7 @@ test('it creates partial visitors from partial visitors', () => {
     const node = tupleTypeNode([numberTypeNode('u32'), publicKeyTypeNode()]);
 
     // And partial static visitor A that supports only 2 of these nodes.
-    const visitorA = staticVisitor(node => node.kind, ['tupleTypeNode', 'numberTypeNode']);
+    const visitorA = staticVisitor(node => node.kind, { keys: ['tupleTypeNode', 'numberTypeNode'] });
 
     // And a mapped visitor B that returns the number of characters returned by visitor A.
     const visitorB = mapVisitor(visitorA, value => value.length);

+ 1 - 1
packages/visitors-core/test/mergeVisitor.test.ts

@@ -44,7 +44,7 @@ test('it can create partial visitors', () => {
     const visitor = mergeVisitor(
         node => node.kind as string,
         (node, values) => `${node.kind}(${values.join(',')})`,
-        ['tupleTypeNode', 'numberTypeNode'],
+        { keys: ['tupleTypeNode', 'numberTypeNode'] },
     );
 
     // When we visit the tree using that visitor.

+ 1 - 1
packages/visitors-core/test/removeDocsVisitor.test.ts

@@ -76,7 +76,7 @@ test('it can create partial visitors', () => {
     ]);
 
     // And a remove docs visitor that only supports struct type nodes.
-    const visitor = removeDocsVisitor(['structTypeNode']);
+    const visitor = removeDocsVisitor({ keys: ['structTypeNode'] });
 
     // When we use it on our struct node.
     const result = visit(node, visitor);

+ 1 - 1
packages/visitors-core/test/staticVisitor.test.ts

@@ -21,7 +21,7 @@ test('it can create partial visitor', () => {
     const node = tupleTypeNode([numberTypeNode('u32'), publicKeyTypeNode()]);
 
     // And a static visitor that supports only 2 of these nodes.
-    const visitor = staticVisitor(node => node.kind, ['tupleTypeNode', 'numberTypeNode']);
+    const visitor = staticVisitor(node => node.kind, { keys: ['tupleTypeNode', 'numberTypeNode'] });
 
     // Then we expect the following results when visiting supported nodes.
     expect(visit(node, visitor)).toBe('tupleTypeNode');

+ 47 - 3
packages/visitors-core/test/topDownTransformerVisitor.test.ts

@@ -1,7 +1,21 @@
-import { assertIsNode, isNode, numberTypeNode, publicKeyTypeNode, tupleTypeNode } from '@codama/nodes';
+import {
+    assertIsNode,
+    definedTypeNode,
+    isNode,
+    numberTypeNode,
+    programNode,
+    publicKeyTypeNode,
+    tupleTypeNode,
+} from '@codama/nodes';
 import { expect, test } from 'vitest';
 
-import { topDownTransformerVisitor, visit } from '../src';
+import {
+    findProgramNodeFromPath,
+    NodeStack,
+    TopDownNodeTransformerWithSelector,
+    topDownTransformerVisitor,
+    visit,
+} from '../src';
 
 test('it can transform nodes to the same kind of node', () => {
     // Given the following tree.
@@ -57,7 +71,7 @@ test('it can create partial transformer visitors', () => {
                 },
             },
         ],
-        ['tupleTypeNode'],
+        { keys: ['tupleTypeNode'] },
     );
 
     // When we visit the tree using that visitor.
@@ -115,3 +129,33 @@ test('it can transform nodes using multiple node selectors', () => {
         tupleTypeNode([numberTypeNode('u32'), tupleTypeNode([numberTypeNode('u64'), publicKeyTypeNode()])]),
     );
 });
+
+test('it can start from an existing stack', () => {
+    // Given the following tuple node inside a program node.
+    const tuple = tupleTypeNode([numberTypeNode('u32'), publicKeyTypeNode()]);
+    const program = programNode({
+        definedTypes: [definedTypeNode({ name: 'myTuple', type: tuple })],
+        name: 'myProgram',
+        publicKey: '1111',
+    });
+
+    // And a transformer that removes all number nodes
+    // from programs whose public key is '1111'.
+    const transformer: TopDownNodeTransformerWithSelector = {
+        select: ['[numberTypeNode]', path => findProgramNodeFromPath(path)?.publicKey === '1111'],
+        transform: () => null,
+    };
+
+    // When we visit the tuple with an existing stack that contains the program node.
+    const stack = new NodeStack([program, program.definedTypes[0]]);
+    const resultWithStack = visit(tuple, topDownTransformerVisitor([transformer], { stack }));
+
+    // Then we expect the number node to have been removed.
+    expect(resultWithStack).toStrictEqual(tupleTypeNode([publicKeyTypeNode()]));
+
+    // But when we visit the tuple without the stack.
+    const resultWithoutStack = visit(tuple, topDownTransformerVisitor([transformer]));
+
+    // Then we expect the number node to have been kept.
+    expect(resultWithoutStack).toStrictEqual(tuple);
+});

+ 1 - 1
packages/visitors/src/fillDefaultPdaSeedValuesVisitor.ts

@@ -42,7 +42,7 @@ export function fillDefaultPdaSeedValuesVisitor(
     strictMode: boolean = false,
 ) {
     const instruction = getLastNodeFromPath(instructionPath);
-    return pipe(identityVisitor(INSTRUCTION_INPUT_VALUE_NODES), v =>
+    return pipe(identityVisitor({ keys: INSTRUCTION_INPUT_VALUE_NODES }), v =>
         extendVisitor(v, {
             visitPdaValue(node, { next }) {
                 const visitedNode = next(node);

+ 2 - 2
packages/visitors/src/setFixedAccountSizesVisitor.ts

@@ -19,13 +19,13 @@ export function setFixedAccountSizesVisitor() {
                 select: path => isNodePath(path, 'accountNode') && getLastNodeFromPath(path).size === undefined,
                 transform: (node, stack) => {
                     assertIsNode(node, 'accountNode');
-                    const size = visit(node.data, getByteSizeVisitor(linkables, stack));
+                    const size = visit(node.data, getByteSizeVisitor(linkables, { stack }));
                     if (size === null) return node;
                     return accountNode({ ...node, size }) as typeof node;
                 },
             },
         ],
-        ['rootNode', 'programNode', 'accountNode'],
+        { keys: ['rootNode', 'programNode', 'accountNode'] },
     );
 
     return pipe(visitor, v => recordLinkablesOnFirstVisitVisitor(v, linkables));

+ 1 - 1
packages/visitors/src/setInstructionAccountDefaultValuesVisitor.ts

@@ -163,7 +163,7 @@ export function setInstructionAccountDefaultValuesVisitor(rules: InstructionAcco
     }
 
     return pipe(
-        nonNullableIdentityVisitor(['rootNode', 'programNode', 'instructionNode']),
+        nonNullableIdentityVisitor({ keys: ['rootNode', 'programNode', 'instructionNode'] }),
         v =>
             extendVisitor(v, {
                 visitInstruction(node) {

+ 1 - 1
packages/visitors/src/transformDefinedTypesIntoAccountsVisitor.ts

@@ -2,7 +2,7 @@ import { accountNode, assertIsNode, programNode } from '@codama/nodes';
 import { extendVisitor, nonNullableIdentityVisitor, pipe } from '@codama/visitors-core';
 
 export function transformDefinedTypesIntoAccountsVisitor(definedTypes: string[]) {
-    return pipe(nonNullableIdentityVisitor(['rootNode', 'programNode']), v =>
+    return pipe(nonNullableIdentityVisitor({ keys: ['rootNode', 'programNode'] }), v =>
         extendVisitor(v, {
             visitProgram(program) {
                 const typesToExtract = program.definedTypes.filter(node => definedTypes.includes(node.name));