|
@@ -8,6 +8,7 @@
|
|
|
from functools import partial
|
|
from functools import partial
|
|
|
from importlib.metadata import version
|
|
from importlib.metadata import version
|
|
|
import json
|
|
import json
|
|
|
|
|
+import math
|
|
|
import os
|
|
import os
|
|
|
import re
|
|
import re
|
|
|
import time
|
|
import time
|
|
@@ -107,6 +108,41 @@ class InstructionDatasetPhi(Dataset):
|
|
|
return len(self.data)
|
|
return len(self.data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+class LinearWithLoRA(torch.nn.Module):
|
|
|
|
|
+ def __init__(self, linear, rank, alpha):
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.linear = linear
|
|
|
|
|
+ self.lora = LoRALayer(
|
|
|
|
|
+ linear.in_features, linear.out_features, rank, alpha
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x):
|
|
|
|
|
+ return self.linear(x) + self.lora(x)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class LoRALayer(torch.nn.Module):
|
|
|
|
|
+ def __init__(self, in_dim, out_dim, rank, alpha):
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.A = torch.nn.Parameter(torch.empty(in_dim, rank))
|
|
|
|
|
+ torch.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) # similar to standard weight initialization
|
|
|
|
|
+ self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
|
|
|
|
|
+ self.alpha = alpha
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x):
|
|
|
|
|
+ x = self.alpha * (x @ self.A @ self.B)
|
|
|
|
|
+ return x
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def replace_linear_with_lora(model, rank, alpha):
|
|
|
|
|
+ for name, module in model.named_children():
|
|
|
|
|
+ if isinstance(module, torch.nn.Linear):
|
|
|
|
|
+ # Replace the Linear layer with LinearWithLoRA
|
|
|
|
|
+ setattr(model, name, LinearWithLoRA(module, rank, alpha))
|
|
|
|
|
+ else:
|
|
|
|
|
+ # Recursively apply the same function to child modules
|
|
|
|
|
+ replace_linear_with_lora(module, rank, alpha)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def custom_collate_fn(
|
|
def custom_collate_fn(
|
|
|
batch,
|
|
batch,
|
|
|
pad_token_id=50256,
|
|
pad_token_id=50256,
|
|
@@ -256,7 +292,7 @@ def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses, plot_name):
|
|
|
# plt.show()
|
|
# plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
-def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False):
|
|
|
|
|
|
|
+def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False, lora=False):
|
|
|
#######################################
|
|
#######################################
|
|
|
# Print package versions
|
|
# Print package versions
|
|
|
#######################################
|
|
#######################################
|
|
@@ -379,6 +415,21 @@ def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False):
|
|
|
print("Loaded model:", CHOOSE_MODEL)
|
|
print("Loaded model:", CHOOSE_MODEL)
|
|
|
print(50*"-")
|
|
print(50*"-")
|
|
|
|
|
|
|
|
|
|
+ if lora:
|
|
|
|
|
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
|
+ print(f"Total trainable parameters before: {total_params:,}")
|
|
|
|
|
+
|
|
|
|
|
+ for param in model.parameters():
|
|
|
|
|
+ param.requires_grad = False
|
|
|
|
|
+
|
|
|
|
|
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
|
+ print(f"Total trainable parameters after: {total_params:,}")
|
|
|
|
|
+ replace_linear_with_lora(model, rank=16, alpha=16)
|
|
|
|
|
+
|
|
|
|
|
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
|
+ print(f"Total trainable LoRA parameters: {total_params:,}")
|
|
|
|
|
+ model.to(device)
|
|
|
|
|
+
|
|
|
#######################################
|
|
#######################################
|
|
|
# Finetuning the model
|
|
# Finetuning the model
|
|
|
#######################################
|
|
#######################################
|
|
@@ -418,7 +469,9 @@ def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False):
|
|
|
plot_name = plot_name.replace(".pdf", "-alpaca52k.pdf")
|
|
plot_name = plot_name.replace(".pdf", "-alpaca52k.pdf")
|
|
|
if phi3_prompt:
|
|
if phi3_prompt:
|
|
|
plot_name = plot_name.replace(".pdf", "-phi3-prompt.pdf")
|
|
plot_name = plot_name.replace(".pdf", "-phi3-prompt.pdf")
|
|
|
- if not any([mask_instructions, alpaca52k, phi3_prompt]):
|
|
|
|
|
|
|
+ if lora:
|
|
|
|
|
+ plot_name = plot_name.replace(".pdf", "-lora.pdf")
|
|
|
|
|
+ if not any([mask_instructions, alpaca52k, phi3_prompt, lora]):
|
|
|
plot_name = plot_name.replace(".pdf", "-baseline.pdf")
|
|
plot_name = plot_name.replace(".pdf", "-baseline.pdf")
|
|
|
|
|
|
|
|
plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses, plot_name)
|
|
plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses, plot_name)
|
|
@@ -460,7 +513,10 @@ def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False):
|
|
|
if phi3_prompt:
|
|
if phi3_prompt:
|
|
|
test_data_path = test_data_path.replace(".json", "-phi3-prompt.json")
|
|
test_data_path = test_data_path.replace(".json", "-phi3-prompt.json")
|
|
|
file_name = file_name.replace(".pth", "-phi3-prompt.pth")
|
|
file_name = file_name.replace(".pth", "-phi3-prompt.pth")
|
|
|
- if not any([mask_instructions, alpaca52k, phi3_prompt]):
|
|
|
|
|
|
|
+ if lora:
|
|
|
|
|
+ test_data_path = test_data_path.replace(".json", "-lora.json")
|
|
|
|
|
+ file_name = file_name.replace(".pth", "-lora.pth")
|
|
|
|
|
+ if not any([mask_instructions, alpaca52k, phi3_prompt, lora]):
|
|
|
test_data_path = test_data_path.replace(".json", "-baseline.json")
|
|
test_data_path = test_data_path.replace(".json", "-baseline.json")
|
|
|
file_name = file_name.replace(".pth", "-baseline.pth")
|
|
file_name = file_name.replace(".pth", "-baseline.pth")
|
|
|
|
|
|
|
@@ -479,7 +535,7 @@ if __name__ == "__main__":
|
|
|
parser = argparse.ArgumentParser(
|
|
parser = argparse.ArgumentParser(
|
|
|
description="Instruction finetune a GPT model"
|
|
description="Instruction finetune a GPT model"
|
|
|
)
|
|
)
|
|
|
- options = {"baseline", "mask_instructions", "alpaca_52k", "phi3_prompt"}
|
|
|
|
|
|
|
+ options = {"baseline", "mask_instructions", "alpaca_52k", "phi3_prompt", "lora"}
|
|
|
parser.add_argument(
|
|
parser.add_argument(
|
|
|
"--exercise_solution",
|
|
"--exercise_solution",
|
|
|
type=str,
|
|
type=str,
|
|
@@ -498,5 +554,7 @@ if __name__ == "__main__":
|
|
|
main(alpaca52k=True)
|
|
main(alpaca52k=True)
|
|
|
elif args.exercise_solution == "phi3_prompt":
|
|
elif args.exercise_solution == "phi3_prompt":
|
|
|
main(phi3_prompt=True)
|
|
main(phi3_prompt=True)
|
|
|
|
|
+ elif args.exercise_solution == "lora":
|
|
|
|
|
+ main(lora=True)
|
|
|
else:
|
|
else:
|
|
|
raise ValueError(f"{args.exercise_solution} is not a valid --args.exercise_solution option. Options: {options}")
|
|
raise ValueError(f"{args.exercise_solution} is not a valid --args.exercise_solution option. Options: {options}")
|