Browse Source

lang: Require `Discriminator` trait impl when using the `zero` constraint (#3118)

acheron 1 year ago
parent
commit
293ee9142b

+ 1 - 0
CHANGELOG.md

@@ -43,6 +43,7 @@ The minor version will be incremented upon a breaking change and the patch versi
 - lang: Remove `EventData` trait ([#3083](https://github.com/coral-xyz/anchor/pull/3083)).
 - client: Remove `async_rpc` method ([#3053](https://github.com/coral-xyz/anchor/pull/3053)).
 - lang: Make discriminator type unsized ([#3098](https://github.com/coral-xyz/anchor/pull/3098)).
+- lang: Require `Discriminator` trait impl when using the `zero` constraint ([#3118](https://github.com/coral-xyz/anchor/pull/3118)).
 
 ## [0.30.1] - 2024-06-20
 

+ 6 - 4
lang/syn/src/codegen/accounts/constraints.rs

@@ -198,6 +198,9 @@ pub fn generate_constraint_init(
 }
 
 pub fn generate_constraint_zeroed(f: &Field, _c: &ConstraintZeroed) -> proc_macro2::TokenStream {
+    let account_ty = f.account_ty();
+    let discriminator = quote! { #account_ty::DISCRIMINATOR };
+
     let field = &f.ident;
     let name_str = field.to_string();
     let ty_decl = f.ty_decl(true);
@@ -205,10 +208,9 @@ pub fn generate_constraint_zeroed(f: &Field, _c: &ConstraintZeroed) -> proc_macr
     quote! {
         let #field: #ty_decl = {
             let mut __data: &[u8] = &#field.try_borrow_data()?;
-            let mut __disc_bytes = [0u8; 8];
-            __disc_bytes.copy_from_slice(&__data[..8]);
-            let __discriminator = u64::from_le_bytes(__disc_bytes);
-            if __discriminator != 0 {
+            let __disc = &__data[..#discriminator.len()];
+            let __has_disc = __disc.iter().any(|b| *b != 0);
+            if __has_disc {
                 return Err(anchor_lang::error::Error::from(anchor_lang::error::ErrorCode::ConstraintZero).with_account_name(#name_str));
             }
             #from_account_info

+ 9 - 0
lang/syn/src/parser/accounts/constraints.rs

@@ -1174,6 +1174,15 @@ impl<'ty> ConstraintGroupBuilder<'ty> {
         if self.init.is_some() {
             return Err(ParseError::new(c.span(), "init already provided"));
         }
+
+        // Require a known account type that implements the `Discriminator` trait so that we can
+        // get the discriminator length dynamically
+        if !matches!(&self.f_ty, Some(Ty::Account(_) | Ty::AccountLoader(_))) {
+            return Err(ParseError::new(
+                c.span(),
+                "`zero` constraint requires the type to implement the `Discriminator` trait",
+            ));
+        }
         self.zeroed.replace(c);
         Ok(())
     }