additional_experiments.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. import argparse
  6. import math
  7. import os
  8. from pathlib import Path
  9. import time
  10. import urllib.request
  11. import zipfile
  12. import pandas as pd
  13. import tiktoken
  14. import torch
  15. from torch.utils.data import DataLoader
  16. from torch.utils.data import Dataset
  17. from gpt_download import download_and_load_gpt2
  18. from previous_chapters import GPTModel, load_weights_into_gpt
  19. # If the `previous_chapters.py` file is not available locally,
  20. # you can import it from the `llms-from-scratch` PyPI package.
  21. # For details, see: https://github.com/rasbt/LLMs-from-scratch/tree/main/pkg
  22. # E.g.,
  23. # from llms_from_scratch.ch04 import GPTModel
  24. # from llms_from_scratch.ch05 import download_and_load_gpt2, load_weights_into_gpt
  25. class LoRALayer(torch.nn.Module):
  26. def __init__(self, in_dim, out_dim, rank, alpha):
  27. super().__init__()
  28. self.A = torch.nn.Parameter(torch.empty(in_dim, rank))
  29. torch.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
  30. self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
  31. self.alpha = alpha
  32. def forward(self, x):
  33. x = self.alpha * (x @ self.A @ self.B)
  34. return x
  35. class LinearWithLoRA(torch.nn.Module):
  36. def __init__(self, linear, rank, alpha):
  37. super().__init__()
  38. self.linear = linear
  39. self.lora = LoRALayer(
  40. linear.in_features, linear.out_features, rank, alpha
  41. )
  42. def forward(self, x):
  43. return self.linear(x) + self.lora(x)
  44. # This LoRA code is equivalent to LinearWithLoRA
  45. class LinearWithLoRAMerged(torch.nn.Module):
  46. def __init__(self, linear, rank, alpha):
  47. super().__init__()
  48. self.linear = linear
  49. self.lora = LoRALayer(
  50. linear.in_features, linear.out_features, rank, alpha
  51. )
  52. def forward(self, x):
  53. lora = self.lora.A @ self.lora.B
  54. combined_weight = self.linear.weight + self.lora.alpha*lora.T
  55. return torch.nn.functional.linear(x, combined_weight, self.linear.bias)
  56. class SpamDataset(Dataset):
  57. def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256, no_padding=False):
  58. self.data = pd.read_csv(csv_file)
  59. self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer)
  60. # Pre-tokenize texts
  61. self.encoded_texts = [
  62. tokenizer.encode(text)[:self.max_length]
  63. for text in self.data["Text"]
  64. ]
  65. if not no_padding:
  66. # Pad sequences to the longest sequence
  67. self.encoded_texts = [
  68. et + [pad_token_id] * (self.max_length - len(et))
  69. for et in self.encoded_texts
  70. ]
  71. def __getitem__(self, index):
  72. encoded = self.encoded_texts[index]
  73. label = self.data.iloc[index]["Label"]
  74. return torch.tensor(encoded, dtype=torch.long), torch.tensor(label, dtype=torch.long)
  75. def __len__(self):
  76. return len(self.data)
  77. def _longest_encoded_length(self, tokenizer):
  78. max_length = 0
  79. for text in self.data["Text"]:
  80. encoded_length = len(tokenizer.encode(text))
  81. if encoded_length > max_length:
  82. max_length = encoded_length
  83. return max_length
  84. # Note: A more pythonic version to implement this method
  85. # is the following, which is also used in the next chapter:
  86. # return max(len(encoded_text) for encoded_text in self.encoded_texts)
  87. def download_and_unzip(url, zip_path, extract_to, new_file_path):
  88. if new_file_path.exists():
  89. print(f"{new_file_path} already exists. Skipping download and extraction.")
  90. return
  91. # Downloading the file
  92. with urllib.request.urlopen(url) as response:
  93. with open(zip_path, "wb") as out_file:
  94. out_file.write(response.read())
  95. # Unzipping the file
  96. with zipfile.ZipFile(zip_path, "r") as zip_ref:
  97. zip_ref.extractall(extract_to)
  98. # Renaming the file to indicate its format
  99. original_file = Path(extract_to) / "SMSSpamCollection"
  100. os.rename(original_file, new_file_path)
  101. print(f"File downloaded and saved as {new_file_path}")
  102. def random_split(df, train_frac, validation_frac):
  103. # Shuffle the entire DataFrame
  104. df = df.sample(frac=1, random_state=123).reset_index(drop=True)
  105. # Calculate split indices
  106. train_end = int(len(df) * train_frac)
  107. validation_end = train_end + int(len(df) * validation_frac)
  108. # Split the DataFrame
  109. train_df = df[:train_end]
  110. validation_df = df[train_end:validation_end]
  111. test_df = df[validation_end:]
  112. return train_df, validation_df, test_df
  113. def create_dataset_csvs(new_file_path):
  114. df = pd.read_csv(new_file_path, sep="\t", header=None, names=["Label", "Text"])
  115. # Create balanced dataset
  116. n_spam = df[df["Label"] == "spam"].shape[0]
  117. ham_sampled = df[df["Label"] == "ham"].sample(n_spam, random_state=123)
  118. balanced_df = pd.concat([ham_sampled, df[df["Label"] == "spam"]])
  119. balanced_df = balanced_df.sample(frac=1, random_state=123).reset_index(drop=True)
  120. balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})
  121. # Sample and save csv files
  122. train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1)
  123. train_df.to_csv("train.csv", index=None)
  124. validation_df.to_csv("validation.csv", index=None)
  125. test_df.to_csv("test.csv", index=None)
  126. def instantiate_model(choose_model, load_weights):
  127. BASE_CONFIG = {
  128. "vocab_size": 50257, # Vocabulary size
  129. "context_length": 1024, # Context length
  130. "drop_rate": 0.0, # Dropout rate
  131. "qkv_bias": True # Query-key-value bias
  132. }
  133. model_configs = {
  134. "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
  135. "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
  136. "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
  137. "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
  138. }
  139. BASE_CONFIG.update(model_configs[choose_model])
  140. if not load_weights:
  141. torch.manual_seed(123)
  142. model = GPTModel(BASE_CONFIG, disable_causal_mask=args.disable_causal_mask)
  143. if load_weights:
  144. model_size = choose_model.split(" ")[-1].lstrip("(").rstrip(")")
  145. settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
  146. load_weights_into_gpt(model, params)
  147. model.eval()
  148. return model
  149. def calc_loss_batch(input_batch, target_batch, model, device,
  150. trainable_token_pos=-1, ignore_index=-100, average_embeddings=False):
  151. input_batch, target_batch = input_batch.to(device), target_batch.to(device)
  152. if trainable_token_pos == "flexible": # Selects the last tokens before the padding tokens
  153. # From https://github.com/rasbt/LLMs-from-scratch/discussions/434
  154. # Find the last non-padding token for each sequence in the batch
  155. pad_token_id = 50256 # <|endoftext|> token used for padding
  156. mask = input_batch != pad_token_id
  157. last_token_pos = mask.sum(dim=1) - 1 # Get position of last real token
  158. # Get model outputs
  159. logits = model(input_batch) # shape: [batch_size, seq_len, num_classes]
  160. # Select the logits corresponding to the last real token of each sequence
  161. batch_size = logits.size(0)
  162. selected_logits = logits[torch.arange(batch_size), last_token_pos]
  163. loss = torch.nn.functional.cross_entropy(selected_logits, target_batch)
  164. return loss
  165. else:
  166. model_output = model(input_batch)
  167. if average_embeddings:
  168. # Average over the sequence dimension (dim=1)
  169. logits = model_output.mean(dim=1)
  170. else:
  171. # Select embeddings at the specified token position
  172. logits = model_output[:, trainable_token_pos, :]
  173. loss = torch.nn.functional.cross_entropy(logits, target_batch, ignore_index=ignore_index)
  174. return loss
  175. def calc_loss_loader(data_loader, model, device,
  176. num_batches=None, trainable_token_pos=-1,
  177. ignore_index=-100, average_embeddings=False):
  178. total_loss = 0.
  179. if len(data_loader) == 0:
  180. return float("nan")
  181. elif num_batches is None:
  182. num_batches = len(data_loader)
  183. else:
  184. # Reduce the number of batches to match the total number of batches in the data loader
  185. # if num_batches exceeds the number of batches in the data loader
  186. num_batches = min(num_batches, len(data_loader))
  187. for i, (input_batch, target_batch) in enumerate(data_loader):
  188. if i < num_batches:
  189. loss = calc_loss_batch(
  190. input_batch, target_batch, model, device,
  191. trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
  192. average_embeddings=average_embeddings
  193. )
  194. total_loss += loss.item()
  195. else:
  196. break
  197. return total_loss / num_batches
  198. @torch.no_grad() # Disable gradient tracking for efficiency
  199. def calc_accuracy_loader(data_loader, model, device, num_batches=None,
  200. trainable_token_pos=-1, average_embeddings=False):
  201. model.eval()
  202. correct_predictions, num_examples = 0, 0
  203. if num_batches is None:
  204. num_batches = len(data_loader)
  205. else:
  206. num_batches = min(num_batches, len(data_loader))
  207. if trainable_token_pos == "flexible":
  208. for i, (input_batch, target_batch) in enumerate(data_loader):
  209. if i < num_batches:
  210. input_batch, target_batch = input_batch.to(device), target_batch.to(device)
  211. # Find the last non-padding token for each sequence in the batch
  212. pad_token_id = 50256 # <|endoftext|> token used for padding
  213. mask = input_batch != pad_token_id
  214. last_token_pos = mask.sum(dim=1) - 1 # Get position of last real token
  215. logits = model(input_batch) # Logits of last output token
  216. # Select the logits corresponding to the last real token of each sequence
  217. batch_size = logits.size(0)
  218. selected_logits = logits[torch.arange(batch_size), last_token_pos]
  219. predicted_labels = torch.argmax(selected_logits, dim=-1)
  220. num_examples += predicted_labels.shape[0]
  221. correct_predictions += (predicted_labels == target_batch).sum().item()
  222. else:
  223. break
  224. else:
  225. for i, (input_batch, target_batch) in enumerate(data_loader):
  226. if i < num_batches:
  227. input_batch, target_batch = input_batch.to(device), target_batch.to(device)
  228. model_output = model(input_batch)
  229. if average_embeddings:
  230. # Average over the sequence dimension (dim=1)
  231. logits = model_output.mean(dim=1)
  232. else:
  233. # Select embeddings at the specified token position
  234. logits = model_output[:, trainable_token_pos, :]
  235. predicted_labels = torch.argmax(logits, dim=-1)
  236. num_examples += predicted_labels.shape[0]
  237. correct_predictions += (predicted_labels == target_batch).sum().item()
  238. else:
  239. break
  240. return correct_predictions / num_examples
  241. def evaluate_model(model, train_loader, val_loader, device,
  242. eval_iter, trainable_token_pos=-1,
  243. ignore_index=-100, average_embeddings=False):
  244. model.eval()
  245. with torch.no_grad():
  246. train_loss = calc_loss_loader(
  247. train_loader, model, device, num_batches=eval_iter,
  248. trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
  249. average_embeddings=average_embeddings
  250. )
  251. val_loss = calc_loss_loader(
  252. val_loader, model, device, num_batches=eval_iter,
  253. trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
  254. average_embeddings=average_embeddings
  255. )
  256. model.train()
  257. return train_loss, val_loss
  258. def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
  259. eval_freq, eval_iter, max_steps=None, trainable_token_pos=-1,
  260. accumulation_steps=1, ignore_index=-100, average_embeddings=False):
  261. # Initialize lists to track losses and tokens seen
  262. train_losses, val_losses, train_accs, val_accs = [], [], [], []
  263. examples_seen, global_step = 0, -1
  264. # Main training loop
  265. for epoch in range(num_epochs):
  266. model.train() # Set model to training mode
  267. for batch_idx, (input_batch, target_batch) in enumerate(train_loader):
  268. loss = calc_loss_batch(
  269. input_batch, target_batch, model, device,
  270. trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
  271. average_embeddings=average_embeddings
  272. )
  273. # Use gradient accumulation if accumulation_steps > 1
  274. # See https://sebastianraschka.com/blog/2023/llm-grad-accumulation.html
  275. # for an explanation
  276. loss /= accumulation_steps
  277. loss.backward() # Calculate loss gradients
  278. # Use gradient accumulation if accumulation_steps > 1
  279. is_update_step = ((batch_idx + 1) % accumulation_steps == 0) or ((batch_idx + 1) == len(train_loader))
  280. if is_update_step:
  281. optimizer.step() # Update model weights using loss gradients
  282. optimizer.zero_grad() # Reset loss gradients from previous batch iteration
  283. examples_seen += input_batch.shape[0] # New: track examples instead of tokens
  284. global_step += 1
  285. # Optional evaluation step
  286. if global_step % eval_freq == 0:
  287. train_loss, val_loss = evaluate_model(
  288. model, train_loader, val_loader, device, eval_iter,
  289. trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
  290. average_embeddings=average_embeddings
  291. )
  292. train_losses.append(train_loss)
  293. val_losses.append(val_loss)
  294. print(f"Ep {epoch+1} (Step {global_step:06d}): "
  295. f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
  296. if max_steps is not None and global_step > max_steps:
  297. break
  298. # New: Calculate accuracy after each epoch
  299. train_accuracy = calc_accuracy_loader(
  300. train_loader, model, device, num_batches=eval_iter,
  301. trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
  302. )
  303. val_accuracy = calc_accuracy_loader(
  304. val_loader, model, device, num_batches=eval_iter,
  305. trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
  306. )
  307. print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
  308. print(f"Validation accuracy: {val_accuracy*100:.2f}%")
  309. train_accs.append(train_accuracy)
  310. val_accs.append(val_accuracy)
  311. if max_steps is not None and global_step > max_steps:
  312. break
  313. return train_losses, val_losses, train_accs, val_accs, examples_seen
  314. def replace_linear_with_lora(model, rank, alpha, alternative=False):
  315. for name, module in model.named_children():
  316. if isinstance(module, torch.nn.Linear):
  317. # Replace the Linear layer with LinearWithLoRA
  318. if alternative:
  319. setattr(model, name, LinearWithLoRAMerged(module, rank, alpha))
  320. else:
  321. setattr(model, name, LinearWithLoRA(module, rank, alpha))
  322. else:
  323. # Recursively apply the same function to child modules
  324. replace_linear_with_lora(module, rank, alpha)
  325. if __name__ == "__main__":
  326. parser = argparse.ArgumentParser()
  327. parser.add_argument(
  328. "--model_size",
  329. type=str,
  330. default="gpt2-small (124M)",
  331. help=(
  332. "Which GPT model to use. Options: 'gpt2-small (124M)', 'gpt2-medium (355M)',"
  333. " 'gpt2-large (774M)', 'gpt2-xl (1558M)'."
  334. )
  335. )
  336. parser.add_argument(
  337. "--weights",
  338. type=str,
  339. default="pretrained",
  340. help=(
  341. "Whether to use 'pretrained' or 'random' weights."
  342. )
  343. )
  344. parser.add_argument(
  345. "--trainable_layers",
  346. type=str,
  347. default="last_block",
  348. help=(
  349. "Which layers to train. Options: 'all', 'last_block', 'last_two_blocks', 'last_layer', 'lora', 'lora_alternative'."
  350. )
  351. )
  352. parser.add_argument(
  353. "--trainable_token_pos",
  354. type=str,
  355. default="last",
  356. help=(
  357. "Which token position to train. Options: 'first', 'last', 'flexible'."
  358. )
  359. )
  360. parser.add_argument(
  361. "--average_embeddings",
  362. action='store_true',
  363. default=False,
  364. help=(
  365. "Average the output embeddings from all tokens instead of using"
  366. " only the embedding at the token position specified by `--trainable_token_pos`."
  367. )
  368. )
  369. parser.add_argument(
  370. "--context_length",
  371. type=str,
  372. default="longest_training_example",
  373. help=(
  374. "The context length of the data inputs."
  375. " Options: 'longest_training_example', 'model_context_length' or integer value."
  376. )
  377. )
  378. parser.add_argument(
  379. "--lora_rank",
  380. type=int,
  381. default=8,
  382. help=(
  383. "The LoRA rank when choosing `--trainable_layers lora`"
  384. )
  385. )
  386. parser.add_argument(
  387. "--lora_alpha",
  388. type=int,
  389. default=8,
  390. help=(
  391. "The LoRA alpha value when choosing `--trainable_layers lora`"
  392. )
  393. )
  394. parser.add_argument(
  395. "--no_padding",
  396. action='store_true',
  397. default=False,
  398. help=(
  399. "Disable padding, which means each example may have a different length."
  400. " This requires setting `--batch_size 1`."
  401. )
  402. )
  403. parser.add_argument(
  404. "--num_epochs",
  405. type=int,
  406. default=5,
  407. help=(
  408. "Number of training epochs."
  409. )
  410. )
  411. parser.add_argument(
  412. "--batch_size",
  413. type=int,
  414. default=8,
  415. help=(
  416. "The batch size used for training."
  417. )
  418. )
  419. parser.add_argument(
  420. "--accumulation_steps",
  421. type=int,
  422. default=1,
  423. help=(
  424. "Accumulation steps to allow for gradient accumulation."
  425. " See https://sebastianraschka.com/blog/2023/llm-grad-accumulation.html for explanation."
  426. " For example, setting `batch_size=8` and `accumulation_steps=1` compute the exact same"
  427. " loss and weight updates as setting `batch_size=1` and `accumulation_steps=8`, however,"
  428. " the latter setting uses more iterations."
  429. )
  430. )
  431. parser.add_argument(
  432. "--disable_causal_mask",
  433. action='store_true',
  434. default=False,
  435. help=(
  436. "Disables the causal attention mask."
  437. )
  438. )
  439. parser.add_argument(
  440. "--ignore_index",
  441. type=int,
  442. default=-100,
  443. help=(
  444. "Sets the `ignore_index` in the cross-entropy loss."
  445. )
  446. )
  447. args = parser.parse_args()
  448. if args.trainable_token_pos == "first":
  449. args.trainable_token_pos = 0
  450. elif args.trainable_token_pos == "last":
  451. args.trainable_token_pos = -1
  452. # The "flexible" setting selects the last tokens before the padding tokens
  453. # See https://github.com/rasbt/LLMs-from-scratch/discussions/434
  454. elif args.trainable_token_pos == "flexible":
  455. args.trainable_token_pos = "flexible"
  456. else:
  457. raise ValueError("Invalid --trainable_token_pos argument")
  458. ###############################
  459. # Load model
  460. ###############################
  461. if args.weights == "pretrained":
  462. load_weights = True
  463. elif args.weights == "random":
  464. load_weights = False
  465. else:
  466. raise ValueError("Invalid --weights argument.")
  467. model = instantiate_model(args.model_size, load_weights)
  468. for param in model.parameters():
  469. param.requires_grad = False
  470. if args.model_size == "gpt2-small (124M)":
  471. in_features = 768
  472. elif args.model_size == "gpt2-medium (355M)":
  473. in_features = 1024
  474. elif args.model_size == "gpt2-large (774M)":
  475. in_features = 1280
  476. elif args.model_size == "gpt2-xl (1558M)":
  477. in_features = 1600
  478. else:
  479. raise ValueError("Invalid --model_size argument")
  480. torch.manual_seed(123)
  481. model.out_head = torch.nn.Linear(in_features=in_features, out_features=2)
  482. if args.trainable_layers == "last_layer":
  483. pass
  484. elif args.trainable_layers == "last_block" or args.trainable_layers == "last_two_blocks":
  485. for param in model.trf_blocks[-1].parameters():
  486. param.requires_grad = True
  487. for param in model.final_norm.parameters():
  488. param.requires_grad = True
  489. if args.trainable_layers == "last_two_blocks":
  490. for param in model.trf_blocks[-2].parameters():
  491. param.requires_grad = True
  492. elif args.trainable_layers == "all":
  493. for param in model.parameters():
  494. param.requires_grad = True
  495. elif args.trainable_layers in ("lora", "lora_alternative"):
  496. if args.trainable_layers == "lora_alternative":
  497. alternative = True
  498. else:
  499. alternative = False
  500. replace_linear_with_lora(model, rank=args.lora_rank, alpha=args.lora_alpha, alternative=alternative)
  501. else:
  502. raise ValueError("Invalid --trainable_layers argument.")
  503. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  504. model.to(device)
  505. ###############################
  506. # Instantiate dataloaders
  507. ###############################
  508. url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
  509. zip_path = "sms_spam_collection.zip"
  510. extract_to = "sms_spam_collection"
  511. new_file_path = Path(extract_to) / "SMSSpamCollection.tsv"
  512. base_path = Path(".")
  513. file_names = ["train.csv", "validation.csv", "test.csv"]
  514. all_exist = all((base_path / file_name).exists() for file_name in file_names)
  515. if not all_exist:
  516. try:
  517. download_and_unzip(url, zip_path, extract_to, new_file_path)
  518. except (urllib.error.HTTPError, urllib.error.URLError, TimeoutError) as e:
  519. print(f"Primary URL failed: {e}. Trying backup URL...")
  520. backup_url = "https://f001.backblazeb2.com/file/LLMs-from-scratch/sms%2Bspam%2Bcollection.zip"
  521. download_and_unzip(backup_url, zip_path, extract_to, new_file_path)
  522. create_dataset_csvs(new_file_path)
  523. tokenizer = tiktoken.get_encoding("gpt2")
  524. train_dataset = None
  525. if args.no_padding:
  526. max_length = None
  527. else:
  528. if args.context_length == "model_context_length":
  529. max_length = model.pos_emb.weight.shape[0]
  530. elif args.context_length == "longest_training_example":
  531. train_dataset = SpamDataset(base_path / "train.csv", max_length=None, tokenizer=tokenizer, no_padding=args.no_padding)
  532. max_length = train_dataset.max_length
  533. else:
  534. try:
  535. max_length = int(args.context_length)
  536. except ValueError:
  537. raise ValueError("Invalid --context_length argument")
  538. if train_dataset is None:
  539. train_dataset = SpamDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
  540. val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
  541. test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
  542. num_workers = 0
  543. train_loader = DataLoader(
  544. dataset=train_dataset,
  545. batch_size=args.batch_size,
  546. shuffle=True,
  547. num_workers=num_workers,
  548. drop_last=True,
  549. )
  550. val_loader = DataLoader(
  551. dataset=val_dataset,
  552. batch_size=args.batch_size,
  553. num_workers=num_workers,
  554. drop_last=False,
  555. )
  556. test_loader = DataLoader(
  557. dataset=test_dataset,
  558. batch_size=args.batch_size,
  559. num_workers=num_workers,
  560. drop_last=False,
  561. )
  562. assert train_dataset.max_length <= model.pos_emb.weight.shape[0], (
  563. f"Dataset length {train_dataset.max_length} exceeds model's context "
  564. f"length {model.pos_emb.weight.shape[0]}. Reinitialize data sets with "
  565. f"`max_length={model.pos_emb.weight.shape[0]}`"
  566. )
  567. ###############################
  568. # Train model
  569. ###############################
  570. start_time = time.time()
  571. torch.manual_seed(123)
  572. optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)
  573. train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
  574. model, train_loader, val_loader, optimizer, device,
  575. num_epochs=args.num_epochs, eval_freq=50, eval_iter=5,
  576. max_steps=None, trainable_token_pos=args.trainable_token_pos,
  577. accumulation_steps=args.accumulation_steps, average_embeddings=args.average_embeddings
  578. )
  579. end_time = time.time()
  580. execution_time_minutes = (end_time - start_time) / 60
  581. print(f"Training completed in {execution_time_minutes:.2f} minutes.")
  582. ###############################
  583. # Evaluate model
  584. ###############################
  585. train_accuracy = calc_accuracy_loader(
  586. train_loader, model, device,
  587. trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
  588. )
  589. val_accuracy = calc_accuracy_loader(
  590. val_loader, model, device,
  591. trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
  592. )
  593. test_accuracy = calc_accuracy_loader(
  594. test_loader, model, device,
  595. trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
  596. )
  597. print(f"Training accuracy: {train_accuracy*100:.2f}%")
  598. print(f"Validation accuracy: {val_accuracy*100:.2f}%")
  599. print(f"Test accuracy: {test_accuracy*100:.2f}%")