discriminators.test.ts 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. import { CodecAndValueVisitors, getCodecAndValueVisitors } from '@codama/dynamic-codecs';
  2. import {
  3. CODAMA_ERROR__DISCRIMINATOR_FIELD_HAS_NO_DEFAULT_VALUE,
  4. CODAMA_ERROR__DISCRIMINATOR_FIELD_NOT_FOUND,
  5. CodamaError,
  6. } from '@codama/errors';
  7. import {
  8. accountNode,
  9. constantDiscriminatorNode,
  10. constantValueNode,
  11. constantValueNodeFromBytes,
  12. definedTypeLinkNode,
  13. definedTypeNode,
  14. fieldDiscriminatorNode,
  15. fixedSizeTypeNode,
  16. numberTypeNode,
  17. numberValueNode,
  18. programLinkNode,
  19. programNode,
  20. rootNode,
  21. sizeDiscriminatorNode,
  22. stringTypeNode,
  23. structFieldTypeNode,
  24. structTypeNode,
  25. } from '@codama/nodes';
  26. import { getRecordLinkablesVisitor, LinkableDictionary, NodeStack, visit } from '@codama/visitors-core';
  27. import { beforeEach, describe, expect, test } from 'vitest';
  28. import { matchDiscriminators } from '../src/discriminators';
  29. import { hex } from './_setup';
  30. describe('matchDiscriminators', () => {
  31. let linkables: LinkableDictionary;
  32. let codecAndValueVisitors: CodecAndValueVisitors;
  33. beforeEach(() => {
  34. linkables = new LinkableDictionary();
  35. codecAndValueVisitors = getCodecAndValueVisitors(linkables);
  36. });
  37. test('it does not match if no discriminators are provided', () => {
  38. const result = matchDiscriminators(hex('ff'), [], structTypeNode([]), codecAndValueVisitors);
  39. expect(result).toBe(false);
  40. });
  41. describe('size discriminators', () => {
  42. test('it returns true if the size matches exactly', () => {
  43. const result = matchDiscriminators(
  44. hex('0102030405'),
  45. [sizeDiscriminatorNode(5)],
  46. structTypeNode([]),
  47. codecAndValueVisitors,
  48. );
  49. expect(result).toBe(true);
  50. });
  51. test('it returns false if the size is lower', () => {
  52. const result = matchDiscriminators(
  53. hex('01020304'),
  54. [sizeDiscriminatorNode(5)],
  55. structTypeNode([]),
  56. codecAndValueVisitors,
  57. );
  58. expect(result).toBe(false);
  59. });
  60. test('it returns false if the size is greater', () => {
  61. const result = matchDiscriminators(
  62. hex('010203040506'),
  63. [sizeDiscriminatorNode(5)],
  64. structTypeNode([]),
  65. codecAndValueVisitors,
  66. );
  67. expect(result).toBe(false);
  68. });
  69. });
  70. describe('constant discriminators', () => {
  71. test('it returns true if the bytes start with the provided constant', () => {
  72. const discriminator = constantDiscriminatorNode(constantValueNodeFromBytes('base16', 'ff'));
  73. const result = matchDiscriminators(
  74. hex('ff0102030405'),
  75. [discriminator],
  76. structTypeNode([]),
  77. codecAndValueVisitors,
  78. );
  79. expect(result).toBe(true);
  80. });
  81. test('it returns false if the bytes do not start with the provided constant', () => {
  82. const discriminator = constantDiscriminatorNode(constantValueNodeFromBytes('base16', 'ff'));
  83. const result = matchDiscriminators(
  84. hex('aa0102030405'),
  85. [discriminator],
  86. structTypeNode([]),
  87. codecAndValueVisitors,
  88. );
  89. expect(result).toBe(false);
  90. });
  91. test('it returns true if the bytes match with the provided constant at the given offset', () => {
  92. const discriminator = constantDiscriminatorNode(
  93. constantValueNodeFromBytes('base16', 'ff'),
  94. 3 /** offset */,
  95. );
  96. const result = matchDiscriminators(
  97. hex('010203ff0405'),
  98. [discriminator],
  99. structTypeNode([]),
  100. codecAndValueVisitors,
  101. );
  102. expect(result).toBe(true);
  103. });
  104. test('it returns false if the bytes do not match with the provided constant at the given offset', () => {
  105. const discriminator = constantDiscriminatorNode(
  106. constantValueNodeFromBytes('base16', 'ff'),
  107. 3 /** offset */,
  108. );
  109. const result = matchDiscriminators(
  110. hex('010203aa0405'),
  111. [discriminator],
  112. structTypeNode([]),
  113. codecAndValueVisitors,
  114. );
  115. expect(result).toBe(false);
  116. });
  117. test('it resolves link nodes correctly', () => {
  118. // Given two link nodes designed so that the path would
  119. // fail if we did not save and restored linked paths.
  120. const discriminator = constantDiscriminatorNode(
  121. constantValueNode(definedTypeLinkNode('typeB1', programLinkNode('programB')), numberValueNode(42)),
  122. );
  123. const programA = programNode({
  124. accounts: [accountNode({ discriminators: [discriminator], name: 'myAccount' })],
  125. definedTypes: [
  126. definedTypeNode({
  127. name: 'typeA',
  128. type: definedTypeLinkNode('typeB1', programLinkNode('programB')),
  129. }),
  130. ],
  131. name: 'programA',
  132. publicKey: '1111',
  133. });
  134. const programB = programNode({
  135. definedTypes: [
  136. definedTypeNode({ name: 'typeB1', type: definedTypeLinkNode('typeB2') }),
  137. definedTypeNode({ name: 'typeB2', type: numberTypeNode('u32') }),
  138. ],
  139. name: 'programB',
  140. publicKey: '2222',
  141. });
  142. const root = rootNode(programA, [programB]);
  143. // And given a recorded linkables dictionary.
  144. const linkables = new LinkableDictionary();
  145. visit(root, getRecordLinkablesVisitor(linkables));
  146. // And a stack keeping track of the current visited nodes.
  147. const stack = new NodeStack([root, programA, programA.accounts[0]]);
  148. codecAndValueVisitors = getCodecAndValueVisitors(linkables, { stack });
  149. // When we match the discriminator which should resolve to a u32 number equal to 42.
  150. const result = matchDiscriminators(
  151. hex('2a0000000102030405'),
  152. [discriminator],
  153. structTypeNode([]),
  154. codecAndValueVisitors,
  155. );
  156. // Then we expect the discriminator to match.
  157. expect(result).toBe(true);
  158. });
  159. });
  160. describe('field discriminators', () => {
  161. test('it returns true if the bytes start with the provided field default value', () => {
  162. const discriminator = fieldDiscriminatorNode('key');
  163. const fields = structTypeNode([
  164. structFieldTypeNode({
  165. defaultValue: numberValueNode(0xff),
  166. name: 'key',
  167. type: numberTypeNode('u8'),
  168. }),
  169. ]);
  170. const result = matchDiscriminators(hex('ff0102030405'), [discriminator], fields, codecAndValueVisitors);
  171. expect(result).toBe(true);
  172. });
  173. test('it returns false if the bytes do not start with the provided field default value', () => {
  174. const discriminator = fieldDiscriminatorNode('key');
  175. const fields = structTypeNode([
  176. structFieldTypeNode({
  177. defaultValue: numberValueNode(0xff),
  178. name: 'key',
  179. type: numberTypeNode('u8'),
  180. }),
  181. ]);
  182. const result = matchDiscriminators(hex('aa0102030405'), [discriminator], fields, codecAndValueVisitors);
  183. expect(result).toBe(false);
  184. });
  185. test('it returns true if the bytes match with the provided field default value at the given offset', () => {
  186. const discriminator = fieldDiscriminatorNode('key', 3 /** offset */);
  187. const fields = structTypeNode([
  188. structFieldTypeNode({ name: 'id', type: fixedSizeTypeNode(stringTypeNode('utf8'), 3) }),
  189. structFieldTypeNode({
  190. defaultValue: numberValueNode(0xff),
  191. name: 'key',
  192. type: numberTypeNode('u8'),
  193. }),
  194. ]);
  195. const result = matchDiscriminators(hex('010203ff0405'), [discriminator], fields, codecAndValueVisitors);
  196. expect(result).toBe(true);
  197. });
  198. test('it returns false if the bytes do not match with the provided field default value at the given offset', () => {
  199. const discriminator = fieldDiscriminatorNode('key', 3 /** offset */);
  200. const fields = structTypeNode([
  201. structFieldTypeNode({ name: 'id', type: fixedSizeTypeNode(stringTypeNode('utf8'), 3) }),
  202. structFieldTypeNode({
  203. defaultValue: numberValueNode(0xff),
  204. name: 'key',
  205. type: numberTypeNode('u8'),
  206. }),
  207. ]);
  208. const result = matchDiscriminators(hex('010203aa0405'), [discriminator], fields, codecAndValueVisitors);
  209. expect(result).toBe(false);
  210. });
  211. test('it throws an error if the discriminator field is not found', () => {
  212. const discriminator = fieldDiscriminatorNode('key');
  213. const fields = structTypeNode([]);
  214. expect(() =>
  215. matchDiscriminators(hex('0102030405'), [discriminator], fields, codecAndValueVisitors),
  216. ).toThrow(new CodamaError(CODAMA_ERROR__DISCRIMINATOR_FIELD_NOT_FOUND, { field: 'key' }));
  217. });
  218. test('it throws an error if the discriminator field does not have a default value', () => {
  219. const discriminator = fieldDiscriminatorNode('key');
  220. const fields = structTypeNode([
  221. structFieldTypeNode({
  222. name: 'key',
  223. type: numberTypeNode('u8'),
  224. }),
  225. ]);
  226. expect(() =>
  227. matchDiscriminators(hex('0102030405'), [discriminator], fields, codecAndValueVisitors),
  228. ).toThrow(new CodamaError(CODAMA_ERROR__DISCRIMINATOR_FIELD_HAS_NO_DEFAULT_VALUE, { field: 'key' }));
  229. });
  230. test('it resolves link nodes correctly', () => {
  231. // Given two link nodes designed so that the path would
  232. // fail if we did not save and restored linked paths.
  233. const discriminator = fieldDiscriminatorNode('key');
  234. const fields = structTypeNode([
  235. structFieldTypeNode({
  236. defaultValue: numberValueNode(42),
  237. name: 'key',
  238. type: definedTypeLinkNode('typeB1', programLinkNode('programB')),
  239. }),
  240. ]);
  241. const programA = programNode({
  242. accounts: [accountNode({ data: fields, discriminators: [discriminator], name: 'myAccount' })],
  243. definedTypes: [
  244. definedTypeNode({
  245. name: 'typeA',
  246. type: definedTypeLinkNode('typeB1', programLinkNode('programB')),
  247. }),
  248. ],
  249. name: 'programA',
  250. publicKey: '1111',
  251. });
  252. const programB = programNode({
  253. definedTypes: [
  254. definedTypeNode({ name: 'typeB1', type: definedTypeLinkNode('typeB2') }),
  255. definedTypeNode({ name: 'typeB2', type: numberTypeNode('u32') }),
  256. ],
  257. name: 'programB',
  258. publicKey: '2222',
  259. });
  260. const root = rootNode(programA, [programB]);
  261. // And given a recorded linkables dictionary.
  262. const linkables = new LinkableDictionary();
  263. visit(root, getRecordLinkablesVisitor(linkables));
  264. // And a stack keeping track of the current visited nodes.
  265. const stack = new NodeStack([root, programA, programA.accounts[0]]);
  266. codecAndValueVisitors = getCodecAndValueVisitors(linkables, { stack });
  267. // When we match the discriminator which should resolve to a u32 number equal to 42.
  268. const result = matchDiscriminators(
  269. hex('2a0000000102030405'),
  270. [discriminator],
  271. fields,
  272. codecAndValueVisitors,
  273. );
  274. // Then we expect the discriminator to match.
  275. expect(result).toBe(true);
  276. });
  277. });
  278. describe('multiple discriminators', () => {
  279. test('it returns true if all discriminators match', () => {
  280. const result = matchDiscriminators(
  281. hex('ff0102030405'),
  282. [constantDiscriminatorNode(constantValueNodeFromBytes('base16', 'ff')), sizeDiscriminatorNode(6)],
  283. structTypeNode([]),
  284. codecAndValueVisitors,
  285. );
  286. expect(result).toBe(true);
  287. });
  288. test('it returns false if any discriminator does not match', () => {
  289. const result = matchDiscriminators(
  290. hex('ff0102030405'),
  291. [constantDiscriminatorNode(constantValueNodeFromBytes('base16', 'ff')), sizeDiscriminatorNode(999)],
  292. structTypeNode([]),
  293. codecAndValueVisitors,
  294. );
  295. expect(result).toBe(false);
  296. });
  297. test('it can match on all discriminator types', () => {
  298. const result = matchDiscriminators(
  299. hex('aabb01020304'),
  300. [
  301. fieldDiscriminatorNode('key'),
  302. constantDiscriminatorNode(constantValueNodeFromBytes('base16', 'bb'), 1),
  303. sizeDiscriminatorNode(6),
  304. ],
  305. structTypeNode([
  306. structFieldTypeNode({
  307. defaultValue: numberValueNode(0xaa),
  308. name: 'key',
  309. type: numberTypeNode('u8'),
  310. }),
  311. ]),
  312. codecAndValueVisitors,
  313. );
  314. expect(result).toBe(true);
  315. });
  316. });
  317. });