|
|
@@ -293,7 +293,7 @@ def generate_and_print_sample(model, tokenizer, device, start_context):
|
|
|
|
|
|
|
|
|
def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
|
|
|
- fig, ax1 = plt.subplots()
|
|
|
+ fig, ax1 = plt.subplots(figsize=(5, 3))
|
|
|
|
|
|
# Plot training and validation loss against epochs
|
|
|
ax1.plot(epochs_seen, train_losses, label="Training loss")
|