test_ch03.py 603 B

12345678910111213141516171819202122
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. from llms_from_scratch.ch03 import MultiHeadAttention
  6. import torch
  7. def test_mha():
  8. context_length = 100
  9. d_in = 256
  10. d_out = 16
  11. mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
  12. batch = torch.rand(8, 6, d_in)
  13. context_vecs = mha(batch)
  14. context_vecs.shape == torch.Size([8, 6, d_out])