|
|
@@ -147,7 +147,7 @@ def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
|
|
|
# plt.show()
|
|
|
|
|
|
|
|
|
-def main():
|
|
|
+def main(test_mode=False):
|
|
|
#######################################
|
|
|
# Print package versions
|
|
|
#######################################
|
|
|
@@ -177,6 +177,12 @@ def main():
|
|
|
test_data = data[train_portion:train_portion + test_portion]
|
|
|
val_data = data[train_portion + test_portion:]
|
|
|
|
|
|
+ # Use very small subset for testing purposes
|
|
|
+ if args.test_mode:
|
|
|
+ train_data = train_data[:10]
|
|
|
+ val_data = val_data[:10]
|
|
|
+ test_data = test_data[:10]
|
|
|
+
|
|
|
print("Training set length:", len(train_data))
|
|
|
print("Validation set length:", len(val_data))
|
|
|
print("Test set length:", len(test_data))
|
|
|
@@ -217,31 +223,50 @@ def main():
|
|
|
#######################################
|
|
|
# Load pretrained model
|
|
|
#######################################
|
|
|
- BASE_CONFIG = {
|
|
|
- "vocab_size": 50257, # Vocabulary size
|
|
|
- "context_length": 1024, # Context length
|
|
|
- "drop_rate": 0.0, # Dropout rate
|
|
|
- "qkv_bias": True # Query-key-value bias
|
|
|
- }
|
|
|
|
|
|
- model_configs = {
|
|
|
- "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
|
|
|
- "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
|
|
|
- "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
|
|
|
- "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
|
|
|
- }
|
|
|
+ # Small GPT model for testing purposes
|
|
|
+ if args.test_mode:
|
|
|
+ BASE_CONFIG = {
|
|
|
+ "vocab_size": 50257,
|
|
|
+ "context_length": 120,
|
|
|
+ "drop_rate": 0.0,
|
|
|
+ "qkv_bias": False,
|
|
|
+ "emb_dim": 12,
|
|
|
+ "n_layers": 1,
|
|
|
+ "n_heads": 2
|
|
|
+ }
|
|
|
+ model = GPTModel(BASE_CONFIG)
|
|
|
+ model.eval()
|
|
|
+ device = "cpu"
|
|
|
+ CHOOSE_MODEL = "Small test model"
|
|
|
+
|
|
|
+ # Code as it is used in the main chapter
|
|
|
+ else:
|
|
|
+ BASE_CONFIG = {
|
|
|
+ "vocab_size": 50257, # Vocabulary size
|
|
|
+ "context_length": 1024, # Context length
|
|
|
+ "drop_rate": 0.0, # Dropout rate
|
|
|
+ "qkv_bias": True # Query-key-value bias
|
|
|
+ }
|
|
|
+
|
|
|
+ model_configs = {
|
|
|
+ "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
|
|
|
+ "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
|
|
|
+ "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
|
|
|
+ "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
|
|
|
+ }
|
|
|
|
|
|
- CHOOSE_MODEL = "gpt2-medium (355M)"
|
|
|
+ CHOOSE_MODEL = "gpt2-medium (355M)"
|
|
|
|
|
|
- BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
|
|
|
+ BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
|
|
|
|
|
|
- model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
|
|
|
- settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
|
|
|
+ model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
|
|
|
+ settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
|
|
|
|
|
|
- model = GPTModel(BASE_CONFIG)
|
|
|
- load_weights_into_gpt(model, params)
|
|
|
- model.eval()
|
|
|
- model.to(device)
|
|
|
+ model = GPTModel(BASE_CONFIG)
|
|
|
+ load_weights_into_gpt(model, params)
|
|
|
+ model.eval()
|
|
|
+ model.to(device)
|
|
|
|
|
|
print("Loaded model:", CHOOSE_MODEL)
|
|
|
print(50*"-")
|
|
|
@@ -259,6 +284,7 @@ def main():
|
|
|
|
|
|
start_time = time.time()
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00005, weight_decay=0.1)
|
|
|
+
|
|
|
num_epochs = 2
|
|
|
|
|
|
torch.manual_seed(123)
|
|
|
@@ -307,4 +333,19 @@ def main():
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- main()
|
|
|
+
|
|
|
+ import argparse
|
|
|
+
|
|
|
+ parser = argparse.ArgumentParser(
|
|
|
+ description="Finetune a GPT model for classification"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--test_mode",
|
|
|
+ default=False,
|
|
|
+ action="store_true",
|
|
|
+ help=("This flag runs the model in test mode for internal testing purposes. "
|
|
|
+ "Otherwise, it runs the model as it is used in the chapter (recommended).")
|
|
|
+ )
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ main(args.test_mode)
|