瀏覽代碼

Error handling

Armani Ferrante 4 年之前
父節點
當前提交
e636cf9721

+ 11 - 0
Cargo.lock

@@ -65,6 +65,16 @@ dependencies = [
  "syn 1.0.57",
 ]
 
+[[package]]
+name = "anchor-attribute-error"
+version = "0.0.0-alpha.0"
+dependencies = [
+ "anchor-syn",
+ "proc-macro2 1.0.24",
+ "quote 1.0.8",
+ "syn 1.0.57",
+]
+
 [[package]]
 name = "anchor-attribute-program"
 version = "0.0.0-alpha.0"
@@ -113,6 +123,7 @@ version = "0.0.0-alpha.0"
 dependencies = [
  "anchor-attribute-access-control",
  "anchor-attribute-account",
+ "anchor-attribute-error",
  "anchor-attribute-program",
  "anchor-derive-accounts",
  "serum-borsh",

+ 1 - 0
Cargo.toml

@@ -14,6 +14,7 @@ default = []
 [dependencies]
 anchor-attribute-access-control = { path = "./attribute/access-control", version = "0.0.0-alpha.0" }
 anchor-attribute-account = { path = "./attribute/account", version = "0.0.0-alpha.0" }
+anchor-attribute-error = { path = "./attribute/error" }
 anchor-attribute-program = { path = "./attribute/program", version = "0.0.0-alpha.0" }
 anchor-derive-accounts = { path = "./derive/accounts", version = "0.0.0-alpha.0" }
 serum-borsh = { version = "0.8.0-serum.1", features = ["serum-program"] }

+ 17 - 0
attribute/error/Cargo.toml

@@ -0,0 +1,17 @@
+[package]
+name = "anchor-attribute-error"
+version = "0.0.0-alpha.0"
+authors = ["Serum Foundation <foundation@projectserum.com>"]
+repository = "https://github.com/project-serum/anchor"
+license = "Apache-2.0"
+description = "Anchor attribute macro for creating error types"
+edition = "2018"
+
+[lib]
+proc-macro = true
+
+[dependencies]
+proc-macro2 = "1.0"
+quote = "1.0"
+syn = { version = "=1.0.57", features = ["full"] }
+anchor-syn = { path = "../../syn" }

+ 16 - 0
attribute/error/src/lib.rs

@@ -0,0 +1,16 @@
+extern crate proc_macro;
+
+use anchor_syn::codegen::error as error_codegen;
+use anchor_syn::parser::error as error_parser;
+use syn::parse_macro_input;
+
+/// Generates an error type from an error code enum.
+#[proc_macro_attribute]
+pub fn error(
+    _args: proc_macro::TokenStream,
+    input: proc_macro::TokenStream,
+) -> proc_macro::TokenStream {
+    let mut error_enum = parse_macro_input!(input as syn::ItemEnum);
+    let error = error_codegen::generate(error_parser::parse(&mut error_enum));
+    proc_macro::TokenStream::from(error)
+}

+ 2 - 0
examples/errors/Anchor.toml

@@ -0,0 +1,2 @@
+cluster = "localnet"
+wallet = "/home/armaniferrante/.config/solana/id.json"

+ 4 - 0
examples/errors/Cargo.toml

@@ -0,0 +1,4 @@
+[workspace]
+members = [
+    "programs/*"
+]

+ 17 - 0
examples/errors/programs/errors/Cargo.toml

@@ -0,0 +1,17 @@
+[package]
+name = "errors"
+version = "0.1.0"
+description = "Created with Anchor"
+edition = "2018"
+
+[lib]
+crate-type = ["cdylib", "lib"]
+name = "errors"
+
+[features]
+no-entrypoint = []
+cpi = ["no-entrypoint"]
+
+[dependencies]
+# anchor-lang = { git = "https://github.com/project-serum/anchor", features = ["derive"] }
+anchor-lang = { path = "/home/armaniferrante/Documents/code/src/github.com/project-serum/anchor", features = ["derive"] }

+ 2 - 0
examples/errors/programs/errors/Xargo.toml

@@ -0,0 +1,2 @@
+[target.bpfel-unknown-unknown.dependencies.std]
+features = []

+ 36 - 0
examples/errors/programs/errors/src/lib.rs

@@ -0,0 +1,36 @@
+#![feature(proc_macro_hygiene)]
+
+use anchor_lang::prelude::*;
+
+#[program]
+mod errors {
+    use super::*;
+    pub fn hello(ctx: Context<Hello>) -> Result<(), Error> {
+        Err(MyError::Hello.into())
+    }
+
+    pub fn hello_no_msg(ctx: Context<HelloNoMsg>) -> Result<(), Error> {
+        Err(MyError::HelloNoMsg.into())
+    }
+
+    pub fn hello_next(ctx: Context<HelloNext>) -> Result<(), Error> {
+        Err(MyError::HelloNext.into())
+    }
+}
+
+#[derive(Accounts)]
+pub struct Hello {}
+
+#[derive(Accounts)]
+pub struct HelloNoMsg {}
+
+#[derive(Accounts)]
+pub struct HelloNext {}
+
+#[error]
+pub enum MyError {
+    #[msg("This is an error message clients will automatically display")]
+    Hello,
+    HelloNoMsg = 123,
+    HelloNext,
+}

+ 47 - 0
examples/errors/tests/errors.js

@@ -0,0 +1,47 @@
+const assert = require("assert");
+//const anchor = require('@project-serum/anchor');
+const anchor = require("/home/armaniferrante/Documents/code/src/github.com/project-serum/anchor/ts");
+
+describe("errors", () => {
+  // Configure the client to use the local cluster.
+  anchor.setProvider(anchor.Provider.local());
+
+  const program = anchor.workspace.Errors;
+
+  it("Emits a Hello error", async () => {
+    try {
+      const tx = await program.rpc.hello();
+      assert.ok(false);
+    } catch (err) {
+      const errMsg =
+        "This is an error message clients will automatically display";
+      assert.equal(err.toString(), errMsg);
+      assert.equal(err.msg, errMsg);
+      assert.equal(err.code, 100);
+    }
+  });
+
+  it("Emits a HelloNoMsg error", async () => {
+    try {
+      const tx = await program.rpc.helloNoMsg();
+      assert.ok(false);
+    } catch (err) {
+      const errMsg = "HelloNoMsg";
+      assert.equal(err.toString(), errMsg);
+      assert.equal(err.msg, errMsg);
+      assert.equal(err.code, 100 + 123);
+    }
+  });
+
+  it("Emits a HelloNext error", async () => {
+    try {
+      const tx = await program.rpc.helloNext();
+      assert.ok(false);
+    } catch (err) {
+      const errMsg = "HelloNext";
+      assert.equal(err.toString(), errMsg);
+      assert.equal(err.msg, errMsg);
+      assert.equal(err.code, 100 + 124);
+    }
+  });
+});

+ 32 - 0
src/error.rs

@@ -0,0 +1,32 @@
+use solana_program::program_error::ProgramError;
+
+#[derive(thiserror::Error, Debug)]
+pub enum Error {
+    #[error(transparent)]
+    ProgramError(#[from] ProgramError),
+    #[error("{0:?}")]
+    ErrorCode(#[from] ErrorCode),
+}
+
+#[derive(Debug, Clone, Copy)]
+#[repr(u32)]
+pub enum ErrorCode {
+    WrongSerialization = 1,
+}
+
+impl std::fmt::Display for ErrorCode {
+    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
+        <Self as std::fmt::Debug>::fmt(self, fmt)
+    }
+}
+
+impl std::error::Error for ErrorCode {}
+
+impl std::convert::From<Error> for ProgramError {
+    fn from(e: Error) -> ProgramError {
+        match e {
+            Error::ProgramError(e) => e,
+            Error::ErrorCode(c) => ProgramError::Custom(c as u32),
+        }
+    }
+}

+ 5 - 1
src/lib.rs

@@ -30,6 +30,7 @@ use std::io::Write;
 mod account_info;
 mod context;
 mod cpi_account;
+mod error;
 mod program_account;
 mod sysvar;
 
@@ -39,10 +40,12 @@ pub use crate::program_account::ProgramAccount;
 pub use crate::sysvar::Sysvar;
 pub use anchor_attribute_access_control::access_control;
 pub use anchor_attribute_account::account;
+pub use anchor_attribute_error::error;
 pub use anchor_attribute_program::program;
 pub use anchor_derive_accounts::Accounts;
 /// Default serialization format for anchor instructions and accounts.
 pub use borsh::{BorshDeserialize as AnchorDeserialize, BorshSerialize as AnchorSerialize};
+pub use error::Error;
 pub use solana_program;
 
 /// A data structure of accounts that can be deserialized from the input
@@ -115,7 +118,7 @@ pub trait AccountDeserialize: Sized {
 /// All programs should include it via `anchor_lang::prelude::*;`.
 pub mod prelude {
     pub use super::{
-        access_control, account, program, AccountDeserialize, AccountSerialize, Accounts,
+        access_control, account, error, program, AccountDeserialize, AccountSerialize, Accounts,
         AccountsInit, AnchorDeserialize, AnchorSerialize, Context, CpiAccount, CpiContext,
         ProgramAccount, Sysvar, ToAccountInfo, ToAccountInfos, ToAccountMetas,
     };
@@ -138,4 +141,5 @@ pub mod prelude {
     pub use solana_program::sysvar::slot_history::SlotHistory;
     pub use solana_program::sysvar::stake_history::StakeHistory;
     pub use solana_program::sysvar::Sysvar as SolanaSysvar;
+    pub use thiserror;
 }

+ 39 - 0
syn/src/codegen/error.rs

@@ -0,0 +1,39 @@
+use crate::Error;
+use quote::quote;
+
+pub fn generate(error: Error) -> proc_macro2::TokenStream {
+    let error_enum = error.raw_enum;
+    let enum_name = &error.ident;
+    quote! {
+        #[derive(thiserror::Error, Debug)]
+        pub enum Error {
+            #[error(transparent)]
+            ProgramError(#[from] ProgramError),
+            #[error("{0:?}")]
+            ErrorCode(#[from] #enum_name),
+        }
+
+        #[derive(Debug, Clone, Copy)]
+        #[repr(u32)]
+        #error_enum
+
+        impl std::fmt::Display for #enum_name {
+            fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
+                <Self as std::fmt::Debug>::fmt(self, fmt)
+            }
+        }
+
+        impl std::error::Error for #enum_name {}
+
+        impl std::convert::From<Error> for ProgramError {
+            fn from(e: Error) -> ProgramError {
+            // Errors 0-100 are reserved for the framework.
+            let error_offset = 100u32;
+                match e {
+                    Error::ProgramError(e) => e,
+                    Error::ErrorCode(c) => ProgramError::Custom(c as u32 + error_offset),
+                }
+            }
+        }
+    }
+}

+ 1 - 0
syn/src/codegen/mod.rs

@@ -1,2 +1,3 @@
 pub mod accounts;
+pub mod error;
 pub mod program;

+ 10 - 0
syn/src/idl.rs

@@ -10,6 +10,8 @@ pub struct Idl {
     #[serde(skip_serializing_if = "Vec::is_empty", default)]
     pub types: Vec<IdlTypeDef>,
     #[serde(skip_serializing_if = "Option::is_none", default)]
+    pub errors: Option<Vec<IdlErrorCode>>,
+    #[serde(skip_serializing_if = "Option::is_none", default)]
     pub metadata: Option<serde_json::Value>,
 }
 
@@ -132,3 +134,11 @@ impl std::str::FromStr for IdlType {
         Ok(r)
     }
 }
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct IdlErrorCode {
+    pub code: u32,
+    pub name: String,
+    #[serde(skip_serializing_if = "Option::is_none", default)]
+    pub msg: Option<String>,
+}

+ 14 - 0
syn/src/lib.rs

@@ -245,3 +245,17 @@ pub enum ConstraintRentExempt {
     Enforce,
     Skip,
 }
+
+#[derive(Debug)]
+pub struct Error {
+    pub raw_enum: syn::ItemEnum,
+    pub ident: syn::Ident,
+    pub codes: Vec<ErrorCode>,
+}
+
+#[derive(Debug)]
+pub struct ErrorCode {
+    pub id: u32,
+    pub ident: syn::Ident,
+    pub msg: Option<String>,
+}

+ 69 - 0
syn/src/parser/error.rs

@@ -0,0 +1,69 @@
+use crate::{Error, ErrorCode};
+
+// Removes any internal #[msg] attributes, as they are inert.
+pub fn parse(error_enum: &mut syn::ItemEnum) -> Error {
+    let ident = error_enum.ident.clone();
+    let mut last_discriminant = 0;
+    let codes: Vec<ErrorCode> = error_enum
+        .variants
+        .iter_mut()
+        .map(|variant: &mut syn::Variant| {
+            let msg = parse_error_attribute(variant);
+            let ident = variant.ident.clone();
+            let id = match &variant.discriminant {
+                None => last_discriminant,
+                Some((_, disc)) => match disc {
+                    syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
+                        syn::Lit::Int(int) => {
+                            int.base10_parse::<u32>().expect("Must be a base 10 number")
+                        }
+                        _ => panic!("Invalid error discriminant"),
+                    },
+                    _ => panic!("Invalid error discriminant"),
+                },
+            };
+            last_discriminant = id + 1;
+
+            // Remove any attributes on the error variant.
+            variant.attrs = vec![];
+
+            ErrorCode { id, ident, msg }
+        })
+        .collect();
+
+    Error {
+        raw_enum: error_enum.clone(),
+        ident,
+        codes,
+    }
+}
+
+fn parse_error_attribute(variant: &syn::Variant) -> Option<String> {
+    let attrs = &variant.attrs;
+    match attrs.len() {
+        0 => None,
+        1 => {
+            let attr = &attrs[0];
+            let attr_str = attr.path.segments[0].ident.to_string();
+            if &attr_str != "msg" {
+                panic!("Use msg to specify error strings");
+            }
+
+            let mut tts = attr.tokens.clone().into_iter();
+            let g_stream = match tts.next().expect("Must have a token group") {
+                proc_macro2::TokenTree::Group(g) => g.stream(),
+                _ => panic!("Invalid syntax"),
+            };
+
+            let msg = match g_stream.into_iter().next() {
+                None => panic!("Must specify a message string"),
+                Some(msg) => msg.to_string().replace("\"", ""),
+            };
+
+            Some(msg)
+        }
+        _ => {
+            panic!("Too many attributes found. Use `msg` to specify error strings");
+        }
+    }
+}

+ 40 - 2
syn/src/parser/file.rs

@@ -1,6 +1,5 @@
 use crate::idl::*;
-use crate::parser::accounts;
-use crate::parser::program;
+use crate::parser::{accounts, error, program};
 use crate::AccountsStruct;
 use anyhow::Result;
 use heck::MixedCase;
@@ -22,6 +21,17 @@ pub fn parse(filename: impl AsRef<Path>) -> Result<Idl> {
     let f = syn::parse_file(&src).expect("Unable to parse file");
 
     let p = program::parse(parse_program_mod(&f));
+    let errors = parse_error_enum(&f).map(|mut e| {
+        error::parse(&mut e)
+            .codes
+            .iter()
+            .map(|code| IdlErrorCode {
+                code: 100 + code.id,
+                name: code.ident.to_string(),
+                msg: code.msg.clone(),
+            })
+            .collect::<Vec<IdlErrorCode>>()
+    });
 
     let accs = parse_accounts(&f);
 
@@ -83,6 +93,7 @@ pub fn parse(filename: impl AsRef<Path>) -> Result<Idl> {
         instructions,
         types,
         accounts,
+        errors,
         metadata: None,
     })
 }
@@ -117,6 +128,33 @@ fn parse_program_mod(f: &syn::File) -> syn::ItemMod {
     mods[0].clone()
 }
 
+fn parse_error_enum(f: &syn::File) -> Option<syn::ItemEnum> {
+    f.items
+        .iter()
+        .filter_map(|i| match i {
+            syn::Item::Enum(item_enum) => {
+                let attrs = item_enum
+                    .attrs
+                    .iter()
+                    .filter_map(|attr| {
+                        let segment = attr.path.segments.last().unwrap();
+                        if segment.ident.to_string() == "error" {
+                            return Some(attr);
+                        }
+                        None
+                    })
+                    .collect::<Vec<_>>();
+                match attrs.len() {
+                    0 => None,
+                    1 => Some(item_enum),
+                    _ => panic!("Invalid syntax: one error attribute allowed"),
+                }
+            }
+            _ => None,
+        })
+        .next()
+        .cloned()
+}
 // Parse all structs implementing the `Accounts` trait.
 fn parse_accounts(f: &syn::File) -> HashMap<String, AccountsStruct> {
     f.items

+ 1 - 0
syn/src/parser/mod.rs

@@ -1,4 +1,5 @@
 pub mod accounts;
+pub mod error;
 #[cfg(feature = "idl")]
 pub mod file;
 pub mod program;

+ 11 - 0
ts/src/error.ts

@@ -1 +1,12 @@
 export class IdlError extends Error {}
+
+// An error from a user defined program.
+export class ProgramError extends Error {
+  constructor(readonly code: number, readonly msg: string, ...params: any[]) {
+    super(...params);
+  }
+
+  public toString(): string {
+    return this.msg;
+  }
+}

+ 7 - 0
ts/src/idl.ts

@@ -4,6 +4,7 @@ export type Idl = {
   instructions: IdlInstruction[];
   accounts?: IdlTypeDef[];
   types?: IdlTypeDef[];
+  errors?: IdlErrorCode[];
 };
 
 export type IdlInstruction = {
@@ -74,3 +75,9 @@ export type IdlTypeDefined = {
 type IdlEnumVariant = {
   // todo
 };
+
+type IdlErrorCode = {
+  code: number;
+  name: string;
+  msg?: string;
+};

+ 52 - 6
ts/src/rpc.ts

@@ -9,7 +9,7 @@ import {
 } from "@solana/web3.js";
 import { sha256 } from "crypto-hash";
 import { Idl, IdlInstruction } from "./idl";
-import { IdlError } from "./error";
+import { IdlError, ProgramError } from "./error";
 import Coder from "./coder";
 import { getProvider } from "./";
 
@@ -95,11 +95,12 @@ export class RpcFactory {
     const rpcs: Rpcs = {};
     const ixFns: Ixs = {};
     const accountFns: Accounts = {};
+    const idlErrors = parseIdlErrors(idl);
     idl.instructions.forEach((idlIx) => {
       // Function to create a raw `TransactionInstruction`.
       const ix = RpcFactory.buildIx(idlIx, coder, programId);
       // Function to invoke an RPC against a cluster.
-      const rpc = RpcFactory.buildRpc(idlIx, ix);
+      const rpc = RpcFactory.buildRpc(idlIx, ix, idlErrors);
 
       const name = camelCase(idlIx.name);
       rpcs[name] = rpc;
@@ -175,7 +176,11 @@ export class RpcFactory {
     return ix;
   }
 
-  private static buildRpc(idlIx: IdlInstruction, ixFn: IxFn): RpcFn {
+  private static buildRpc(
+    idlIx: IdlInstruction,
+    ixFn: IxFn,
+    idlErrors: Map<number, string>
+  ): RpcFn {
     const rpc = async (...args: any[]): Promise<TransactionSignature> => {
       const [_, ctx] = splitArgsAndCtx(idlIx, [...args]);
       const tx = new Transaction();
@@ -187,15 +192,56 @@ export class RpcFactory {
       if (provider === null) {
         throw new Error("Provider not found");
       }
-
-      const txSig = await provider.send(tx, ctx.signers, ctx.options);
-      return txSig;
+      try {
+        const txSig = await provider.send(tx, ctx.signers, ctx.options);
+        return txSig;
+      } catch (err) {
+        let translatedErr = translateError(idlErrors, err);
+        if (err === null) {
+          throw err;
+        }
+        throw translatedErr;
+      }
     };
 
     return rpc;
   }
 }
 
+function translateError(
+  idlErrors: Map<number, string>,
+  err: any
+): Error | null {
+  // TODO: don't rely on the error string. web3.js should preserve the error
+  //       code information instead of giving us an untyped string.
+  let components = err.toString().split("custom program error: ");
+  if (components.length === 2) {
+    try {
+      const errorCode = parseInt(components[1]);
+      let errorMsg = idlErrors.get(errorCode);
+      if (errorMsg === undefined) {
+        // Unexpected error code so just throw the untranslated error.
+        throw err;
+      }
+      return new ProgramError(errorCode, errorMsg);
+    } catch (parseErr) {
+      // Unable to parse the error. Just return the untranslated error.
+      return null;
+    }
+  }
+}
+
+function parseIdlErrors(idl: Idl): Map<number, string> {
+  const errors = new Map();
+  if (idl.errors) {
+    idl.errors.forEach((e) => {
+      let msg = e.msg ?? e.name;
+      errors.set(e.code, msg);
+    });
+  }
+  return errors;
+}
+
 function splitArgsAndCtx(
   idlIx: IdlInstruction,
   args: any[]