|
|
@@ -312,7 +312,7 @@ def generate_and_print_sample(model, device, start_context):
|
|
|
|
|
|
|
|
|
def train_model_simple_with_timing(model, train_loader, val_loader, optimizer, device,
|
|
|
- num_epochs, eval_freq, eval_iter, start_context, tokenizer):
|
|
|
+ num_epochs, eval_freq, eval_iter, start_context):
|
|
|
train_losses, val_losses, track_tokens = [], [], []
|
|
|
total_tokens, global_step, last_tokens = 0, -1, 0
|
|
|
|
|
|
@@ -524,8 +524,6 @@ def main(gpt_config, settings, rank, world_size):
|
|
|
# Train model
|
|
|
##############################
|
|
|
|
|
|
- tokenizer = tiktoken.get_encoding("gpt2")
|
|
|
-
|
|
|
train_losses, val_losses, tokens_seen = train_model_simple_with_timing(
|
|
|
model=model,
|
|
|
train_loader=train_loader,
|
|
|
@@ -536,7 +534,6 @@ def main(gpt_config, settings, rank, world_size):
|
|
|
eval_freq=5,
|
|
|
eval_iter=1,
|
|
|
start_context="Every effort moves you",
|
|
|
- tokenizer=tokenizer
|
|
|
)
|
|
|
|
|
|
# NEW: Clean up distributed processes
|