appendix_e.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. import torch
  6. import math
  7. class LoRALayer(torch.nn.Module):
  8. def __init__(self, in_dim, out_dim, rank, alpha):
  9. super().__init__()
  10. self.A = torch.nn.Parameter(torch.empty(in_dim, rank))
  11. torch.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) # similar to standard weight initialization
  12. self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
  13. self.alpha = alpha
  14. def forward(self, x):
  15. x = self.alpha * (x @ self.A @ self.B)
  16. return x
  17. class LinearWithLoRA(torch.nn.Module):
  18. def __init__(self, linear, rank, alpha):
  19. super().__init__()
  20. self.linear = linear
  21. self.lora = LoRALayer(
  22. linear.in_features, linear.out_features, rank, alpha
  23. )
  24. def forward(self, x):
  25. return self.linear(x) + self.lora(x)
  26. def replace_linear_with_lora(model, rank, alpha):
  27. for name, module in model.named_children():
  28. if isinstance(module, torch.nn.Linear):
  29. # Replace the Linear layer with LinearWithLoRA
  30. setattr(model, name, LinearWithLoRA(module, rank, alpha))
  31. else:
  32. # Recursively apply the same function to child modules
  33. replace_linear_with_lora(module, rank, alpha)