瀏覽代碼

fixed plot_losses (#677)

Daniel Kleine 5 月之前
父節點
當前提交
15fa6a84f6
共有 1 個文件被更改,包括 1 次插入1 次删除
  1. 1 1
      ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py

+ 1 - 1
ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py

@@ -228,7 +228,7 @@ if __name__ == "__main__":
     )
 
     epochs_tensor = torch.linspace(0, args.n_epochs, len(train_losses))
-    plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses, output_dir)
+    plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)
 
     torch.save(model.state_dict(), output_dir / "model_pg_final.pth")
     print(f"Maximum GPU memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")