Browse Source

lang: Hash at compile time (#63)

Armani Ferrante 4 years ago
parent
commit
170e6f18d4

+ 1 - 1
lang/attribute/account/Cargo.toml

@@ -15,4 +15,4 @@ proc-macro2 = "1.0"
 quote = "1.0"
 syn = { version = "=1.0.57", features = ["full"] }
 anyhow = "1.0.32"
-anchor-syn = { path = "../../syn", version = "0.1.0" }
+anchor-syn = { path = "../../syn", version = "0.1.0", features = ["hash"] }

+ 17 - 27
lang/attribute/account/src/lib.rs

@@ -27,29 +27,26 @@ pub fn account(
     let account_strct = parse_macro_input!(input as syn::ItemStruct);
     let account_name = &account_strct.ident;
 
-    // Namespace the discriminator to prevent collisions.
-    let discriminator_preimage = {
-        if namespace == "" {
-            format!("account:{}", account_name.to_string())
-        } else {
-            format!("{}:{}", namespace, account_name.to_string())
-        }
+    let discriminator: proc_macro2::TokenStream = {
+        // Namespace the discriminator to prevent collisions.
+        let discriminator_preimage = {
+            if namespace == "" {
+                format!("account:{}", account_name.to_string())
+            } else {
+                format!("{}:{}", namespace, account_name.to_string())
+            }
+        };
+        let mut discriminator = [0u8; 8];
+        discriminator.copy_from_slice(
+            &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8],
+        );
+        format!("{:?}", discriminator).parse().unwrap()
     };
 
     let coder = quote! {
         impl anchor_lang::AccountSerialize for #account_name {
             fn try_serialize<W: std::io::Write>(&self, writer: &mut W) -> std::result::Result<(), ProgramError> {
-                // TODO: we shouldn't have to hash at runtime. However, rust
-                //       is not happy when trying to include solana-sdk from
-                //       the proc-macro crate.
-                let mut discriminator = [0u8; 8];
-                discriminator.copy_from_slice(
-                    &anchor_lang::solana_program::hash::hash(
-                        #discriminator_preimage.as_bytes(),
-                    ).to_bytes()[..8],
-                );
-
-                writer.write_all(&discriminator).map_err(|_| ProgramError::InvalidAccountData)?;
+                writer.write_all(&#discriminator).map_err(|_| ProgramError::InvalidAccountData)?;
                 AnchorSerialize::serialize(
                     self,
                     writer
@@ -62,18 +59,11 @@ pub fn account(
         impl anchor_lang::AccountDeserialize for #account_name {
 
             fn try_deserialize(buf: &mut &[u8]) -> std::result::Result<Self, ProgramError> {
-                let mut discriminator = [0u8; 8];
-                discriminator.copy_from_slice(
-                    &anchor_lang::solana_program::hash::hash(
-                        #discriminator_preimage.as_bytes(),
-                    ).to_bytes()[..8],
-                );
-
-                if buf.len() < discriminator.len() {
+                 if buf.len() < #discriminator.len() {
                     return Err(ProgramError::AccountDataTooSmall);
                 }
                 let given_disc = &buf[..8];
-                if &discriminator != given_disc {
+                if &#discriminator != given_disc {
                     return Err(ProgramError::InvalidInstructionData);
                 }
                 Self::try_deserialize_unchecked(buf)

+ 4 - 0
lang/syn/Cargo.toml

@@ -9,6 +9,7 @@ edition = "2018"
 
 [features]
 idl = []
+hash = []
 default = []
 
 [dependencies]
@@ -19,3 +20,6 @@ anyhow = "1.0.32"
 heck = "0.3.1"
 serde = { version = "1.0.118", features = ["derive"] }
 serde_json = "1.0"
+sha2 = "0.9.2"
+thiserror = "1.0"
+bs58 = "0.3.1"

+ 139 - 0
lang/syn/src/hash.rs

@@ -0,0 +1,139 @@
+// Utility hashing module copied from `solana_program::program::hash`, since we
+// can't import solana_program for compile time hashing for some reason.
+
+use serde::{Deserialize, Serialize};
+use sha2::{Digest, Sha256};
+use std::{convert::TryFrom, fmt, mem, str::FromStr};
+use thiserror::Error;
+
+pub const HASH_BYTES: usize = 32;
+#[derive(Serialize, Deserialize, Clone, Copy, Default, Eq, PartialEq, Ord, PartialOrd, Hash)]
+#[repr(transparent)]
+pub struct Hash(pub [u8; HASH_BYTES]);
+
+#[derive(Clone, Default)]
+pub struct Hasher {
+    hasher: Sha256,
+}
+
+impl Hasher {
+    pub fn hash(&mut self, val: &[u8]) {
+        self.hasher.update(val);
+    }
+    pub fn hashv(&mut self, vals: &[&[u8]]) {
+        for val in vals {
+            self.hash(val);
+        }
+    }
+    pub fn result(self) -> Hash {
+        // At the time of this writing, the sha2 library is stuck on an old version
+        // of generic_array (0.9.0). Decouple ourselves with a clone to our version.
+        Hash(<[u8; HASH_BYTES]>::try_from(self.hasher.finalize().as_slice()).unwrap())
+    }
+}
+
+impl AsRef<[u8]> for Hash {
+    fn as_ref(&self) -> &[u8] {
+        &self.0[..]
+    }
+}
+
+impl fmt::Debug for Hash {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        write!(f, "{}", bs58::encode(self.0).into_string())
+    }
+}
+
+impl fmt::Display for Hash {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        write!(f, "{}", bs58::encode(self.0).into_string())
+    }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq, Error)]
+pub enum ParseHashError {
+    #[error("string decoded to wrong size for hash")]
+    WrongSize,
+    #[error("failed to decoded string to hash")]
+    Invalid,
+}
+
+impl FromStr for Hash {
+    type Err = ParseHashError;
+
+    fn from_str(s: &str) -> Result<Self, Self::Err> {
+        let bytes = bs58::decode(s)
+            .into_vec()
+            .map_err(|_| ParseHashError::Invalid)?;
+        if bytes.len() != mem::size_of::<Hash>() {
+            Err(ParseHashError::WrongSize)
+        } else {
+            Ok(Hash::new(&bytes))
+        }
+    }
+}
+
+impl Hash {
+    pub fn new(hash_slice: &[u8]) -> Self {
+        Hash(<[u8; HASH_BYTES]>::try_from(hash_slice).unwrap())
+    }
+
+    pub const fn new_from_array(hash_array: [u8; HASH_BYTES]) -> Self {
+        Self(hash_array)
+    }
+
+    /// unique Hash for tests and benchmarks.
+    pub fn new_unique() -> Self {
+        use std::sync::atomic::{AtomicU64, Ordering};
+        static I: AtomicU64 = AtomicU64::new(1);
+
+        let mut b = [0u8; HASH_BYTES];
+        let i = I.fetch_add(1, Ordering::Relaxed);
+        b[0..8].copy_from_slice(&i.to_le_bytes());
+        Self::new(&b)
+    }
+
+    pub fn to_bytes(self) -> [u8; HASH_BYTES] {
+        self.0
+    }
+}
+
+/// Return a Sha256 hash for the given data.
+pub fn hashv(vals: &[&[u8]]) -> Hash {
+    // Perform the calculation inline, calling this from within a program is
+    // not supported
+    #[cfg(not(target_arch = "bpf"))]
+    {
+        let mut hasher = Hasher::default();
+        hasher.hashv(vals);
+        hasher.result()
+    }
+    // Call via a system call to perform the calculation
+    #[cfg(target_arch = "bpf")]
+    {
+        extern "C" {
+            fn sol_sha256(vals: *const u8, val_len: u64, hash_result: *mut u8) -> u64;
+        };
+        let mut hash_result = [0; HASH_BYTES];
+        unsafe {
+            sol_sha256(
+                vals as *const _ as *const u8,
+                vals.len() as u64,
+                &mut hash_result as *mut _ as *mut u8,
+            );
+        }
+        Hash::new_from_array(hash_result)
+    }
+}
+
+/// Return a Sha256 hash for the given data.
+pub fn hash(val: &[u8]) -> Hash {
+    hashv(&[val])
+}
+
+/// Return the hash of the given hash extended with the given value.
+pub fn extend_and_hash(id: &Hash, val: &[u8]) -> Hash {
+    let mut hash_data = id.as_ref().to_vec();
+    hash_data.extend_from_slice(val);
+    hash(&hash_data)
+}

+ 2 - 0
lang/syn/src/lib.rs

@@ -9,6 +9,8 @@ use quote::quote;
 use std::collections::HashMap;
 
 pub mod codegen;
+#[cfg(feature = "hash")]
+pub mod hash;
 #[cfg(feature = "idl")]
 pub mod idl;
 pub mod parser;