Browse Source

Make zero copy safe by default, add `account(zero_copy(unsafe))` feature. (#2330)

Christian Kamm 2 years ago
parent
commit
ed2769ef28

+ 1 - 0
CHANGELOG.md

@@ -27,6 +27,7 @@ The minor version will be incremented upon a breaking change and the patch versi
 ### Breaking
 
 - lang: Remove `state` and `interface` attributes ([#2285](https://github.com/coral-xyz/anchor/pull/2285)).
+- lang: `account(zero_copy)` and `zero_copy` attributes now derive the `bytemuck::Pod` and `bytemuck::Zeroable` traits instead of using `unsafe impl` ([#2330](https://github.com/coral-xyz/anchor/pull/2330)). This imposes useful restrictions on the type, like not having padding bytes and all fields being `Pod` themselves. See [bytemuck::Pod](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html) for details. This change requires adding `bytemuck = { version = "1.4.0", features = ["derive", "min_const_generics"]}` to your `cargo.toml`. Legacy applications can still use `#[account(zero_copy(unsafe))]` and `#[zero_copy(unsafe)]` for the old behavior.
 - ts: Remove `createProgramAddressSync`, `findProgramAddressSync` (now available in `@solana/web3.js`) and update `associatedAddress` to be synchronous ([#2357](https://github.com/coral-xyz/anchor/pull/2357)).
 
 ## [0.26.0] - 2022-12-15

+ 101 - 9
lang/attribute/account/src/lib.rs

@@ -59,6 +59,9 @@ mod id;
 /// [`Pod`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html). Please review the
 /// [`safety`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html#safety)
 /// section before using.
+///
+/// Using `zero_copy` requires adding the following to your `cargo.toml` file:
+/// `bytemuck = { version = "1.4.0", features = ["derive", "min_const_generics"]}`
 #[proc_macro_attribute]
 pub fn account(
     args: proc_macro::TokenStream,
@@ -66,6 +69,7 @@ pub fn account(
 ) -> proc_macro::TokenStream {
     let mut namespace = "".to_string();
     let mut is_zero_copy = false;
+    let mut unsafe_bytemuck = false;
     let args_str = args.to_string();
     let args: Vec<&str> = args_str.split(',').collect();
     if args.len() > 2 {
@@ -80,6 +84,10 @@ pub fn account(
             .collect();
         if ns == "zero_copy" {
             is_zero_copy = true;
+            unsafe_bytemuck = false;
+        } else if ns == "zero_copy(unsafe)" {
+            is_zero_copy = true;
+            unsafe_bytemuck = true;
         } else {
             namespace = ns;
         }
@@ -123,16 +131,38 @@ pub fn account(
         }
     };
 
-    proc_macro::TokenStream::from({
-        if is_zero_copy {
+    let unsafe_bytemuck_impl = {
+        if unsafe_bytemuck {
             quote! {
-                #[zero_copy]
-                #account_strct
-
                 #[automatically_derived]
                 unsafe impl #impl_gen anchor_lang::__private::bytemuck::Pod for #account_name #type_gen #where_clause {}
                 #[automatically_derived]
                 unsafe impl #impl_gen anchor_lang::__private::bytemuck::Zeroable for #account_name #type_gen #where_clause {}
+            }
+        } else {
+            quote! {}
+        }
+    };
+
+    let bytemuck_derives = {
+        if !unsafe_bytemuck {
+            quote! {
+                #[zero_copy]
+            }
+        } else {
+            quote! {
+                #[zero_copy(unsafe)]
+            }
+        }
+    };
+
+    proc_macro::TokenStream::from({
+        if is_zero_copy {
+            quote! {
+                #bytemuck_derives
+                #account_strct
+
+                #unsafe_bytemuck_impl
 
                 #[automatically_derived]
                 impl #impl_gen anchor_lang::ZeroCopy for #account_name #type_gen #where_clause {}
@@ -276,18 +306,44 @@ pub fn derive_zero_copy_accessor(item: proc_macro::TokenStream) -> proc_macro::T
 /// A data structure that can be used as an internal field for a zero copy
 /// deserialized account, i.e., a struct marked with `#[account(zero_copy)]`.
 ///
-/// This is just a convenient alias for
+/// `#[zero_copy]` is just a convenient alias for
 ///
 /// ```ignore
 /// #[derive(Copy, Clone)]
-/// #[repr(packed)]
+/// #[derive(bytemuck::Zeroable)]
+/// #[derive(bytemuck::Pod)]
+/// #[repr(C)]
 /// struct MyStruct {...}
 /// ```
 #[proc_macro_attribute]
 pub fn zero_copy(
-    _args: proc_macro::TokenStream,
+    args: proc_macro::TokenStream,
     item: proc_macro::TokenStream,
 ) -> proc_macro::TokenStream {
+    let mut is_unsafe = false;
+    for arg in args.into_iter() {
+        match arg {
+            proc_macro::TokenTree::Ident(ident) => {
+                if ident.to_string() == "unsafe" {
+                    // `#[zero_copy(unsafe)]` maintains the old behaviour
+                    //
+                    // ```ignore
+                    // #[derive(Copy, Clone)]
+                    // #[repr(packed)]
+                    // struct MyStruct {...}
+                    // ```
+                    is_unsafe = true;
+                } else {
+                    // TODO: how to return a compile error with a span (can't return prase error because expected type TokenStream)
+                    panic!("expected single ident `unsafe`");
+                }
+            }
+            _ => {
+                panic!("expected single ident `unsafe`");
+            }
+        }
+    }
+
     let account_strct = parse_macro_input!(item as syn::ItemStruct);
 
     // Takes the first repr. It's assumed that more than one are not on the
@@ -298,13 +354,49 @@ pub fn zero_copy(
         .find(|attr| anchor_syn::parser::tts_to_string(&attr.path) == "repr");
 
     let repr = match attr {
+        // Users might want to manually specify repr modifiers e.g. repr(C, packed)
         Some(_attr) => quote! {},
-        None => quote! {#[repr(C)]},
+        None => {
+            if is_unsafe {
+                quote! {#[repr(packed)]}
+            } else {
+                quote! {#[repr(C)]}
+            }
+        }
+    };
+
+    let mut has_pod_attr = false;
+    let mut has_zeroable_attr = false;
+    for attr in account_strct.attrs.iter() {
+        let token_string = attr.tokens.to_string();
+        if token_string.contains("bytemuck :: Pod") {
+            has_pod_attr = true;
+        }
+        if token_string.contains("bytemuck :: Zeroable") {
+            has_zeroable_attr = true;
+        }
+    }
+
+    // Once the Pod derive macro is expanded the compiler has to use the local crate's
+    // bytemuck `::bytemuck::Pod` anyway, so we're no longer using the privately
+    // exported anchor bytemuck `__private::bytemuck`, so that there won't be any
+    // possible disparity between the anchor version and the local crate's version.
+    let pod = if has_pod_attr || is_unsafe {
+        quote! {}
+    } else {
+        quote! {#[derive(::bytemuck::Pod)]}
+    };
+    let zeroable = if has_zeroable_attr || is_unsafe {
+        quote! {}
+    } else {
+        quote! {#[derive(::bytemuck::Zeroable)]}
     };
 
     proc_macro::TokenStream::from(quote! {
         #[derive(anchor_lang::__private::ZeroCopyAccessor, Copy, Clone)]
         #repr
+        #pod
+        #zeroable
         #account_strct
     })
 }

+ 1 - 1
lang/tests/generics_test.rs

@@ -22,7 +22,7 @@ where
     pub associated: Account<'info, Associated<U>>,
 }
 
-#[account(zero_copy)]
+#[account(zero_copy(unsafe))]
 pub struct FooAccount<const N: usize> {
     pub data: WrappedU8Array<N>,
 }

+ 1 - 0
tests/chat/programs/chat/Cargo.toml

@@ -17,3 +17,4 @@ default = []
 
 [dependencies]
 anchor-lang = { path = "../../../../lang" }
+bytemuck = {version = "1.4.0", features = ["derive", "min_const_generics"]}

+ 1 - 0
tests/misc/programs/misc-optional/Cargo.toml

@@ -19,3 +19,4 @@ default = []
 anchor-lang = { path = "../../../../lang", features = ["init-if-needed"] }
 anchor-spl = { path = "../../../../spl" }
 spl-associated-token-account = "1.1.1"
+bytemuck = {version = "1.4.0", features = ["derive", "min_const_generics"]}

+ 1 - 0
tests/misc/programs/misc/Cargo.toml

@@ -19,3 +19,4 @@ default = []
 anchor-lang = { path = "../../../../lang", features = ["init-if-needed"] }
 anchor-spl = { path = "../../../../spl" }
 spl-associated-token-account = "1.1.1"
+bytemuck = {version = "1.4.0", features = ["derive", "min_const_generics"]}

+ 1 - 1
tests/zero-copy/programs/zero-copy/Cargo.toml

@@ -18,8 +18,8 @@ test-bpf = []
 
 [dependencies]
 anchor-lang = { path = "../../../../lang" }
+bytemuck = {version = "1.4.0", features = ["derive", "min_const_generics"]}
 
 [dev-dependencies]
 anchor-client = { path = "../../../../client", features = ["debug"] }
-bytemuck = "1.4.0"
 solana-program-test = "1.13.5"

+ 0 - 1
tests/zero-copy/programs/zero-copy/src/lib.rs

@@ -133,7 +133,6 @@ pub struct UpdateLargeAccount<'info> {
 }
 
 #[account(zero_copy)]
-#[repr(packed)]
 #[derive(Default)]
 pub struct Foo {
     pub authority: Pubkey,