test_appendix_e.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. from llms_from_scratch.ch04 import GPTModel
  6. from llms_from_scratch.ch06 import (
  7. download_and_unzip_spam_data, create_balanced_dataset,
  8. random_split, SpamDataset, train_classifier_simple
  9. )
  10. from llms_from_scratch.appendix_e import replace_linear_with_lora
  11. from pathlib import Path
  12. import urllib
  13. import pandas as pd
  14. import tiktoken
  15. import torch
  16. from torch.utils.data import DataLoader, Subset
  17. def test_train_classifier_lora(tmp_path):
  18. ########################################
  19. # Download and prepare dataset
  20. ########################################
  21. url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
  22. zip_path = tmp_path / "sms_spam_collection.zip"
  23. extracted_path = tmp_path / "sms_spam_collection"
  24. data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv"
  25. try:
  26. download_and_unzip_spam_data(
  27. url, zip_path, extracted_path, data_file_path
  28. )
  29. except (urllib.error.HTTPError, urllib.error.URLError, TimeoutError) as e:
  30. print(f"Primary URL failed: {e}. Trying backup URL...")
  31. backup_url = "https://f001.backblazeb2.com/file/LLMs-from-scratch/sms%2Bspam%2Bcollection.zip"
  32. download_and_unzip_spam_data(
  33. backup_url, zip_path, extracted_path, data_file_path
  34. )
  35. df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])
  36. balanced_df = create_balanced_dataset(df)
  37. balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})
  38. train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1)
  39. train_df.to_csv(tmp_path / "train.csv", index=None)
  40. validation_df.to_csv(tmp_path / "validation.csv", index=None)
  41. test_df.to_csv(tmp_path / "test.csv", index=None)
  42. ########################################
  43. # Create data loaders
  44. ########################################
  45. tokenizer = tiktoken.get_encoding("gpt2")
  46. train_dataset = SpamDataset(
  47. csv_file=tmp_path / "train.csv",
  48. max_length=None,
  49. tokenizer=tokenizer
  50. )
  51. val_dataset = SpamDataset(
  52. csv_file=tmp_path / "validation.csv",
  53. max_length=train_dataset.max_length,
  54. tokenizer=tokenizer
  55. )
  56. num_workers = 0
  57. batch_size = 8
  58. torch.manual_seed(123)
  59. train_loader = DataLoader(
  60. dataset=train_dataset,
  61. batch_size=batch_size,
  62. shuffle=True,
  63. num_workers=num_workers,
  64. drop_last=True,
  65. )
  66. val_loader = DataLoader(
  67. dataset=val_dataset,
  68. batch_size=batch_size,
  69. num_workers=num_workers,
  70. drop_last=False,
  71. )
  72. ########################################
  73. # Load pretrained model
  74. ########################################
  75. # Small GPT model for testing purposes
  76. BASE_CONFIG = {
  77. "vocab_size": 50257,
  78. "context_length": 120,
  79. "drop_rate": 0.0,
  80. "qkv_bias": False,
  81. "emb_dim": 12,
  82. "n_layers": 1,
  83. "n_heads": 2
  84. }
  85. model = GPTModel(BASE_CONFIG)
  86. model.eval()
  87. device = "cpu"
  88. ########################################
  89. # Modify and pretrained model
  90. ########################################
  91. for param in model.parameters():
  92. param.requires_grad = False
  93. torch.manual_seed(123)
  94. num_classes = 2
  95. model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=num_classes)
  96. replace_linear_with_lora(model, rank=16, alpha=16)
  97. model.to(device)
  98. for param in model.trf_blocks[-1].parameters():
  99. param.requires_grad = True
  100. for param in model.final_norm.parameters():
  101. param.requires_grad = True
  102. ########################################
  103. # Finetune modified model
  104. ########################################
  105. torch.manual_seed(123)
  106. optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)
  107. train_subset = Subset(train_loader.dataset, range(5))
  108. batch_train_loader = DataLoader(train_subset, batch_size=5)
  109. val_subset = Subset(val_loader.dataset, range(5))
  110. batch_val_loader = DataLoader(val_subset, batch_size=5)
  111. num_epochs = 6
  112. train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
  113. model, batch_train_loader, batch_val_loader, optimizer, device,
  114. num_epochs=num_epochs, eval_freq=1, eval_iter=1,
  115. )
  116. assert round(train_losses[0], 1) == 0.8
  117. assert round(val_losses[0], 1) == 0.8
  118. assert train_losses[-1] < train_losses[0]