|
|
@@ -58,10 +58,10 @@ def set_seed():
|
|
|
torch.manual_seed(123)
|
|
|
|
|
|
|
|
|
-def test_rope(notebook):
|
|
|
+def test_rope_llama2(notebook):
|
|
|
# Settings
|
|
|
batch_size = 1
|
|
|
- context_len = 5
|
|
|
+ context_len = 4096
|
|
|
num_heads = 4
|
|
|
head_dim = 16
|
|
|
|
|
|
@@ -76,19 +76,51 @@ def test_rope(notebook):
|
|
|
queries_rot = notebook.compute_rope(queries, cos, sin)
|
|
|
keys_rot = notebook.compute_rope(keys, cos, sin)
|
|
|
|
|
|
- class RoPEConfig:
|
|
|
- rope_type = "default"
|
|
|
- rope_scaling = None
|
|
|
- factor = 1.0
|
|
|
- dim: int = head_dim
|
|
|
- rope_theta = 10000
|
|
|
- max_position_embeddings: int = 4096
|
|
|
- hidden_size = head_dim * num_heads
|
|
|
- num_attention_heads = num_heads
|
|
|
+ rot_emb = LlamaRotaryEmbedding(
|
|
|
+ dim=head_dim,
|
|
|
+ max_position_embeddings=context_len,
|
|
|
+ base=10_000
|
|
|
+ )
|
|
|
|
|
|
- config = RoPEConfig()
|
|
|
+ position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
|
|
|
+ ref_cos, ref_sin = rot_emb(queries, position_ids)
|
|
|
+ ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
|
|
|
+
|
|
|
+ torch.testing.assert_close(sin, ref_sin.squeeze(0))
|
|
|
+ torch.testing.assert_close(cos, ref_cos.squeeze(0))
|
|
|
+ torch.testing.assert_close(keys_rot, ref_keys_rot)
|
|
|
+ torch.testing.assert_close(queries_rot, ref_queries_rot)
|
|
|
+
|
|
|
+
|
|
|
+def test_rope_llama3(notebook):
|
|
|
+ # Settings
|
|
|
+ batch_size = 1
|
|
|
+ context_len = 8192
|
|
|
+ num_heads = 4
|
|
|
+ head_dim = 16
|
|
|
+ theta_base = 50_000
|
|
|
+
|
|
|
+ # Instantiate RoPE parameters
|
|
|
+ cos, sin = notebook.precompute_rope_params(
|
|
|
+ head_dim=head_dim,
|
|
|
+ context_length=context_len,
|
|
|
+ theta_base=theta_base
|
|
|
+ )
|
|
|
+
|
|
|
+ # Dummy query and key tensors
|
|
|
+ queries = torch.randn(batch_size, num_heads, context_len, head_dim)
|
|
|
+ keys = torch.randn(batch_size, num_heads, context_len, head_dim)
|
|
|
+
|
|
|
+ # Apply rotary position embeddings
|
|
|
+ queries_rot = notebook.compute_rope(queries, cos, sin)
|
|
|
+ keys_rot = notebook.compute_rope(keys, cos, sin)
|
|
|
+
|
|
|
+ rot_emb = LlamaRotaryEmbedding(
|
|
|
+ dim=head_dim,
|
|
|
+ max_position_embeddings=context_len,
|
|
|
+ base=theta_base
|
|
|
+ )
|
|
|
|
|
|
- rot_emb = LlamaRotaryEmbedding(config=config)
|
|
|
position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
|
|
|
ref_cos, ref_sin = rot_emb(queries, position_ids)
|
|
|
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
|
|
|
@@ -108,7 +140,7 @@ def test_silu(notebook):
|
|
|
@pytest.mark.skipif(torch.__version__ < "2.4", reason="Requires PyTorch 2.4 or newer")
|
|
|
def test_rmsnorm(notebook):
|
|
|
example_batch = torch.randn(2, 3, 4)
|
|
|
- rms_norm = notebook.RMSNorm(emb_dim=example_batch.shape[-1])
|
|
|
- rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-6)
|
|
|
+ rms_norm = notebook.RMSNorm(emb_dim=example_batch.shape[-1], eps=1e-5)
|
|
|
+ rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5)
|
|
|
|
|
|
assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))
|