|
|
@@ -119,11 +119,11 @@ def train_model_simple(model, optimizer, device, n_epochs,
|
|
|
print(f"Ep {epoch+1} (Step {global_step}): "
|
|
|
f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
|
|
|
|
|
|
- # Generate text passage
|
|
|
- if index % print_sample_iter == 0:
|
|
|
- generate_and_print_sample(
|
|
|
- model, train_loader.dataset.tokenizer, device, start_context
|
|
|
- )
|
|
|
+ # Generate text passage
|
|
|
+ if global_step % print_sample_iter == 0:
|
|
|
+ generate_and_print_sample(
|
|
|
+ model, train_loader.dataset.tokenizer, device, start_context
|
|
|
+ )
|
|
|
|
|
|
if global_step % save_ckpt_freq:
|
|
|
file_name = output_dir / f"model_pg_{global_step}.pth"
|
|
|
@@ -137,7 +137,7 @@ def train_model_simple(model, optimizer, device, n_epochs,
|
|
|
torch.save(model.state_dict(), file_name)
|
|
|
print(f"Saved {file_name}")
|
|
|
|
|
|
- return train_losses, val_losses, tokens_seen
|
|
|
+ return train_losses, val_losses, track_tokens_seen
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
@@ -150,7 +150,7 @@ if __name__ == "__main__":
|
|
|
help='Directory where the model checkpoints will be saved')
|
|
|
parser.add_argument('--n_epochs', type=int, default=1,
|
|
|
help='Number of epochs to train the model')
|
|
|
- parser.add_argument('--print_sample_iter', type=int, default=500,
|
|
|
+ parser.add_argument('--print_sample_iter', type=int, default=1000,
|
|
|
help='Iterations between printing sample outputs')
|
|
|
parser.add_argument('--eval_freq', type=int, default=100,
|
|
|
help='Frequency of evaluations during training')
|
|
|
@@ -205,7 +205,9 @@ if __name__ == "__main__":
|
|
|
start_context="Every effort moves you",
|
|
|
)
|
|
|
|
|
|
- epochs_tensor = torch.linspace(1, args.n_epochs, len(train_losses))
|
|
|
+ epochs_tensor = torch.linspace(0, args.n_epochs, len(train_losses))
|
|
|
+
|
|
|
+ print("debug", epochs_tensor, tokens_seen, train_losses, val_losses, output_dir)
|
|
|
plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses, output_dir)
|
|
|
|
|
|
torch.save(model.state_dict(), output_dir / "model_pg_final.pth")
|