borsh_encoding.rs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. // SPDX-License-Identifier: Apache-2.0
  2. use anchor_syn::idl::{IdlType, IdlTypeDefinition, IdlTypeDefinitionTy};
  3. use byte_slice_cast::AsByteSlice;
  4. use num_bigint::{BigInt, Sign};
  5. use num_traits::ToPrimitive;
  6. use std::cmp::Ordering;
  7. /// This is the token that should be used for each function call in Solana runtime tests
  8. #[derive(Debug, PartialEq, Clone)]
  9. pub enum BorshToken {
  10. Address([u8; 32]),
  11. Int { width: u16, value: BigInt },
  12. Uint { width: u16, value: BigInt },
  13. FixedBytes(Vec<u8>),
  14. Bytes(Vec<u8>),
  15. Bool(bool),
  16. String(String),
  17. FixedArray(Vec<BorshToken>),
  18. Array(Vec<BorshToken>),
  19. Tuple(Vec<BorshToken>),
  20. }
  21. impl BorshToken {
  22. /// Encode the parameter into the buffer
  23. pub fn encode(&self, buffer: &mut Vec<u8>) {
  24. match self {
  25. BorshToken::Address(data) => {
  26. buffer.extend_from_slice(data);
  27. }
  28. BorshToken::Uint { width, value } => {
  29. encode_uint(*width, value, buffer);
  30. }
  31. BorshToken::Int { width, value } => {
  32. encode_int(*width, value, buffer);
  33. }
  34. BorshToken::FixedBytes(data) => {
  35. buffer.extend_from_slice(data);
  36. }
  37. BorshToken::Bytes(data) => {
  38. let len = data.len() as u32;
  39. buffer.extend_from_slice(&len.to_le_bytes());
  40. buffer.extend_from_slice(data);
  41. }
  42. BorshToken::Bool(value) => {
  43. buffer.push(*value as u8);
  44. }
  45. BorshToken::String(data) => {
  46. let len = data.len() as u32;
  47. buffer.extend_from_slice(&len.to_le_bytes());
  48. buffer.extend_from_slice(data.as_byte_slice());
  49. }
  50. BorshToken::Tuple(data) | BorshToken::FixedArray(data) => {
  51. for item in data {
  52. item.encode(buffer);
  53. }
  54. }
  55. BorshToken::Array(arr) => {
  56. let len = arr.len() as u32;
  57. buffer.extend_from_slice(&len.to_le_bytes());
  58. for item in arr {
  59. item.encode(buffer);
  60. }
  61. }
  62. }
  63. }
  64. pub fn into_string(self) -> Option<String> {
  65. match self {
  66. BorshToken::String(value) => Some(value),
  67. _ => None,
  68. }
  69. }
  70. pub fn into_array(self) -> Option<Vec<BorshToken>> {
  71. match self {
  72. BorshToken::Array(value) => Some(value),
  73. _ => None,
  74. }
  75. }
  76. pub fn into_fixed_bytes(self) -> Option<Vec<u8>> {
  77. match self {
  78. BorshToken::FixedBytes(value) => Some(value),
  79. BorshToken::FixedArray(vec) => {
  80. let mut response: Vec<u8> = Vec::with_capacity(vec.len());
  81. for elem in vec {
  82. match elem {
  83. BorshToken::Uint { width, value } => {
  84. assert_eq!(width, 8);
  85. response.push(value.to_u8().unwrap());
  86. }
  87. _ => unreachable!("Array cannot be converted to fixed bytes"),
  88. }
  89. }
  90. Some(response)
  91. }
  92. BorshToken::Address(value) => Some(value.to_vec()),
  93. _ => None,
  94. }
  95. }
  96. pub fn into_bytes(self) -> Option<Vec<u8>> {
  97. match self {
  98. BorshToken::Bytes(value) => Some(value),
  99. _ => None,
  100. }
  101. }
  102. pub fn into_bigint(self) -> Option<BigInt> {
  103. match self {
  104. BorshToken::Uint { value, .. } => Some(value),
  105. BorshToken::Int { value, .. } => Some(value),
  106. _ => None,
  107. }
  108. }
  109. pub fn unwrap_tuple(self) -> Vec<BorshToken> {
  110. match self {
  111. BorshToken::Tuple(vec) => vec,
  112. _ => panic!("This is not a tuple"),
  113. }
  114. }
  115. pub fn uint8_fixed_array(vec: Vec<u8>) -> BorshToken {
  116. let mut array: Vec<BorshToken> = Vec::with_capacity(vec.len());
  117. for item in &vec {
  118. array.push(BorshToken::Uint {
  119. width: 8,
  120. value: BigInt::from(*item),
  121. });
  122. }
  123. BorshToken::FixedArray(array)
  124. }
  125. }
  126. /// Encode a signed integer
  127. fn encode_int(width: u16, value: &BigInt, buffer: &mut Vec<u8>) {
  128. match width {
  129. 1..=8 => {
  130. let val = value.to_i8().unwrap();
  131. buffer.extend_from_slice(&val.to_le_bytes());
  132. }
  133. 9..=16 => {
  134. let val = value.to_i16().unwrap();
  135. buffer.extend_from_slice(&val.to_le_bytes());
  136. }
  137. 17..=32 => {
  138. let val = value.to_i32().unwrap();
  139. buffer.extend_from_slice(&val.to_le_bytes());
  140. }
  141. 33..=64 => {
  142. let val = value.to_i64().unwrap();
  143. buffer.extend_from_slice(&val.to_le_bytes());
  144. }
  145. 65..=128 => {
  146. let val = value.to_i128().unwrap();
  147. buffer.extend_from_slice(&val.to_le_bytes());
  148. }
  149. 129..=256 => {
  150. let mut val = value.to_signed_bytes_le();
  151. let byte_width = 32;
  152. match val.len().cmp(&byte_width) {
  153. Ordering::Greater => {
  154. while val.len() > byte_width {
  155. val.pop();
  156. }
  157. }
  158. Ordering::Less => {
  159. if value.sign() == Sign::Minus {
  160. val.extend(vec![255; byte_width - val.len()]);
  161. } else {
  162. val.extend(vec![0; byte_width - val.len()]);
  163. }
  164. }
  165. Ordering::Equal => (),
  166. }
  167. buffer.extend_from_slice(&val);
  168. }
  169. _ => unreachable!("bit width not supported"),
  170. }
  171. }
  172. /// Encode an unsigned integer
  173. fn encode_uint(width: u16, value: &BigInt, buffer: &mut Vec<u8>) {
  174. match width {
  175. 1..=8 => {
  176. let val = value.to_u8().unwrap();
  177. buffer.push(val);
  178. }
  179. 9..=16 => {
  180. let val = value.to_u16().unwrap();
  181. buffer.extend_from_slice(&val.to_le_bytes());
  182. }
  183. 17..=32 => {
  184. let val = value.to_u32().unwrap();
  185. buffer.extend_from_slice(&val.to_le_bytes());
  186. }
  187. 33..=64 => {
  188. let val = value.to_u64().unwrap();
  189. buffer.extend_from_slice(&val.to_le_bytes());
  190. }
  191. 65..=128 => {
  192. let val = value.to_u128().unwrap();
  193. buffer.extend_from_slice(&val.to_le_bytes());
  194. }
  195. 129..=256 => {
  196. let mut val = value.to_signed_bytes_le();
  197. let bytes_width = 32;
  198. match val.len().cmp(&bytes_width) {
  199. Ordering::Greater => {
  200. while val.len() > bytes_width {
  201. val.pop();
  202. }
  203. }
  204. Ordering::Less => {
  205. val.extend(vec![0; bytes_width - val.len()]);
  206. }
  207. Ordering::Equal => (),
  208. }
  209. buffer.extend_from_slice(&val);
  210. }
  211. _ => unreachable!("bit width not supported"),
  212. }
  213. }
  214. /// Encode the arguments of a function
  215. pub fn encode_arguments(args: &[BorshToken]) -> Vec<u8> {
  216. let mut encoded: Vec<u8> = Vec::new();
  217. for item in args {
  218. item.encode(&mut encoded);
  219. }
  220. encoded
  221. }
  222. /// Decode a parameter at a given offset
  223. pub fn decode_at_offset(
  224. data: &[u8],
  225. offset: &mut usize,
  226. ty: &IdlType,
  227. custom_types: &[IdlTypeDefinition],
  228. ) -> BorshToken {
  229. match ty {
  230. IdlType::PublicKey => {
  231. let read = &data[*offset..(*offset + 32)];
  232. (*offset) += 32;
  233. BorshToken::Address(<[u8; 32]>::try_from(read).unwrap())
  234. }
  235. IdlType::U8
  236. | IdlType::U16
  237. | IdlType::U32
  238. | IdlType::U64
  239. | IdlType::U128
  240. | IdlType::U256 => {
  241. let decoding_width = integer_byte_width(ty);
  242. let bigint =
  243. BigInt::from_bytes_le(Sign::Plus, &data[*offset..(*offset + decoding_width)]);
  244. (*offset) += decoding_width;
  245. BorshToken::Uint {
  246. width: (decoding_width * 8) as u16,
  247. value: bigint,
  248. }
  249. }
  250. IdlType::I8
  251. | IdlType::I16
  252. | IdlType::I32
  253. | IdlType::I64
  254. | IdlType::I128
  255. | IdlType::I256 => {
  256. let decoding_width = integer_byte_width(ty);
  257. let bigint = BigInt::from_signed_bytes_le(&data[*offset..(*offset + decoding_width)]);
  258. (*offset) += decoding_width;
  259. BorshToken::Int {
  260. width: (decoding_width * 8) as u16,
  261. value: bigint,
  262. }
  263. }
  264. IdlType::Bool => {
  265. let val = data[*offset] == 1;
  266. (*offset) += 1;
  267. BorshToken::Bool(val)
  268. }
  269. IdlType::String => {
  270. let mut int_data: [u8; 4] = Default::default();
  271. int_data.copy_from_slice(&data[*offset..(*offset + 4)]);
  272. let len = u32::from_le_bytes(int_data) as usize;
  273. (*offset) += 4;
  274. let read_string = std::str::from_utf8(&data[*offset..(*offset + len)]).unwrap();
  275. (*offset) += len;
  276. BorshToken::String(read_string.to_string())
  277. }
  278. IdlType::Array(ty, len) => {
  279. let mut read_items: Vec<BorshToken> = Vec::with_capacity(*len);
  280. for _ in 0..*len {
  281. read_items.push(decode_at_offset(data, offset, ty, custom_types));
  282. }
  283. BorshToken::FixedArray(read_items)
  284. }
  285. IdlType::Vec(ty) => {
  286. let mut int_data: [u8; 4] = Default::default();
  287. int_data.copy_from_slice(&data[*offset..(*offset + 4)]);
  288. let len = u32::from_le_bytes(int_data);
  289. (*offset) += 4;
  290. let mut read_items: Vec<BorshToken> = Vec::with_capacity(len as usize);
  291. for _ in 0..len {
  292. read_items.push(decode_at_offset(data, offset, ty, custom_types));
  293. }
  294. BorshToken::Array(read_items)
  295. }
  296. IdlType::Defined(value) => {
  297. let current_ty = custom_types
  298. .iter()
  299. .find(|item| &item.name == value)
  300. .unwrap();
  301. match &current_ty.ty {
  302. IdlTypeDefinitionTy::Enum { .. } => {
  303. let value = data[*offset];
  304. (*offset) += 1;
  305. BorshToken::Uint {
  306. width: 8,
  307. value: BigInt::from(value),
  308. }
  309. }
  310. IdlTypeDefinitionTy::Struct { fields } => {
  311. let mut read_items: Vec<BorshToken> = Vec::with_capacity(fields.len());
  312. for item in fields {
  313. read_items.push(decode_at_offset(data, offset, &item.ty, custom_types));
  314. }
  315. BorshToken::Tuple(read_items)
  316. }
  317. }
  318. }
  319. IdlType::Bytes => {
  320. let mut int_data: [u8; 4] = Default::default();
  321. int_data.copy_from_slice(&data[*offset..(*offset + 4)]);
  322. let len = u32::from_le_bytes(int_data) as usize;
  323. (*offset) += 4;
  324. let read_data = &data[*offset..(*offset + len)];
  325. (*offset) += len;
  326. BorshToken::Bytes(read_data.to_vec())
  327. }
  328. IdlType::Option(_) | IdlType::F32 | IdlType::F64 => {
  329. unreachable!("Type not available in Solidity")
  330. }
  331. }
  332. }
  333. fn integer_byte_width(ty: &IdlType) -> usize {
  334. match ty {
  335. IdlType::U8 | IdlType::I8 => 1,
  336. IdlType::U16 | IdlType::I16 => 2,
  337. IdlType::U32 | IdlType::I32 => 4,
  338. IdlType::U64 | IdlType::I64 => 8,
  339. IdlType::U128 | IdlType::I128 => 16,
  340. IdlType::U256 | IdlType::I256 => 32,
  341. _ => unreachable!("Not an integer"),
  342. }
  343. }