|
@@ -271,12 +271,11 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None,
|
|
|
mask = input_batch != pad_token_id
|
|
mask = input_batch != pad_token_id
|
|
|
last_token_pos = mask.sum(dim=1) - 1 # Get position of last real token
|
|
last_token_pos = mask.sum(dim=1) - 1 # Get position of last real token
|
|
|
|
|
|
|
|
- with torch.no_grad():
|
|
|
|
|
- logits = model(input_batch) # Logits of last output token
|
|
|
|
|
- # Select the logits corresponding to the last real token of each sequence
|
|
|
|
|
- batch_size = logits.size(0)
|
|
|
|
|
- selected_logits = logits[torch.arange(batch_size), last_token_pos]
|
|
|
|
|
- predicted_labels = torch.argmax(selected_logits, dim=-1)
|
|
|
|
|
|
|
+ logits = model(input_batch) # Logits of last output token
|
|
|
|
|
+ # Select the logits corresponding to the last real token of each sequence
|
|
|
|
|
+ batch_size = logits.size(0)
|
|
|
|
|
+ selected_logits = logits[torch.arange(batch_size), last_token_pos]
|
|
|
|
|
+ predicted_labels = torch.argmax(selected_logits, dim=-1)
|
|
|
|
|
|
|
|
num_examples += predicted_labels.shape[0]
|
|
num_examples += predicted_labels.shape[0]
|
|
|
correct_predictions += (predicted_labels == target_batch).sum().item()
|
|
correct_predictions += (predicted_labels == target_batch).sum().item()
|
|
@@ -643,8 +642,6 @@ if __name__ == "__main__":
|
|
|
val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
|
|
val_dataset = SpamDataset(base_path / "validation.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
|
|
|
test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
|
|
test_dataset = SpamDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer, no_padding=args.no_padding)
|
|
|
|
|
|
|
|
- tokenizer = tiktoken.get_encoding("gpt2")
|
|
|
|
|
-
|
|
|
|
|
num_workers = 0
|
|
num_workers = 0
|
|
|
|
|
|
|
|
train_loader = DataLoader(
|
|
train_loader = DataLoader(
|