Browse Source

lang, ts: account header

Armani Ferrante 3 years ago
parent
commit
53477541c2

+ 2 - 0
CHANGELOG.md

@@ -21,6 +21,8 @@ incremented for features.
 
 * lang: rename `loader_account` module to `account_loader` module ([#1279](https://github.com/project-serum/anchor/pull/1279))
 * ts: `Coder` is now an interface and the existing class has been renamed to `BorshCoder`. This change allows the generation of Anchor clients for non anchor programs  ([#1259](https://github.com/project-serum/anchor/pull/1259/files)).
+* ts: `BorshAccountsCoder.accountDiscriminator` method has been replaced with `BorshAccountHeader.discriminator` ([#]()).
+* lang, ts: 8 byte account discriminator has been replaced with a versioned account header ([#]()).
 
 ## [0.20.1] - 2022-01-09
 

+ 10 - 0
client/src/lib.rs

@@ -291,12 +291,22 @@ fn handle_program_log<T: anchor_lang::Event + anchor_lang::AnchorDeserialize>(
         };
 
         let mut slice: &[u8] = &borsh_bytes[..];
+
+        #[cfg(feature = "deprecated-layout")]
         let disc: [u8; 8] = {
             let mut disc = [0; 8];
             disc.copy_from_slice(&borsh_bytes[..8]);
             slice = &slice[8..];
             disc
         };
+        #[cfg(not(feature = "deprecated-layout"))]
+        let disc: [u8; 4] = {
+            let mut disc = [0; 4];
+            disc.copy_from_slice(&borsh_bytes[2..6]);
+            slice = &slice[8..];
+            disc
+        };
+
         let mut event = None;
         if disc == T::discriminator() {
             let e: T = anchor_lang::AnchorDeserialize::deserialize(&mut slice)

+ 55 - 22
lang/attribute/account/src/lib.rs

@@ -88,6 +88,21 @@ pub fn account(
     let account_name = &account_strct.ident;
     let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl();
 
+    let owner_impl = {
+        if namespace.is_empty() {
+            quote! {
+                #[automatically_derived]
+                impl #impl_gen anchor_lang::Owner for #account_name #type_gen #where_clause {
+                    fn owner() -> Pubkey {
+                        crate::ID
+                    }
+                }
+            }
+        } else {
+            quote! {}
+        }
+    };
+
     let discriminator: proc_macro2::TokenStream = {
         // Namespace the discriminator to prevent collisions.
         let discriminator_preimage = {
@@ -99,25 +114,46 @@ pub fn account(
             }
         };
 
-        let mut discriminator = [0u8; 8];
-        discriminator.copy_from_slice(
-            &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8],
-        );
-        format!("{:?}", discriminator).parse().unwrap()
+        if cfg!(feature = "deprecated-layout") {
+            let mut discriminator = [0u8; 8];
+            discriminator.copy_from_slice(
+                &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8],
+            );
+            format!("{:?}", discriminator).parse().unwrap()
+        } else {
+            let mut discriminator = [0u8; 4];
+            discriminator.copy_from_slice(
+                &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..4],
+            );
+            format!("{:?}", discriminator).parse().unwrap()
+        }
     };
 
-    let owner_impl = {
-        if namespace.is_empty() {
+    let disc_bytes = {
+        if cfg!(feature = "deprecated-layout") {
             quote! {
-                #[automatically_derived]
-                impl #impl_gen anchor_lang::Owner for #account_name #type_gen #where_clause {
-                    fn owner() -> Pubkey {
-                        crate::ID
-                    }
+                let given_disc = &buf[..8];
+            }
+        } else {
+            quote! {
+                let given_disc = &buf[2..6];
+            }
+        }
+    };
+
+    let disc_fn = {
+        if cfg!(feature = "deprecated-layout") {
+            quote! {
+                fn discriminator() -> [u8; 8] {
+                    #discriminator
                 }
             }
         } else {
-            quote! {}
+            quote! {
+                fn discriminator() -> [u8; 4] {
+                    #discriminator
+                }
+            }
         }
     };
 
@@ -137,9 +173,7 @@ pub fn account(
 
                 #[automatically_derived]
                 impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
-                    fn discriminator() -> [u8; 8] {
-                        #discriminator
-                    }
+                    #disc_fn
                 }
 
                 // This trait is useful for clients deserializing accounts.
@@ -147,10 +181,11 @@ pub fn account(
                 #[automatically_derived]
                 impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
                     fn try_deserialize(buf: &mut &[u8]) -> std::result::Result<Self, ProgramError> {
-                        if buf.len() < #discriminator.len() {
+                        // Header is always 8 bytes.
+                        if buf.len() < 8 {
                             return Err(anchor_lang::__private::ErrorCode::AccountDiscriminatorNotFound.into());
                         }
-                        let given_disc = &buf[..8];
+                        #disc_bytes
                         if &#discriminator != given_disc {
                             return Err(anchor_lang::__private::ErrorCode::AccountDiscriminatorMismatch.into());
                         }
@@ -192,7 +227,7 @@ pub fn account(
                         if buf.len() < #discriminator.len() {
                             return Err(anchor_lang::__private::ErrorCode::AccountDiscriminatorNotFound.into());
                         }
-                        let given_disc = &buf[..8];
+                        #disc_bytes
                         if &#discriminator != given_disc {
                             return Err(anchor_lang::__private::ErrorCode::AccountDiscriminatorMismatch.into());
                         }
@@ -208,9 +243,7 @@ pub fn account(
 
                 #[automatically_derived]
                 impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
-                    fn discriminator() -> [u8; 8] {
-                        #discriminator
-                    }
+                    #disc_fn
                 }
 
                 #owner_impl

+ 40 - 2
lang/src/accounts/account_loader.rs

@@ -2,8 +2,8 @@
 
 use crate::error::ErrorCode;
 use crate::{
-    Accounts, AccountsClose, AccountsExit, Owner, ToAccountInfo, ToAccountInfos, ToAccountMetas,
-    ZeroCopy,
+    Accounts, AccountsClose, AccountsExit, Bump, Owner, ToAccountInfo, ToAccountInfos,
+    ToAccountMetas, ZeroCopy,
 };
 use arrayref::array_ref;
 use solana_program::account_info::AccountInfo;
@@ -126,7 +126,11 @@ impl<'info, T: ZeroCopy + Owner> AccountLoader<'info, T> {
         }
         let data: &[u8] = &acc_info.try_borrow_data()?;
         // Discriminator must match.
+        #[cfg(feature = "deprecated-layout")]
         let disc_bytes = array_ref![data, 0, 8];
+        #[cfg(not(feature = "deprecated-layout"))]
+        let disc_bytes = array_ref![data, 2, 4];
+
         if disc_bytes != &T::discriminator() {
             return Err(ErrorCode::AccountDiscriminatorMismatch.into());
         }
@@ -150,7 +154,11 @@ impl<'info, T: ZeroCopy + Owner> AccountLoader<'info, T> {
     pub fn load(&self) -> Result<Ref<T>, ProgramError> {
         let data = self.acc_info.try_borrow_data()?;
 
+        #[cfg(feature = "deprecated-layout")]
         let disc_bytes = array_ref![data, 0, 8];
+        #[cfg(not(feature = "deprecated-layout"))]
+        let disc_bytes = array_ref![data, 2, 4];
+
         if disc_bytes != &T::discriminator() {
             return Err(ErrorCode::AccountDiscriminatorMismatch.into());
         }
@@ -170,7 +178,11 @@ impl<'info, T: ZeroCopy + Owner> AccountLoader<'info, T> {
 
         let data = self.acc_info.try_borrow_mut_data()?;
 
+        #[cfg(feature = "deprecated-layout")]
         let disc_bytes = array_ref![data, 0, 8];
+        #[cfg(not(feature = "deprecated-layout"))]
+        let disc_bytes = array_ref![data, 2, 4];
+
         if disc_bytes != &T::discriminator() {
             return Err(ErrorCode::AccountDiscriminatorMismatch.into());
         }
@@ -192,9 +204,20 @@ impl<'info, T: ZeroCopy + Owner> AccountLoader<'info, T> {
         let data = self.acc_info.try_borrow_mut_data()?;
 
         // The discriminator should be zero, since we're initializing.
+        #[cfg(feature = "deprecated-layout")]
         let mut disc_bytes = [0u8; 8];
+        #[cfg(feature = "deprecated-layout")]
         disc_bytes.copy_from_slice(&data[..8]);
+        #[cfg(feature = "deprecated-layout")]
         let discriminator = u64::from_le_bytes(disc_bytes);
+
+        #[cfg(not(feature = "deprecated-layout"))]
+        let mut disc_bytes = [0u8; 4];
+        #[cfg(not(feature = "deprecated-layout"))]
+        disc_bytes.copy_from_slice(&data[2..6]);
+        #[cfg(not(feature = "deprecated-layout"))]
+        let discriminator = u32::from_le_bytes(disc_bytes);
+
         if discriminator != 0 {
             return Err(ErrorCode::AccountDiscriminatorAlreadySet.into());
         }
@@ -226,7 +249,12 @@ impl<'info, T: ZeroCopy + Owner> AccountsExit<'info> for AccountLoader<'info, T>
     // The account *cannot* be loaded when this is called.
     fn exit(&self, _program_id: &Pubkey) -> ProgramResult {
         let mut data = self.acc_info.try_borrow_mut_data()?;
+
+        #[cfg(feature = "deprecated-layout")]
         let dst: &mut [u8] = &mut data;
+        #[cfg(not(feature = "deprecated-layout"))]
+        let dst: &mut [u8] = &mut data[2..];
+
         let mut cursor = std::io::Cursor::new(dst);
         cursor.write_all(&T::discriminator()).unwrap();
         Ok(())
@@ -261,3 +289,13 @@ impl<'info, T: ZeroCopy + Owner> ToAccountInfos<'info> for AccountLoader<'info,
         vec![self.acc_info.clone()]
     }
 }
+
+#[cfg(not(feature = "deprecated-layout"))]
+impl<'info, T> Bump for T
+where
+    T: AsRef<AccountInfo<'info>>,
+{
+    fn seed(&self) -> u8 {
+        self.as_ref().data.borrow()[1]
+    }
+}

+ 29 - 0
lang/src/accounts/loader.rs

@@ -61,8 +61,13 @@ impl<'info, T: ZeroCopy> Loader<'info, T> {
             return Err(ErrorCode::AccountOwnedByWrongProgram.into());
         }
         let data: &[u8] = &acc_info.try_borrow_data()?;
+
         // Discriminator must match.
+        #[cfg(feature = "deprecated-layout")]
         let disc_bytes = array_ref![data, 0, 8];
+        #[cfg(not(feature = "deprecated-layout"))]
+        let disc_bytes = array_ref![data, 2, 4];
+
         if disc_bytes != &T::discriminator() {
             return Err(ErrorCode::AccountDiscriminatorMismatch.into());
         }
@@ -88,7 +93,11 @@ impl<'info, T: ZeroCopy> Loader<'info, T> {
     pub fn load(&self) -> Result<Ref<T>, ProgramError> {
         let data = self.acc_info.try_borrow_data()?;
 
+        #[cfg(feature = "deprecated-layout")]
         let disc_bytes = array_ref![data, 0, 8];
+        #[cfg(not(feature = "deprecated-layout"))]
+        let disc_bytes = array_ref![data, 2, 4];
+
         if disc_bytes != &T::discriminator() {
             return Err(ErrorCode::AccountDiscriminatorMismatch.into());
         }
@@ -107,7 +116,11 @@ impl<'info, T: ZeroCopy> Loader<'info, T> {
 
         let data = self.acc_info.try_borrow_mut_data()?;
 
+        #[cfg(feature = "deprecated-layout")]
         let disc_bytes = array_ref![data, 0, 8];
+        #[cfg(not(feature = "deprecated-layout"))]
+        let disc_bytes = array_ref![data, 2, 4];
+
         if disc_bytes != &T::discriminator() {
             return Err(ErrorCode::AccountDiscriminatorMismatch.into());
         }
@@ -130,9 +143,20 @@ impl<'info, T: ZeroCopy> Loader<'info, T> {
         let data = self.acc_info.try_borrow_mut_data()?;
 
         // The discriminator should be zero, since we're initializing.
+        #[cfg(feature = "deprecated-layout")]
         let mut disc_bytes = [0u8; 8];
+        #[cfg(feature = "deprecated-layout")]
         disc_bytes.copy_from_slice(&data[..8]);
+        #[cfg(feature = "deprecated-layout")]
         let discriminator = u64::from_le_bytes(disc_bytes);
+
+        #[cfg(not(feature = "deprecated-layout"))]
+        let mut disc_bytes = [0u8; 4];
+        #[cfg(not(feature = "deprecated-layout"))]
+        disc_bytes.copy_from_slice(&data[2..6]);
+        #[cfg(not(feature = "deprecated-layout"))]
+        let discriminator = u32::from_le_bytes(disc_bytes);
+
         if discriminator != 0 {
             return Err(ErrorCode::AccountDiscriminatorAlreadySet.into());
         }
@@ -166,7 +190,12 @@ impl<'info, T: ZeroCopy> AccountsExit<'info> for Loader<'info, T> {
     // The account *cannot* be loaded when this is called.
     fn exit(&self, _program_id: &Pubkey) -> ProgramResult {
         let mut data = self.acc_info.try_borrow_mut_data()?;
+
+        #[cfg(feature = "deprecated-layout")]
         let dst: &mut [u8] = &mut data;
+        #[cfg(not(feature = "deprecated-layout"))]
+        let dst: &mut [u8] = &mut data[2..];
+
         let mut cursor = std::io::Cursor::new(dst);
         cursor.write_all(&T::discriminator()).unwrap();
         Ok(())

+ 3 - 0
lang/src/lib.rs

@@ -198,7 +198,10 @@ pub trait EventData: AnchorSerialize + Discriminator {
 
 /// 8 byte unique identifier for a type.
 pub trait Discriminator {
+    #[cfg(feature = "deprecated-layout")]
     fn discriminator() -> [u8; 8];
+    #[cfg(not(feature = "deprecated-layout"))]
+    fn discriminator() -> [u8; 4];
 }
 
 /// Bump seed for program derived addresses.

+ 1 - 0
lang/syn/src/idl/file.rs

@@ -238,6 +238,7 @@ pub fn parse(
         .collect::<Vec<IdlConst>>();
 
     Ok(Some(Idl {
+        layout_version: "0.1.0".to_string(),
         version,
         name: p.name.to_string(),
         state,

+ 3 - 0
lang/syn/src/idl/mod.rs

@@ -6,6 +6,9 @@ pub mod pda;
 
 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 pub struct Idl {
+    // Version of the idl protocol.
+    pub layout_version: String,
+    // Version of the program.
     pub version: String,
     pub name: String,
     #[serde(skip_serializing_if = "Vec::is_empty", default)]

+ 7 - 0
lang/syn/src/lib.rs

@@ -26,6 +26,13 @@ pub(crate) mod hash;
 pub mod idl;
 pub mod parser;
 
+// Layout indices.
+pub const LAYOUT_VERSION: u8 = 0;
+pub const LAYOUT_VERSION_INDEX: u8 = 0;
+pub const LAYOUT_BUMP_INDEX: u8 = 1;
+pub const LAYOUT_DISCRIMINATOR_INDEX: u8 = 2;
+pub const LAYOUT_UNUSED_INDEX: u8 = 6;
+
 #[derive(Debug)]
 pub struct Program {
     pub state: Option<State>,

+ 72 - 13
ts/src/coder/borsh/accounts.ts

@@ -7,11 +7,18 @@ import { Idl, IdlTypeDef } from "../../idl.js";
 import { IdlCoder } from "./idl.js";
 import { AccountsCoder } from "../index.js";
 import { accountSize } from "../common.js";
+import * as features from "../../utils/features";
+
+/**
+ * Number of bytes of the account header.
+ */
+const ACCOUNT_HEADER_SIZE = 8;
 
 /**
  * Number of bytes of the account discriminator.
  */
-export const ACCOUNT_DISCRIMINATOR_SIZE = 8;
+const ACCOUNT_DISCRIMINATOR_SIZE = 4;
+const DEPRECATED_ACCOUNT_DISCRIMINATOR_SIZE = 4;
 
 /**
  * Encodes and decodes account objects.
@@ -49,22 +56,21 @@ export class BorshAccountsCoder<A extends string = string>
     }
     const len = layout.encode(account, buffer);
     let accountData = buffer.slice(0, len);
-    let discriminator = BorshAccountsCoder.accountDiscriminator(accountName);
-    return Buffer.concat([discriminator, accountData]);
+    let header = BorshAccountHeader.encode(accountName);
+    return Buffer.concat([header, accountData]);
   }
 
   public decode<T = any>(accountName: A, data: Buffer): T {
-    // Assert the account discriminator is correct.
-    const discriminator = BorshAccountsCoder.accountDiscriminator(accountName);
-    if (discriminator.compare(data.slice(0, 8))) {
+    const expectedDiscriminator = BorshAccountHeader.discriminator(accountName);
+		const givenDisc = BorshAccountHeader.parseDiscriminator(data);
+    if (expectedDiscriminator.compare(givenDisc)) {
       throw new Error("Invalid account discriminator");
     }
     return this.decodeUnchecked(accountName, data);
   }
 
   public decodeUnchecked<T = any>(accountName: A, ix: Buffer): T {
-    // Chop off the discriminator before decoding.
-    const data = ix.slice(ACCOUNT_DISCRIMINATOR_SIZE);
+    const data = ix.slice(BorshAccountHeader.size());   // Chop off the header.
     const layout = this.accountLayouts.get(accountName);
     if (!layout) {
       throw new Error(`Unknown account: ${accountName}`);
@@ -73,9 +79,9 @@ export class BorshAccountsCoder<A extends string = string>
   }
 
   public memcmp(accountName: A, appendData?: Buffer): any {
-    const discriminator = BorshAccountsCoder.accountDiscriminator(accountName);
+    const discriminator = BorshAccountHeader.discriminator(accountName);
     return {
-      offset: 0,
+      offset: BorshAccountHeader.discriminatorOffset(),
       bytes: bs58.encode(
         appendData ? Buffer.concat([discriminator, appendData]) : discriminator
       ),
@@ -84,18 +90,71 @@ export class BorshAccountsCoder<A extends string = string>
 
   public size(idlAccount: IdlTypeDef): number {
     return (
-      ACCOUNT_DISCRIMINATOR_SIZE + (accountSize(this.idl, idlAccount) ?? 0)
+      BorshAccountHeader.size() + (accountSize(this.idl, idlAccount) ?? 0)
     );
   }
+}
+
+export class BorshAccountHeader {
+	/**
+	 * Returns the default account header for an account with the given name.
+	 */
+	public static encode(accountName: string): Buffer {
+		if (features.isSet('deprecated-layout')) {
+			return BorshAccountHeader.discriminator(accountName);
+		} else {
+			return Buffer.concat([
+				Buffer.from([0]), // Version.
+				Buffer.from([0]), // Bump.
+				BorshAccountHeader.discriminator(accountName), // Disc.
+				Buffer.from([0, 0]), // Unused.
+			]);
+		}
+	}
 
   /**
    * Calculates and returns a unique 8 byte discriminator prepended to all anchor accounts.
    *
    * @param name The name of the account to calculate the discriminator.
    */
-  public static accountDiscriminator(name: string): Buffer {
+  public static discriminator(name: string): Buffer {
+		let size: number;
+		if (features.isSet("deprecated-layout")) {
+			size = DEPRECATED_ACCOUNT_DISCRIMINATOR_SIZE;
+		} else {
+			size = ACCOUNT_DISCRIMINATOR_SIZE;
+		}
     return Buffer.from(
       sha256.digest(`account:${camelcase(name, { pascalCase: true })}`)
-    ).slice(0, ACCOUNT_DISCRIMINATOR_SIZE);
+    ).slice(0, size);
   }
+
+	/**
+	 * Returns the account data index at which the discriminator starts.
+	 */
+	public static discriminatorOffset(): number {
+		if (features.isSet("deprecated-layout")) {
+			return 0;
+		} else {
+			return 2;
+		}
+	}
+
+	/**
+	 * Returns the byte size of the account header.
+	 */
+	public static size(): number {
+		return ACCOUNT_HEADER_SIZE;
+	}
+
+	/**
+	 * Returns the discriminator from the given account data.
+	 */
+	public static parseDiscriminator(data: Buffer): Buffer {
+		if (features.isSet("deprecated-layout")) {
+			return data.slice(0, 8);
+		} else {
+			return data.slice(2, 6);
+		}
+	}
 }

+ 1 - 1
ts/src/coder/borsh/index.ts

@@ -6,7 +6,7 @@ import { BorshStateCoder } from "./state.js";
 import { Coder } from "../index.js";
 
 export { BorshInstructionCoder } from "./instruction.js";
-export { BorshAccountsCoder, ACCOUNT_DISCRIMINATOR_SIZE } from "./accounts.js";
+export { BorshAccountsCoder, BorshAccountHeader } from "./accounts.js";
 export { BorshEventCoder, eventDiscriminator } from "./event.js";
 export { BorshStateCoder, stateDiscriminator } from "./state.js";
 

+ 1 - 0
ts/src/idl.ts

@@ -3,6 +3,7 @@ import { PublicKey } from "@solana/web3.js";
 import * as borsh from "@project-serum/borsh";
 
 export type Idl = {
+	layoutVersion: string;
   version: string;
   name: string;
   instructions: IdlInstruction[];

+ 2 - 0
ts/src/spl/token.ts

@@ -19,6 +19,7 @@ export function coder(): SplTokenCoder {
  * SplToken IDL.
  */
 export type SplToken = {
+	layoutVersion: "custom",
   version: "0.1.0";
   name: "spl_token";
   instructions: [
@@ -624,6 +625,7 @@ export type SplToken = {
 };
 
 export const IDL: SplToken = {
+	layoutVersion: "custom",
   version: "0.1.0",
   name: "spl_token",
   instructions: [

+ 4 - 1
ts/src/utils/features.ts

@@ -1,4 +1,7 @@
-const _AVAILABLE_FEATURES = new Set(["anchor-deprecated-state"]);
+const _AVAILABLE_FEATURES = new Set([
+	"anchor-deprecated-state",
+	"deprecated-layout",
+]);
 
 const _FEATURES = new Map();