borsh_encoding.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. // SPDX-License-Identifier: Apache-2.0
  2. use anchor_syn::idl::types::{IdlType, IdlTypeDefinition, IdlTypeDefinitionTy};
  3. use byte_slice_cast::AsByteSlice;
  4. use num_bigint::{BigInt, Sign};
  5. use num_traits::ToPrimitive;
  6. use serde::Deserialize;
  7. use std::cmp::Ordering;
  8. /// This is the token that should be used for each function call in Solana runtime tests
  9. #[derive(Debug, PartialEq, Clone, Deserialize)]
  10. pub enum BorshToken {
  11. Address([u8; 32]),
  12. Int { width: u16, value: BigInt },
  13. Uint { width: u16, value: BigInt },
  14. FixedBytes(Vec<u8>),
  15. Bytes(Vec<u8>),
  16. Bool(bool),
  17. String(String),
  18. FixedArray(Vec<BorshToken>),
  19. Array(Vec<BorshToken>),
  20. Tuple(Vec<BorshToken>),
  21. }
  22. impl BorshToken {
  23. /// Encode the parameter into the buffer
  24. pub fn encode(&self, buffer: &mut Vec<u8>) {
  25. match self {
  26. BorshToken::Address(data) => {
  27. buffer.extend_from_slice(data);
  28. }
  29. BorshToken::Uint { width, value } => {
  30. encode_uint(*width, value, buffer);
  31. }
  32. BorshToken::Int { width, value } => {
  33. encode_int(*width, value, buffer);
  34. }
  35. BorshToken::FixedBytes(data) => {
  36. buffer.extend_from_slice(data);
  37. }
  38. BorshToken::Bytes(data) => {
  39. let len = data.len() as u32;
  40. buffer.extend_from_slice(&len.to_le_bytes());
  41. buffer.extend_from_slice(data);
  42. }
  43. BorshToken::Bool(value) => {
  44. buffer.push(*value as u8);
  45. }
  46. BorshToken::String(data) => {
  47. let len = data.len() as u32;
  48. buffer.extend_from_slice(&len.to_le_bytes());
  49. buffer.extend_from_slice(data.as_byte_slice());
  50. }
  51. BorshToken::Tuple(data) | BorshToken::FixedArray(data) => {
  52. for item in data {
  53. item.encode(buffer);
  54. }
  55. }
  56. BorshToken::Array(arr) => {
  57. let len = arr.len() as u32;
  58. buffer.extend_from_slice(&len.to_le_bytes());
  59. for item in arr {
  60. item.encode(buffer);
  61. }
  62. }
  63. }
  64. }
  65. pub fn into_string(self) -> Option<String> {
  66. match self {
  67. BorshToken::String(value) => Some(value),
  68. _ => None,
  69. }
  70. }
  71. pub fn into_array(self) -> Option<Vec<BorshToken>> {
  72. match self {
  73. BorshToken::Array(value) => Some(value),
  74. _ => None,
  75. }
  76. }
  77. pub fn into_fixed_bytes(self) -> Option<Vec<u8>> {
  78. match self {
  79. BorshToken::FixedBytes(value) => Some(value),
  80. BorshToken::FixedArray(vec) => {
  81. let mut response: Vec<u8> = Vec::with_capacity(vec.len());
  82. for elem in vec {
  83. match elem {
  84. BorshToken::Uint { width, value } => {
  85. assert_eq!(width, 8);
  86. response.push(value.to_u8().unwrap());
  87. }
  88. _ => unreachable!("Array cannot be converted to fixed bytes"),
  89. }
  90. }
  91. Some(response)
  92. }
  93. BorshToken::Address(value) => Some(value.to_vec()),
  94. _ => None,
  95. }
  96. }
  97. pub fn into_bytes(self) -> Option<Vec<u8>> {
  98. match self {
  99. BorshToken::Bytes(value) => Some(value),
  100. _ => None,
  101. }
  102. }
  103. pub fn into_bigint(self) -> Option<BigInt> {
  104. match self {
  105. BorshToken::Uint { value, .. } => Some(value),
  106. BorshToken::Int { value, .. } => Some(value),
  107. _ => None,
  108. }
  109. }
  110. pub fn unwrap_tuple(self) -> Vec<BorshToken> {
  111. match self {
  112. BorshToken::Tuple(vec) => vec,
  113. _ => panic!("This is not a tuple"),
  114. }
  115. }
  116. pub fn uint8_fixed_array(vec: Vec<u8>) -> BorshToken {
  117. let mut array: Vec<BorshToken> = Vec::with_capacity(vec.len());
  118. for item in &vec {
  119. array.push(BorshToken::Uint {
  120. width: 8,
  121. value: BigInt::from(*item),
  122. });
  123. }
  124. BorshToken::FixedArray(array)
  125. }
  126. }
  127. /// Encode a signed integer
  128. fn encode_int(width: u16, value: &BigInt, buffer: &mut Vec<u8>) {
  129. match width {
  130. 1..=8 => {
  131. let val = value.to_i8().unwrap();
  132. buffer.extend_from_slice(&val.to_le_bytes());
  133. }
  134. 9..=16 => {
  135. let val = value.to_i16().unwrap();
  136. buffer.extend_from_slice(&val.to_le_bytes());
  137. }
  138. 17..=32 => {
  139. let val = value.to_i32().unwrap();
  140. buffer.extend_from_slice(&val.to_le_bytes());
  141. }
  142. 33..=64 => {
  143. let val = value.to_i64().unwrap();
  144. buffer.extend_from_slice(&val.to_le_bytes());
  145. }
  146. 65..=128 => {
  147. let val = value.to_i128().unwrap();
  148. buffer.extend_from_slice(&val.to_le_bytes());
  149. }
  150. 129..=256 => {
  151. let mut val = value.to_signed_bytes_le();
  152. let byte_width = 32;
  153. match val.len().cmp(&byte_width) {
  154. Ordering::Greater => {
  155. while val.len() > byte_width {
  156. val.pop();
  157. }
  158. }
  159. Ordering::Less => {
  160. if value.sign() == Sign::Minus {
  161. val.extend(vec![255; byte_width - val.len()]);
  162. } else {
  163. val.extend(vec![0; byte_width - val.len()]);
  164. }
  165. }
  166. Ordering::Equal => (),
  167. }
  168. buffer.extend_from_slice(&val);
  169. }
  170. _ => unreachable!("bit width not supported"),
  171. }
  172. }
  173. /// Encode an unsigned integer
  174. fn encode_uint(width: u16, value: &BigInt, buffer: &mut Vec<u8>) {
  175. match width {
  176. 1..=8 => {
  177. let val = value.to_u8().unwrap();
  178. buffer.push(val);
  179. }
  180. 9..=16 => {
  181. let val = value.to_u16().unwrap();
  182. buffer.extend_from_slice(&val.to_le_bytes());
  183. }
  184. 17..=32 => {
  185. let val = value.to_u32().unwrap();
  186. buffer.extend_from_slice(&val.to_le_bytes());
  187. }
  188. 33..=64 => {
  189. let val = value.to_u64().unwrap();
  190. buffer.extend_from_slice(&val.to_le_bytes());
  191. }
  192. 65..=128 => {
  193. let val = value.to_u128().unwrap();
  194. buffer.extend_from_slice(&val.to_le_bytes());
  195. }
  196. 129..=256 => {
  197. let mut val = value.to_signed_bytes_le();
  198. let bytes_width = 32;
  199. match val.len().cmp(&bytes_width) {
  200. Ordering::Greater => {
  201. while val.len() > bytes_width {
  202. val.pop();
  203. }
  204. }
  205. Ordering::Less => {
  206. val.extend(vec![0; bytes_width - val.len()]);
  207. }
  208. Ordering::Equal => (),
  209. }
  210. buffer.extend_from_slice(&val);
  211. }
  212. _ => unreachable!("bit width not supported"),
  213. }
  214. }
  215. /// Encode the arguments of a function
  216. pub fn encode_arguments(args: &[BorshToken]) -> Vec<u8> {
  217. let mut encoded: Vec<u8> = Vec::new();
  218. for item in args {
  219. item.encode(&mut encoded);
  220. }
  221. encoded
  222. }
  223. /// Decode a parameter at a given offset
  224. pub fn decode_at_offset(
  225. data: &[u8],
  226. offset: &mut usize,
  227. ty: &IdlType,
  228. custom_types: &[IdlTypeDefinition],
  229. ) -> BorshToken {
  230. match ty {
  231. IdlType::PublicKey => {
  232. let read = &data[*offset..(*offset + 32)];
  233. (*offset) += 32;
  234. BorshToken::Address(<[u8; 32]>::try_from(read).unwrap())
  235. }
  236. IdlType::U8
  237. | IdlType::U16
  238. | IdlType::U32
  239. | IdlType::U64
  240. | IdlType::U128
  241. | IdlType::U256 => {
  242. let decoding_width = integer_byte_width(ty);
  243. let bigint =
  244. BigInt::from_bytes_le(Sign::Plus, &data[*offset..(*offset + decoding_width)]);
  245. (*offset) += decoding_width;
  246. BorshToken::Uint {
  247. width: (decoding_width * 8) as u16,
  248. value: bigint,
  249. }
  250. }
  251. IdlType::I8
  252. | IdlType::I16
  253. | IdlType::I32
  254. | IdlType::I64
  255. | IdlType::I128
  256. | IdlType::I256 => {
  257. let decoding_width = integer_byte_width(ty);
  258. let bigint = BigInt::from_signed_bytes_le(&data[*offset..(*offset + decoding_width)]);
  259. (*offset) += decoding_width;
  260. BorshToken::Int {
  261. width: (decoding_width * 8) as u16,
  262. value: bigint,
  263. }
  264. }
  265. IdlType::Bool => {
  266. let val = data[*offset] == 1;
  267. (*offset) += 1;
  268. BorshToken::Bool(val)
  269. }
  270. IdlType::String => {
  271. let mut int_data: [u8; 4] = Default::default();
  272. int_data.copy_from_slice(&data[*offset..(*offset + 4)]);
  273. let len = u32::from_le_bytes(int_data) as usize;
  274. (*offset) += 4;
  275. let read_string = std::str::from_utf8(&data[*offset..(*offset + len)]).unwrap();
  276. (*offset) += len;
  277. BorshToken::String(read_string.to_string())
  278. }
  279. IdlType::Array(ty, len) => {
  280. let mut read_items: Vec<BorshToken> = Vec::with_capacity(*len);
  281. for _ in 0..*len {
  282. read_items.push(decode_at_offset(data, offset, ty, custom_types));
  283. }
  284. BorshToken::FixedArray(read_items)
  285. }
  286. IdlType::Vec(ty) => {
  287. let mut int_data: [u8; 4] = Default::default();
  288. int_data.copy_from_slice(&data[*offset..(*offset + 4)]);
  289. let len = u32::from_le_bytes(int_data);
  290. (*offset) += 4;
  291. let mut read_items: Vec<BorshToken> = Vec::with_capacity(len as usize);
  292. for _ in 0..len {
  293. read_items.push(decode_at_offset(data, offset, ty, custom_types));
  294. }
  295. BorshToken::Array(read_items)
  296. }
  297. IdlType::Defined(value) => {
  298. let current_ty = custom_types
  299. .iter()
  300. .find(|item| &item.name == value)
  301. .unwrap();
  302. match &current_ty.ty {
  303. IdlTypeDefinitionTy::Enum { .. } => {
  304. let value = data[*offset];
  305. (*offset) += 1;
  306. BorshToken::Uint {
  307. width: 8,
  308. value: BigInt::from(value),
  309. }
  310. }
  311. IdlTypeDefinitionTy::Struct { fields } => {
  312. let mut read_items: Vec<BorshToken> = Vec::with_capacity(fields.len());
  313. for item in fields {
  314. read_items.push(decode_at_offset(data, offset, &item.ty, custom_types));
  315. }
  316. BorshToken::Tuple(read_items)
  317. }
  318. IdlTypeDefinitionTy::Alias { value } => {
  319. decode_at_offset(data, offset, value, custom_types)
  320. }
  321. }
  322. }
  323. IdlType::Bytes => {
  324. let mut int_data: [u8; 4] = Default::default();
  325. int_data.copy_from_slice(&data[*offset..(*offset + 4)]);
  326. let len = u32::from_le_bytes(int_data) as usize;
  327. (*offset) += 4;
  328. let read_data = &data[*offset..(*offset + len)];
  329. (*offset) += len;
  330. BorshToken::Bytes(read_data.to_vec())
  331. }
  332. IdlType::Option(_)
  333. | IdlType::F32
  334. | IdlType::F64
  335. | IdlType::Generic(..)
  336. | IdlType::DefinedWithTypeArgs { .. }
  337. | IdlType::GenericLenArray(..) => {
  338. unreachable!("Type not available in Solidity")
  339. }
  340. }
  341. }
  342. fn integer_byte_width(ty: &IdlType) -> usize {
  343. match ty {
  344. IdlType::U8 | IdlType::I8 => 1,
  345. IdlType::U16 | IdlType::I16 => 2,
  346. IdlType::U32 | IdlType::I32 => 4,
  347. IdlType::U64 | IdlType::I64 => 8,
  348. IdlType::U128 | IdlType::I128 => 16,
  349. IdlType::U256 | IdlType::I256 => 32,
  350. _ => unreachable!("Not an integer"),
  351. }
  352. }
  353. pub trait VisitorMut {
  354. fn visit_address(&mut self, _a: &mut [u8; 32]) {}
  355. fn visit_int(&mut self, _width: &mut u16, _value: &mut BigInt) {}
  356. fn visit_uint(&mut self, _width: &mut u16, _value: &mut BigInt) {}
  357. fn visit_fixed_bytes(&mut self, _v: &mut Vec<u8>) {}
  358. fn visit_bytes(&mut self, _v: &mut Vec<u8>) {}
  359. fn visit_bool(&mut self, _b: &mut bool) {}
  360. fn visit_string(&mut self, _s: &mut String) {}
  361. fn visit_fixed_array(&mut self, v: &mut Vec<BorshToken>) {
  362. visit_fixed_array(self, v)
  363. }
  364. fn visit_array(&mut self, v: &mut Vec<BorshToken>) {
  365. visit_array(self, v)
  366. }
  367. fn visit_tuple(&mut self, v: &mut Vec<BorshToken>) {
  368. visit_tuple(self, v)
  369. }
  370. }
  371. pub fn visit_mut<T: VisitorMut + ?Sized>(visitor: &mut T, token: &mut BorshToken) {
  372. match token {
  373. BorshToken::Address(a) => visitor.visit_address(a),
  374. BorshToken::Int { width, value } => visitor.visit_int(width, value),
  375. BorshToken::Uint { width, value } => visitor.visit_uint(width, value),
  376. BorshToken::FixedBytes(v) => visitor.visit_fixed_bytes(v),
  377. BorshToken::Bytes(v) => visitor.visit_bytes(v),
  378. BorshToken::Bool(b) => visitor.visit_bool(b),
  379. BorshToken::String(s) => visitor.visit_string(s),
  380. BorshToken::FixedArray(v) => visitor.visit_fixed_array(v),
  381. BorshToken::Array(v) => visitor.visit_array(v),
  382. BorshToken::Tuple(v) => visitor.visit_tuple(v),
  383. }
  384. }
  385. #[allow(clippy::ptr_arg)]
  386. pub fn visit_fixed_array<T: VisitorMut + ?Sized>(visitor: &mut T, v: &mut Vec<BorshToken>) {
  387. for token in v.iter_mut() {
  388. visit_mut(visitor, token);
  389. }
  390. }
  391. #[allow(clippy::ptr_arg)]
  392. pub fn visit_array<T: VisitorMut + ?Sized>(visitor: &mut T, v: &mut Vec<BorshToken>) {
  393. for token in v.iter_mut() {
  394. visit_mut(visitor, token);
  395. }
  396. }
  397. #[allow(clippy::ptr_arg)]
  398. pub fn visit_tuple<T: VisitorMut + ?Sized>(visitor: &mut T, v: &mut Vec<BorshToken>) {
  399. for token in v.iter_mut() {
  400. visit_mut(visitor, token);
  401. }
  402. }