소스 검색

fix plotting

rasbt 1 년 전
부모
커밋
ee8efcbcf6
2개의 변경된 파일13개의 추가작업 그리고 10개의 파일을 삭제
  1. 10 8
      ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py
  2. 3 2
      ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py

+ 10 - 8
ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py

@@ -119,11 +119,11 @@ def train_model_simple(model, optimizer, device, n_epochs,
                         print(f"Ep {epoch+1} (Step {global_step}): "
                               f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
 
-                # Generate text passage
-                if index % print_sample_iter == 0:
-                    generate_and_print_sample(
-                        model, train_loader.dataset.tokenizer, device, start_context
-                    )
+                    # Generate text passage
+                    if global_step % print_sample_iter == 0:
+                        generate_and_print_sample(
+                            model, train_loader.dataset.tokenizer, device, start_context
+                        )
 
                 if global_step % save_ckpt_freq:
                     file_name = output_dir / f"model_pg_{global_step}.pth"
@@ -137,7 +137,7 @@ def train_model_simple(model, optimizer, device, n_epochs,
         torch.save(model.state_dict(), file_name)
         print(f"Saved {file_name}")
 
-    return train_losses, val_losses, tokens_seen
+    return train_losses, val_losses, track_tokens_seen
 
 
 if __name__ == "__main__":
@@ -150,7 +150,7 @@ if __name__ == "__main__":
                         help='Directory where the model checkpoints will be saved')
     parser.add_argument('--n_epochs', type=int, default=1,
                         help='Number of epochs to train the model')
-    parser.add_argument('--print_sample_iter', type=int, default=500,
+    parser.add_argument('--print_sample_iter', type=int, default=1000,
                         help='Iterations between printing sample outputs')
     parser.add_argument('--eval_freq', type=int, default=100,
                         help='Frequency of evaluations during training')
@@ -205,7 +205,9 @@ if __name__ == "__main__":
         start_context="Every effort moves you",
     )
 
-    epochs_tensor = torch.linspace(1, args.n_epochs, len(train_losses))
+    epochs_tensor = torch.linspace(0, args.n_epochs, len(train_losses))
+
+    print("debug", epochs_tensor, tokens_seen, train_losses, val_losses, output_dir)
     plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses, output_dir)
 
     torch.save(model.state_dict(), output_dir / "model_pg_final.pth")

+ 3 - 2
ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py

@@ -274,8 +274,9 @@ def generate_and_print_sample(model, tokenizer, device, start_context):
     context_size = model.pos_emb.weight.shape[0]
     encoded = text_to_token_ids(start_context, tokenizer).to(device)
     with torch.no_grad():
-        token_ids = generate_text_simple(model=model, idx=encoded,
-                                   max_new_tokens=50, context_size=context_size)
+        token_ids = generate_text_simple(
+            model=model, idx=encoded,
+            max_new_tokens=50, context_size=context_size)
         decoded_text = token_ids_to_text(token_ids, tokenizer)
         print(decoded_text.replace("\n", " "))  # Compact print format
     model.train()