Browse Source

Fix LinkNode paths for getByteSizeVisitor (#282)

This PR fixes an issue in the `getByteSizeVisitor` where complex link node paths would be incorrectly resolved due to the fact that the `NodeStack` would follow in invalid path in the tree. The new methods to save and restore `NodePaths` inside the `NodeStack` help us fix this.
Loris Leiva 1 year ago
parent
commit
1f52f00ba2

+ 5 - 0
.changeset/soft-beds-jam.md

@@ -0,0 +1,5 @@
+---
+'@codama/visitors-core': minor
+---
+
+Fix LinkNode paths for `getByteSizeVisitor`

+ 1 - 1
packages/renderers-js-umi/src/getRenderMapVisitor.ts

@@ -63,7 +63,6 @@ export type GetRenderMapOptions = {
 export function getRenderMapVisitor(options: GetRenderMapOptions = {}): Visitor<RenderMap> {
     const linkables = new LinkableDictionary();
     const stack = new NodeStack();
-    const byteSizeVisitor = getByteSizeVisitor(linkables, stack);
     let program: ProgramNode | null = null;
 
     const renderParentInstructions = options.renderParentInstructions ?? false;
@@ -100,6 +99,7 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}): Visitor<
         });
     const typeManifestVisitor = getTypeManifestVisitor();
     const resolvedInstructionInputVisitor = getResolvedInstructionInputsVisitor();
+    const byteSizeVisitor = getByteSizeVisitor(linkables, stack);
 
     function getInstructionAccountType(account: ResolvedInstructionAccount): string {
         if (account.isPda && account.isSigner === false) return 'Pda';

+ 8 - 6
packages/visitors-core/src/getByteSizeVisitor.ts

@@ -3,6 +3,7 @@ import { isNode, isScalarEnum, REGISTERED_TYPE_NODE_KINDS, RegisteredTypeNode }
 import { extendVisitor } from './extendVisitor';
 import { LinkableDictionary } from './LinkableDictionary';
 import { mergeVisitor } from './mergeVisitor';
+import { getLastNodeFromPath } from './NodePath';
 import { NodeStack } from './NodeStack';
 import { pipe } from './pipe';
 import { recordNodeStackVisitor } from './recordNodeStackVisitor';
@@ -69,11 +70,9 @@ export function getByteSizeVisitor(
                 visitDefinedTypeLink(node, { self }) {
                     // Fetch the linked type and return null if not found.
                     // The validator visitor will throw a proper error later on.
-                    // FIXME: Keep track of our own internal stack within this visitor (starting from a provided NodePath).
-                    const linkedDefinedType = linkables.get([...stack.getPath(), node]);
-                    if (!linkedDefinedType) {
-                        return null;
-                    }
+                    const linkedDefinedPath = linkables.getPath(stack.getPath(node.kind));
+                    if (!linkedDefinedPath) return null;
+                    const linkedDefinedType = getLastNodeFromPath(linkedDefinedPath);
 
                     // This prevents infinite recursion by using assuming
                     // cyclic types don't have a fixed size.
@@ -81,7 +80,10 @@ export function getByteSizeVisitor(
                         return null;
                     }
 
-                    return visit(linkedDefinedType, self);
+                    stack.pushPath(linkedDefinedPath);
+                    const result = visit(linkedDefinedType, self);
+                    stack.popPath();
+                    return result;
                 },
 
                 visitEnumEmptyVariantType() {

+ 41 - 1
packages/visitors-core/test/getByteSizeVisitor.test.ts

@@ -1,4 +1,6 @@
 import {
+    definedTypeLinkNode,
+    definedTypeNode,
     enumEmptyVariantTypeNode,
     enumStructVariantTypeNode,
     enumTupleVariantTypeNode,
@@ -7,7 +9,10 @@ import {
     Node,
     NumberFormat,
     numberTypeNode,
+    programLinkNode,
+    programNode,
     publicKeyTypeNode,
+    rootNode,
     stringTypeNode,
     structFieldTypeNode,
     structTypeNode,
@@ -15,7 +20,7 @@ import {
 } from '@codama/nodes';
 import { expect, test } from 'vitest';
 
-import { getByteSizeVisitor, LinkableDictionary, NodeStack, visit, Visitor } from '../src';
+import { getByteSizeVisitor, getRecordLinkablesVisitor, LinkableDictionary, NodeStack, visit, Visitor } from '../src';
 
 const expectSize = (node: Node, expectedSize: number | null) => {
     expect(visit(node, getByteSizeVisitor(new LinkableDictionary(), new NodeStack()) as Visitor<number | null>)).toBe(
@@ -104,3 +109,38 @@ test('it gets the size of variable data enums', () => {
         null,
     );
 });
+
+test('it follows linked nodes using the correct paths', () => {
+    // Given two link nodes designed so that the path would
+    // fail if we did not save and restored linked paths.
+    const programA = programNode({
+        definedTypes: [
+            definedTypeNode({
+                name: 'typeA',
+                type: definedTypeLinkNode('typeB1', programLinkNode('programB')),
+            }),
+        ],
+        name: 'programA',
+        publicKey: '1111',
+    });
+    const programB = programNode({
+        definedTypes: [
+            definedTypeNode({ name: 'typeB1', type: definedTypeLinkNode('typeB2') }),
+            definedTypeNode({ name: 'typeB2', type: numberTypeNode('u64') }),
+        ],
+        name: 'programB',
+        publicKey: '2222',
+    });
+    const root = rootNode(programA, [programB]);
+
+    // And given a recorded linkables dictionary.
+    const linkables = new LinkableDictionary();
+    visit(root, getRecordLinkablesVisitor(linkables));
+
+    // When we visit the first defined type.
+    const visitor = getByteSizeVisitor(linkables, new NodeStack([root, programA]));
+    const result = visit(programA.definedTypes[0], visitor);
+
+    // Then we expect the final linkable to be resolved.
+    expect(result).toBe(8);
+});

+ 3 - 11
packages/visitors/src/setFixedAccountSizesVisitor.ts

@@ -4,26 +4,22 @@ import {
     getLastNodeFromPath,
     isNodePath,
     LinkableDictionary,
-    NodeStack,
     pipe,
     recordLinkablesOnFirstVisitVisitor,
-    recordNodeStackVisitor,
     topDownTransformerVisitor,
     visit,
 } from '@codama/visitors-core';
 
 export function setFixedAccountSizesVisitor() {
     const linkables = new LinkableDictionary();
-    const stack = new NodeStack();
-    const byteSizeVisitor = getByteSizeVisitor(linkables, stack);
 
     const visitor = topDownTransformerVisitor(
         [
             {
                 select: path => isNodePath(path, 'accountNode') && getLastNodeFromPath(path).size === undefined,
-                transform: node => {
+                transform: (node, stack) => {
                     assertIsNode(node, 'accountNode');
-                    const size = visit(node.data, byteSizeVisitor);
+                    const size = visit(node.data, getByteSizeVisitor(linkables, stack));
                     if (size === null) return node;
                     return accountNode({ ...node, size }) as typeof node;
                 },
@@ -32,9 +28,5 @@ export function setFixedAccountSizesVisitor() {
         ['rootNode', 'programNode', 'accountNode'],
     );
 
-    return pipe(
-        visitor,
-        v => recordNodeStackVisitor(v, stack),
-        v => recordLinkablesOnFirstVisitVisitor(v, linkables),
-    );
+    return pipe(visitor, v => recordLinkablesOnFirstVisitVisitor(v, linkables));
 }