Explorar o código

minor readability improvements (#668)

casinca hai 5 meses
pai
achega
00b8c0a107

+ 5 - 8
ch06/02_bonus_additional-experiments/additional_experiments.py

@@ -271,12 +271,11 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None,
                 mask = input_batch != pad_token_id
                 last_token_pos = mask.sum(dim=1) - 1  # Get position of last real token
 
-                with torch.no_grad():
-                    logits = model(input_batch)  # Logits of last output token
-                    # Select the logits corresponding to the last real token of each sequence
-                    batch_size = logits.size(0)
-                    selected_logits = logits[torch.arange(batch_size), last_token_pos]
-                    predicted_labels = torch.argmax(selected_logits, dim=-1)
+                logits = model(input_batch)  # Logits of last output token
+                # Select the logits corresponding to the last real token of each sequence
+                batch_size = logits.size(0)
+                selected_logits = logits[torch.arange(batch_size), last_token_pos]
+                predicted_labels = torch.argmax(selected_logits, dim=-1)
 
                 num_examples += predicted_labels.shape[0]
                 correct_predictions += (predicted_labels == target_batch).sum().item()
@@ -643,8 +642,6 @@ if __name__ == "__main__":
     val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
     test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
 
-    tokenizer = tiktoken.get_encoding("gpt2")
-
     num_workers = 0
 
     train_loader = DataLoader(