|
|
@@ -81,14 +81,25 @@ def instantiate_model(choose_model, load_weights):
|
|
|
return model
|
|
|
|
|
|
|
|
|
-def calc_loss_batch(input_batch, target_batch, model, device, trainable_token=-1):
|
|
|
+def calc_loss_batch(input_batch, target_batch, model, device,
|
|
|
+ trainable_token_pos=-1, average_embeddings=False):
|
|
|
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
|
|
- logits = model(input_batch)[:, trainable_token, :] # Logits of last output token
|
|
|
+
|
|
|
+ model_output = model(input_batch)
|
|
|
+ if average_embeddings:
|
|
|
+ # Average over the sequence dimension (dim=1)
|
|
|
+ logits = model_output.mean(dim=1)
|
|
|
+ else:
|
|
|
+ # Select embeddings at the specified token position
|
|
|
+ logits = model_output[:, trainable_token_pos, :]
|
|
|
+
|
|
|
loss = torch.nn.functional.cross_entropy(logits, target_batch)
|
|
|
return loss
|
|
|
|
|
|
|
|
|
-def calc_loss_loader(data_loader, model, device, num_batches=None, trainable_token=-1):
|
|
|
+def calc_loss_loader(data_loader, model, device,
|
|
|
+ num_batches=None, trainable_token_pos=-1,
|
|
|
+ average_embeddings=False):
|
|
|
total_loss = 0.
|
|
|
if len(data_loader) == 0:
|
|
|
return float("nan")
|
|
|
@@ -100,7 +111,10 @@ def calc_loss_loader(data_loader, model, device, num_batches=None, trainable_tok
|
|
|
num_batches = min(num_batches, len(data_loader))
|
|
|
for i, (input_batch, target_batch) in enumerate(data_loader):
|
|
|
if i < num_batches:
|
|
|
- loss = calc_loss_batch(input_batch, target_batch, model, device, trainable_token=trainable_token)
|
|
|
+ loss = calc_loss_batch(
|
|
|
+ input_batch, target_batch, model, device,
|
|
|
+ trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
|
|
+ )
|
|
|
total_loss += loss.item()
|
|
|
else:
|
|
|
break
|
|
|
@@ -108,7 +122,9 @@ def calc_loss_loader(data_loader, model, device, num_batches=None, trainable_tok
|
|
|
|
|
|
|
|
|
@torch.no_grad() # Disable gradient tracking for efficiency
|
|
|
-def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable_token=-1):
|
|
|
+def calc_accuracy_loader(data_loader, model, device,
|
|
|
+ num_batches=None, trainable_token_pos=-1,
|
|
|
+ average_embeddings=False):
|
|
|
model.eval()
|
|
|
correct_predictions, num_examples = 0, 0
|
|
|
|
|
|
@@ -119,7 +135,15 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable
|
|
|
for i, (input_batch, target_batch) in enumerate(data_loader):
|
|
|
if i < num_batches:
|
|
|
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
|
|
- logits = model(input_batch)[:, trainable_token, :] # Logits of last output token
|
|
|
+
|
|
|
+ model_output = model(input_batch)
|
|
|
+ if average_embeddings:
|
|
|
+ # Average over the sequence dimension (dim=1)
|
|
|
+ logits = model_output.mean(dim=1)
|
|
|
+ else:
|
|
|
+ # Select embeddings at the specified token position
|
|
|
+ logits = model_output[:, trainable_token_pos, :]
|
|
|
+
|
|
|
predicted_labels = torch.argmax(logits, dim=-1)
|
|
|
|
|
|
num_examples += predicted_labels.shape[0]
|
|
|
@@ -129,17 +153,25 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable
|
|
|
return correct_predictions / num_examples
|
|
|
|
|
|
|
|
|
-def evaluate_model(model, train_loader, val_loader, device, eval_iter, trainable_token=-1):
|
|
|
+def evaluate_model(model, train_loader, val_loader, device, eval_iter,
|
|
|
+ trainable_token_pos=-1, average_embeddings=False):
|
|
|
model.eval()
|
|
|
with torch.no_grad():
|
|
|
- train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
|
|
|
- val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
|
|
|
+ train_loss = calc_loss_loader(
|
|
|
+ train_loader, model, device, num_batches=eval_iter,
|
|
|
+ trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
|
|
+ )
|
|
|
+ val_loss = calc_loss_loader(
|
|
|
+ val_loader, model, device, num_batches=eval_iter,
|
|
|
+ trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
|
|
+ )
|
|
|
model.train()
|
|
|
return train_loss, val_loss
|
|
|
|
|
|
|
|
|
def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
|
|
|
- eval_freq, eval_iter, max_steps=None, trainable_token=-1):
|
|
|
+ eval_freq, eval_iter, max_steps=None, trainable_token_pos=-1,
|
|
|
+ average_embeddings=False):
|
|
|
# Initialize lists to track losses and tokens seen
|
|
|
train_losses, val_losses, train_accs, val_accs = [], [], [], []
|
|
|
examples_seen, global_step = 0, -1
|
|
|
@@ -150,7 +182,8 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
|
|
|
|
|
|
for input_batch, target_batch in train_loader:
|
|
|
optimizer.zero_grad() # Reset loss gradients from previous batch iteration
|
|
|
- loss = calc_loss_batch(input_batch, target_batch, model, device, trainable_token=trainable_token)
|
|
|
+ loss = calc_loss_batch(input_batch, target_batch, model, device,
|
|
|
+ trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings)
|
|
|
loss.backward() # Calculate loss gradients
|
|
|
optimizer.step() # Update model weights using loss gradients
|
|
|
examples_seen += input_batch.shape[0] # New: track examples instead of tokens
|
|
|
@@ -159,7 +192,9 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
|
|
|
# Optional evaluation step
|
|
|
if global_step % eval_freq == 0:
|
|
|
train_loss, val_loss = evaluate_model(
|
|
|
- model, train_loader, val_loader, device, eval_iter, trainable_token=trainable_token)
|
|
|
+ model, train_loader, val_loader, device, eval_iter,
|
|
|
+ trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
|
|
+ )
|
|
|
train_losses.append(train_loss)
|
|
|
val_losses.append(val_loss)
|
|
|
print(f"Ep {epoch+1} (Step {global_step:06d}): "
|
|
|
@@ -169,8 +204,14 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
|
|
|
break
|
|
|
|
|
|
# New: Calculate accuracy after each epoch
|
|
|
- train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
|
|
|
- val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
|
|
|
+ train_accuracy = calc_accuracy_loader(
|
|
|
+ train_loader, model, device, num_batches=eval_iter,
|
|
|
+ trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
|
|
+ )
|
|
|
+ val_accuracy = calc_accuracy_loader(
|
|
|
+ val_loader, model, device, num_batches=eval_iter,
|
|
|
+ trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
|
|
|
+ )
|
|
|
print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
|
|
|
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
|
|
|
train_accs.append(train_accuracy)
|
|
|
@@ -211,13 +252,22 @@ if __name__ == "__main__":
|
|
|
)
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
- "--trainable_token",
|
|
|
+ "--trainable_token_pos",
|
|
|
type=str,
|
|
|
default="last",
|
|
|
help=(
|
|
|
"Which token to train. Options: 'first', 'last'."
|
|
|
)
|
|
|
)
|
|
|
+ parser.add_argument(
|
|
|
+ "--average_embeddings",
|
|
|
+ action='store_true',
|
|
|
+ default=False,
|
|
|
+ help=(
|
|
|
+ "Average the output embeddings from all tokens instead of using"
|
|
|
+ " only the embedding at the token position specified by `--trainable_token_pos`."
|
|
|
+ )
|
|
|
+ )
|
|
|
parser.add_argument(
|
|
|
"--context_length",
|
|
|
type=str,
|
|
|
@@ -245,12 +295,12 @@ if __name__ == "__main__":
|
|
|
)
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
- if args.trainable_token == "first":
|
|
|
- args.trainable_token = 0
|
|
|
- elif args.trainable_token == "last":
|
|
|
- args.trainable_token = -1
|
|
|
+ if args.trainable_token_pos == "first":
|
|
|
+ args.trainable_token_pos = 0
|
|
|
+ elif args.trainable_token_pos == "last":
|
|
|
+ args.trainable_token_pos = -1
|
|
|
else:
|
|
|
- raise ValueError("Invalid --trainable_token argument")
|
|
|
+ raise ValueError("Invalid --trainable_token_pos argument")
|
|
|
|
|
|
###############################
|
|
|
# Load model
|
|
|
@@ -358,7 +408,8 @@ if __name__ == "__main__":
|
|
|
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
|
|
|
model, train_loader, val_loader, optimizer, device,
|
|
|
num_epochs=args.num_epochs, eval_freq=50, eval_iter=20,
|
|
|
- max_steps=None, trainable_token=args.trainable_token
|
|
|
+ max_steps=None, trainable_token_pos=args.trainable_token_pos,
|
|
|
+ average_embeddings=args.average_embeddings
|
|
|
)
|
|
|
|
|
|
end_time = time.time()
|
|
|
@@ -371,9 +422,18 @@ if __name__ == "__main__":
|
|
|
|
|
|
print("\nEvaluating on the full datasets ...\n")
|
|
|
|
|
|
- train_accuracy = calc_accuracy_loader(train_loader, model, device, trainable_token=args.trainable_token)
|
|
|
- val_accuracy = calc_accuracy_loader(val_loader, model, device, trainable_token=args.trainable_token)
|
|
|
- test_accuracy = calc_accuracy_loader(test_loader, model, device, trainable_token=args.trainable_token)
|
|
|
+ train_accuracy = calc_accuracy_loader(
|
|
|
+ train_loader, model, device,
|
|
|
+ trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
|
|
|
+ )
|
|
|
+ val_accuracy = calc_accuracy_loader(
|
|
|
+ val_loader, model, device,
|
|
|
+ trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
|
|
|
+ )
|
|
|
+ test_accuracy = calc_accuracy_loader(
|
|
|
+ test_loader, model, device,
|
|
|
+ trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
|
|
|
+ )
|
|
|
|
|
|
print(f"Training accuracy: {train_accuracy*100:.2f}%")
|
|
|
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
|