coder.ts 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553
  1. import camelCase from "camelcase";
  2. import { snakeCase } from "snake-case";
  3. import { Layout } from "buffer-layout";
  4. import * as sha256 from "js-sha256";
  5. import * as borsh from "@project-serum/borsh";
  6. import {
  7. Idl,
  8. IdlField,
  9. IdlTypeDef,
  10. IdlEnumVariant,
  11. IdlType,
  12. IdlStateMethod,
  13. } from "./idl";
  14. import { IdlError } from "./error";
  15. /**
  16. * Number of bytes of the account discriminator.
  17. */
  18. export const ACCOUNT_DISCRIMINATOR_SIZE = 8;
  19. /**
  20. * Namespace for state method function signatures.
  21. */
  22. export const SIGHASH_STATE_NAMESPACE = "state";
  23. /**
  24. * Namespace for global instruction function signatures (i.e. functions
  25. * that aren't namespaced by the state or any of its trait implementations).
  26. */
  27. export const SIGHASH_GLOBAL_NAMESPACE = "global";
  28. /**
  29. * Coder provides a facade for encoding and decoding all IDL related objects.
  30. */
  31. export default class Coder {
  32. /**
  33. * Instruction coder.
  34. */
  35. readonly instruction: InstructionCoder;
  36. /**
  37. * Account coder.
  38. */
  39. readonly accounts: AccountsCoder;
  40. /**
  41. * Types coder.
  42. */
  43. readonly types: TypesCoder;
  44. /**
  45. * Coder for state structs.
  46. */
  47. readonly state: StateCoder;
  48. /**
  49. * Coder for events.
  50. */
  51. readonly events: EventCoder;
  52. constructor(idl: Idl) {
  53. this.instruction = new InstructionCoder(idl);
  54. this.accounts = new AccountsCoder(idl);
  55. this.types = new TypesCoder(idl);
  56. this.events = new EventCoder(idl);
  57. if (idl.state) {
  58. this.state = new StateCoder(idl);
  59. }
  60. }
  61. public sighash(nameSpace: string, ixName: string): Buffer {
  62. return sighash(nameSpace, ixName);
  63. }
  64. }
  65. /**
  66. * Encodes and decodes program instructions.
  67. */
  68. class InstructionCoder {
  69. /**
  70. * Instruction args layout. Maps namespaced method
  71. */
  72. private ixLayout: Map<string, Layout>;
  73. public constructor(idl: Idl) {
  74. this.ixLayout = InstructionCoder.parseIxLayout(idl);
  75. }
  76. /**
  77. * Encodes a program instruction.
  78. */
  79. public encode(ixName: string, ix: any) {
  80. return this._encode(SIGHASH_GLOBAL_NAMESPACE, ixName, ix);
  81. }
  82. /**
  83. * Encodes a program state instruction.
  84. */
  85. public encodeState(ixName: string, ix: any) {
  86. return this._encode(SIGHASH_STATE_NAMESPACE, ixName, ix);
  87. }
  88. private _encode(nameSpace: string, ixName: string, ix: any): Buffer {
  89. const buffer = Buffer.alloc(1000); // TODO: use a tighter buffer.
  90. const methodName = camelCase(ixName);
  91. const len = this.ixLayout.get(methodName).encode(ix, buffer);
  92. const data = buffer.slice(0, len);
  93. return Buffer.concat([sighash(nameSpace, ixName), data]);
  94. }
  95. private static parseIxLayout(idl: Idl): Map<string, Layout> {
  96. const stateMethods = idl.state ? idl.state.methods : [];
  97. const ixLayouts = stateMethods
  98. .map((m: IdlStateMethod) => {
  99. let fieldLayouts = m.args.map((arg: IdlField) => {
  100. return IdlCoder.fieldLayout(arg, idl.types);
  101. });
  102. const name = camelCase(m.name);
  103. return [name, borsh.struct(fieldLayouts, name)];
  104. })
  105. .concat(
  106. idl.instructions.map((ix) => {
  107. let fieldLayouts = ix.args.map((arg: IdlField) =>
  108. IdlCoder.fieldLayout(arg, idl.types)
  109. );
  110. const name = camelCase(ix.name);
  111. return [name, borsh.struct(fieldLayouts, name)];
  112. })
  113. );
  114. // @ts-ignore
  115. return new Map(ixLayouts);
  116. }
  117. }
  118. /**
  119. * Encodes and decodes account objects.
  120. */
  121. class AccountsCoder {
  122. /**
  123. * Maps account type identifier to a layout.
  124. */
  125. private accountLayouts: Map<string, Layout>;
  126. public constructor(idl: Idl) {
  127. if (idl.accounts === undefined) {
  128. this.accountLayouts = new Map();
  129. return;
  130. }
  131. const layouts: [string, Layout][] = idl.accounts.map((acc) => {
  132. return [acc.name, IdlCoder.typeDefLayout(acc, idl.types)];
  133. });
  134. this.accountLayouts = new Map(layouts);
  135. }
  136. public async encode<T = any>(
  137. accountName: string,
  138. account: T
  139. ): Promise<Buffer> {
  140. const buffer = Buffer.alloc(1000); // TODO: use a tighter buffer.
  141. const layout = this.accountLayouts.get(accountName);
  142. const len = layout.encode(account, buffer);
  143. let accountData = buffer.slice(0, len);
  144. let discriminator = await accountDiscriminator(accountName);
  145. return Buffer.concat([discriminator, accountData]);
  146. }
  147. public decode<T = any>(accountName: string, ix: Buffer): T {
  148. // Chop off the discriminator before decoding.
  149. const data = ix.slice(8);
  150. const layout = this.accountLayouts.get(accountName);
  151. return layout.decode(data);
  152. }
  153. }
  154. /**
  155. * Encodes and decodes user defined types.
  156. */
  157. class TypesCoder {
  158. /**
  159. * Maps account type identifier to a layout.
  160. */
  161. private layouts: Map<string, Layout>;
  162. public constructor(idl: Idl) {
  163. if (idl.types === undefined) {
  164. this.layouts = new Map();
  165. return;
  166. }
  167. const layouts = idl.types.map((acc) => {
  168. return [acc.name, IdlCoder.typeDefLayout(acc, idl.types)];
  169. });
  170. // @ts-ignore
  171. this.layouts = new Map(layouts);
  172. }
  173. public encode<T = any>(accountName: string, account: T): Buffer {
  174. const buffer = Buffer.alloc(1000); // TODO: use a tighter buffer.
  175. const layout = this.layouts.get(accountName);
  176. const len = layout.encode(account, buffer);
  177. return buffer.slice(0, len);
  178. }
  179. public decode<T = any>(accountName: string, ix: Buffer): T {
  180. const layout = this.layouts.get(accountName);
  181. return layout.decode(ix);
  182. }
  183. }
  184. class EventCoder {
  185. /**
  186. * Maps account type identifier to a layout.
  187. */
  188. private layouts: Map<string, Layout>;
  189. public constructor(idl: Idl) {
  190. if (idl.events === undefined) {
  191. this.layouts = new Map();
  192. return;
  193. }
  194. const layouts = idl.events.map((event) => {
  195. let eventTypeDef: IdlTypeDef = {
  196. name: event.name,
  197. type: {
  198. kind: "struct",
  199. fields: event.fields.map((f) => {
  200. return { name: f.name, type: f.type };
  201. }),
  202. },
  203. };
  204. return [event.name, IdlCoder.typeDefLayout(eventTypeDef, idl.types)];
  205. });
  206. // @ts-ignore
  207. this.layouts = new Map(layouts);
  208. }
  209. public encode<T = any>(eventName: string, account: T): Buffer {
  210. const buffer = Buffer.alloc(1000); // TODO: use a tighter buffer.
  211. const layout = this.layouts.get(eventName);
  212. const len = layout.encode(account, buffer);
  213. return buffer.slice(0, len);
  214. }
  215. public decode<T = any>(eventName: string, ix: Buffer): T {
  216. const layout = this.layouts.get(eventName);
  217. return layout.decode(ix);
  218. }
  219. }
  220. class StateCoder {
  221. private layout: Layout;
  222. public constructor(idl: Idl) {
  223. if (idl.state === undefined) {
  224. throw new Error("Idl state not defined.");
  225. }
  226. this.layout = IdlCoder.typeDefLayout(idl.state.struct, idl.types);
  227. }
  228. public async encode<T = any>(name: string, account: T): Promise<Buffer> {
  229. const buffer = Buffer.alloc(1000); // TODO: use a tighter buffer.
  230. const len = this.layout.encode(account, buffer);
  231. const disc = await stateDiscriminator(name);
  232. const accData = buffer.slice(0, len);
  233. return Buffer.concat([disc, accData]);
  234. }
  235. public decode<T = any>(ix: Buffer): T {
  236. // Chop off discriminator.
  237. const data = ix.slice(8);
  238. return this.layout.decode(data);
  239. }
  240. }
  241. class IdlCoder {
  242. public static fieldLayout(field: IdlField, types?: IdlTypeDef[]): Layout {
  243. const fieldName =
  244. field.name !== undefined ? camelCase(field.name) : undefined;
  245. switch (field.type) {
  246. case "bool": {
  247. return borsh.bool(fieldName);
  248. }
  249. case "u8": {
  250. return borsh.u8(fieldName);
  251. }
  252. case "i8": {
  253. return borsh.i8(fieldName);
  254. }
  255. case "u16": {
  256. return borsh.u16(fieldName);
  257. }
  258. case "i16": {
  259. return borsh.i16(fieldName);
  260. }
  261. case "u32": {
  262. return borsh.u32(fieldName);
  263. }
  264. case "i32": {
  265. return borsh.i32(fieldName);
  266. }
  267. case "u64": {
  268. return borsh.u64(fieldName);
  269. }
  270. case "i64": {
  271. return borsh.i64(fieldName);
  272. }
  273. case "u128": {
  274. return borsh.u128(fieldName);
  275. }
  276. case "i128": {
  277. return borsh.i128(fieldName);
  278. }
  279. case "bytes": {
  280. return borsh.vecU8(fieldName);
  281. }
  282. case "string": {
  283. return borsh.str(fieldName);
  284. }
  285. case "publicKey": {
  286. return borsh.publicKey(fieldName);
  287. }
  288. // TODO: all the other types that need to be exported by the borsh package.
  289. default: {
  290. // @ts-ignore
  291. if (field.type.vec) {
  292. return borsh.vec(
  293. IdlCoder.fieldLayout(
  294. {
  295. name: undefined,
  296. // @ts-ignore
  297. type: field.type.vec,
  298. },
  299. types
  300. ),
  301. fieldName
  302. );
  303. // @ts-ignore
  304. } else if (field.type.option) {
  305. return borsh.option(
  306. IdlCoder.fieldLayout(
  307. {
  308. name: undefined,
  309. // @ts-ignore
  310. type: field.type.option,
  311. },
  312. types
  313. ),
  314. fieldName
  315. );
  316. // @ts-ignore
  317. } else if (field.type.defined) {
  318. // User defined type.
  319. if (types === undefined) {
  320. throw new IdlError("User defined types not provided");
  321. }
  322. // @ts-ignore
  323. const filtered = types.filter((t) => t.name === field.type.defined);
  324. if (filtered.length !== 1) {
  325. throw new IdlError(`Type not found: ${JSON.stringify(field)}`);
  326. }
  327. return IdlCoder.typeDefLayout(filtered[0], types, fieldName);
  328. // @ts-ignore
  329. } else if (field.type.array) {
  330. // @ts-ignore
  331. let arrayTy = field.type.array[0];
  332. // @ts-ignore
  333. let arrayLen = field.type.array[1];
  334. let innerLayout = IdlCoder.fieldLayout(
  335. {
  336. name: undefined,
  337. type: arrayTy,
  338. },
  339. types
  340. );
  341. return borsh.array(innerLayout, arrayLen, fieldName);
  342. } else {
  343. throw new Error(`Not yet implemented: ${field}`);
  344. }
  345. }
  346. }
  347. }
  348. public static typeDefLayout(
  349. typeDef: IdlTypeDef,
  350. types: IdlTypeDef[],
  351. name?: string
  352. ): Layout {
  353. if (typeDef.type.kind === "struct") {
  354. const fieldLayouts = typeDef.type.fields.map((field) => {
  355. const x = IdlCoder.fieldLayout(field, types);
  356. return x;
  357. });
  358. return borsh.struct(fieldLayouts, name);
  359. } else if (typeDef.type.kind === "enum") {
  360. let variants = typeDef.type.variants.map((variant: IdlEnumVariant) => {
  361. const name = camelCase(variant.name);
  362. if (variant.fields === undefined) {
  363. return borsh.struct([], name);
  364. }
  365. // @ts-ignore
  366. const fieldLayouts = variant.fields.map((f: IdlField | IdlType) => {
  367. // @ts-ignore
  368. if (f.name === undefined) {
  369. throw new Error("Tuple enum variants not yet implemented.");
  370. }
  371. // @ts-ignore
  372. return IdlCoder.fieldLayout(f, types);
  373. });
  374. return borsh.struct(fieldLayouts, name);
  375. });
  376. if (name !== undefined) {
  377. // Buffer-layout lib requires the name to be null (on construction)
  378. // when used as a field.
  379. return borsh.rustEnum(variants).replicate(name);
  380. }
  381. return borsh.rustEnum(variants, name);
  382. } else {
  383. throw new Error(`Unknown type kint: ${typeDef}`);
  384. }
  385. }
  386. }
  387. // Calculates unique 8 byte discriminator prepended to all anchor accounts.
  388. export async function accountDiscriminator(name: string): Promise<Buffer> {
  389. // @ts-ignore
  390. return Buffer.from(sha256.digest(`account:${name}`)).slice(0, 8);
  391. }
  392. // Calculates unique 8 byte discriminator prepended to all anchor state accounts.
  393. export async function stateDiscriminator(name: string): Promise<Buffer> {
  394. // @ts-ignore
  395. return Buffer.from(sha256.digest(`account:${name}`)).slice(0, 8);
  396. }
  397. export function eventDiscriminator(name: string): Buffer {
  398. // @ts-ignore
  399. return Buffer.from(sha256.digest(`event:${name}`)).slice(0, 8);
  400. }
  401. // Returns the size of the type in bytes. For variable length types, just return
  402. // 1. Users should override this value in such cases.
  403. function typeSize(idl: Idl, ty: IdlType): number {
  404. switch (ty) {
  405. case "bool":
  406. return 1;
  407. case "u8":
  408. return 1;
  409. case "i8":
  410. return 1;
  411. case "i16":
  412. return 2;
  413. case "u16":
  414. return 2;
  415. case "u32":
  416. return 4;
  417. case "i32":
  418. return 4;
  419. case "u64":
  420. return 8;
  421. case "i64":
  422. return 8;
  423. case "u128":
  424. return 16;
  425. case "i128":
  426. return 16;
  427. case "bytes":
  428. return 1;
  429. case "string":
  430. return 1;
  431. case "publicKey":
  432. return 32;
  433. default:
  434. // @ts-ignore
  435. if (ty.vec !== undefined) {
  436. return 1;
  437. }
  438. // @ts-ignore
  439. if (ty.option !== undefined) {
  440. // @ts-ignore
  441. return 1 + typeSize(idl, ty.option);
  442. }
  443. // @ts-ignore
  444. if (ty.defined !== undefined) {
  445. // @ts-ignore
  446. const filtered = idl.types.filter((t) => t.name === ty.defined);
  447. if (filtered.length !== 1) {
  448. throw new IdlError(`Type not found: ${JSON.stringify(ty)}`);
  449. }
  450. let typeDef = filtered[0];
  451. return accountSize(idl, typeDef);
  452. }
  453. // @ts-ignore
  454. if (ty.array !== undefined) {
  455. // @ts-ignore
  456. let arrayTy = ty.array[0];
  457. // @ts-ignore
  458. let arraySize = ty.array[1];
  459. // @ts-ignore
  460. return typeSize(idl, arrayTy) * arraySize;
  461. }
  462. throw new Error(`Invalid type ${JSON.stringify(ty)}`);
  463. }
  464. }
  465. export function accountSize(
  466. idl: Idl,
  467. idlAccount: IdlTypeDef
  468. ): number | undefined {
  469. if (idlAccount.type.kind === "enum") {
  470. let variantSizes = idlAccount.type.variants.map(
  471. (variant: IdlEnumVariant) => {
  472. if (variant.fields === undefined) {
  473. return 0;
  474. }
  475. // @ts-ignore
  476. return (
  477. variant.fields
  478. // @ts-ignore
  479. .map((f: IdlField | IdlType) => {
  480. // @ts-ignore
  481. if (f.name === undefined) {
  482. throw new Error("Tuple enum variants not yet implemented.");
  483. }
  484. // @ts-ignore
  485. return typeSize(idl, f.type);
  486. })
  487. .reduce((a: number, b: number) => a + b)
  488. );
  489. }
  490. );
  491. return Math.max(...variantSizes) + 1;
  492. }
  493. if (idlAccount.type.fields === undefined) {
  494. return 0;
  495. }
  496. return idlAccount.type.fields
  497. .map((f) => typeSize(idl, f.type))
  498. .reduce((a, b) => a + b);
  499. }
  500. // Not technically sighash, since we don't include the arguments, as Rust
  501. // doesn't allow function overloading.
  502. function sighash(nameSpace: string, ixName: string): Buffer {
  503. let name = snakeCase(ixName);
  504. let preimage = `${nameSpace}::${name}`;
  505. // @ts-ignore
  506. return Buffer.from(sha256.digest(preimage)).slice(0, 8);
  507. }