|
|
@@ -10,11 +10,82 @@ import os
|
|
|
import sys
|
|
|
import types
|
|
|
import nbformat
|
|
|
+from typing import Optional, Tuple
|
|
|
import torch
|
|
|
import pytest
|
|
|
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
|
|
|
|
|
|
|
|
+# LitGPT code from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py
|
|
|
+# LitGPT is licensed under Apache v2: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE
|
|
|
+def litgpt_build_rope_cache(
|
|
|
+ seq_len: int,
|
|
|
+ n_elem: int,
|
|
|
+ device: Optional[torch.device] = None,
|
|
|
+ base: int = 10000,
|
|
|
+ condense_ratio: int = 1,
|
|
|
+ extra_config: Optional[dict] = None,
|
|
|
+) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
+ """
|
|
|
+ Enhanced Transformer with Rotary Position Embedding.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ seq_len (int): Sequence length.
|
|
|
+ n_elem (int): Number of elements (head dimension).
|
|
|
+ device (torch.device, optional): Device for tensor allocations.
|
|
|
+ base (int, optional): Base for computing inverse frequencies.
|
|
|
+ condense_ratio (int, optional): Ratio to condense the position indices.
|
|
|
+ extra_config (dict, optional): Configuration parameters for frequency adjustments (used by Llama 3.1 and 3.2)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Tuple[torch.Tensor, torch.Tensor]: Cosine and sine caches for RoPE.
|
|
|
+ """
|
|
|
+
|
|
|
+ # Compute the inverse frequencies theta
|
|
|
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
|
|
|
+
|
|
|
+ if extra_config is not None:
|
|
|
+ orig_context_len = extra_config["original_max_seq_len"]
|
|
|
+ factor = extra_config["factor"]
|
|
|
+ low_freq_factor = extra_config["low_freq_factor"]
|
|
|
+ high_freq_factor = extra_config["high_freq_factor"]
|
|
|
+
|
|
|
+ wavelen = 2 * torch.pi / theta
|
|
|
+ ratio = orig_context_len / wavelen
|
|
|
+ smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
|
|
+ smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0)
|
|
|
+
|
|
|
+ # Compute adjusted_theta without masked indexing
|
|
|
+ adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta
|
|
|
+ theta = adjusted_theta
|
|
|
+
|
|
|
+ # Create position indices `[0, 1, ..., seq_len - 1]`
|
|
|
+ seq_idx = torch.arange(seq_len, device=device) / condense_ratio
|
|
|
+
|
|
|
+ # Calculate the product of position index and $\theta_i$
|
|
|
+ idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
|
|
|
+
|
|
|
+ return torch.cos(idx_theta), torch.sin(idx_theta)
|
|
|
+
|
|
|
+
|
|
|
+# LitGPT code from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py
|
|
|
+# LitGPT is licensed under Apache v2: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE
|
|
|
+def litgpt_apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
|
|
+ head_size = x.size(-1)
|
|
|
+ x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
|
|
|
+ x2 = x[..., head_size // 2:] # (B, nh, T, hs/2)
|
|
|
+ rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
|
|
|
+ if cos.dim() > 1:
|
|
|
+ # batch dimensions must align
|
|
|
+ # sin/cos are (B, T, hs) so we unsqeeze -3 for nh
|
|
|
+ # we count from back because all of apply_rope does
|
|
|
+ cos = cos.unsqueeze(-3)
|
|
|
+ sin = sin.unsqueeze(-3)
|
|
|
+
|
|
|
+ roped = (x * cos) + (rotated * sin)
|
|
|
+ return roped.to(dtype=x.dtype)
|
|
|
+
|
|
|
+
|
|
|
@pytest.fixture(scope="module")
|
|
|
def notebook():
|
|
|
def import_definitions_from_notebook(notebooks):
|
|
|
@@ -84,21 +155,30 @@ def test_rope_llama2(notebook):
|
|
|
queries_rot = this_nb.compute_rope(queries, cos, sin)
|
|
|
keys_rot = this_nb.compute_rope(keys, cos, sin)
|
|
|
|
|
|
+ # Generate reference RoPE via HF
|
|
|
rot_emb = LlamaRotaryEmbedding(
|
|
|
dim=head_dim,
|
|
|
max_position_embeddings=context_len,
|
|
|
base=10_000
|
|
|
)
|
|
|
-
|
|
|
position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
|
|
|
ref_cos, ref_sin = rot_emb(queries, position_ids)
|
|
|
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
|
|
|
-
|
|
|
torch.testing.assert_close(sin, ref_sin.squeeze(0))
|
|
|
torch.testing.assert_close(cos, ref_cos.squeeze(0))
|
|
|
torch.testing.assert_close(keys_rot, ref_keys_rot)
|
|
|
torch.testing.assert_close(queries_rot, ref_queries_rot)
|
|
|
|
|
|
+ # Generate reference RoPE via LitGPT
|
|
|
+ litgpt_cos, litgpt_sin = litgpt_build_rope_cache(context_len, n_elem=head_dim, base=10_000)
|
|
|
+ litgpt_queries_rot = litgpt_apply_rope(queries, litgpt_cos, litgpt_sin)
|
|
|
+ litgpt_keys_rot = litgpt_apply_rope(keys, litgpt_cos, litgpt_sin)
|
|
|
+
|
|
|
+ torch.testing.assert_close(sin, litgpt_sin)
|
|
|
+ torch.testing.assert_close(cos, litgpt_cos)
|
|
|
+ torch.testing.assert_close(keys_rot, litgpt_keys_rot)
|
|
|
+ torch.testing.assert_close(queries_rot, litgpt_queries_rot)
|
|
|
+
|
|
|
|
|
|
def test_rope_llama3(notebook):
|
|
|
|
|
|
@@ -128,6 +208,7 @@ def test_rope_llama3(notebook):
|
|
|
queries_rot = nb1.compute_rope(queries, cos, sin)
|
|
|
keys_rot = nb1.compute_rope(keys, cos, sin)
|
|
|
|
|
|
+ # Generate reference RoPE via HF
|
|
|
rot_emb = LlamaRotaryEmbedding(
|
|
|
dim=head_dim,
|
|
|
max_position_embeddings=context_len,
|
|
|
@@ -143,6 +224,16 @@ def test_rope_llama3(notebook):
|
|
|
torch.testing.assert_close(keys_rot, ref_keys_rot)
|
|
|
torch.testing.assert_close(queries_rot, ref_queries_rot)
|
|
|
|
|
|
+ # Generate reference RoPE via LitGPT
|
|
|
+ litgpt_cos, litgpt_sin = litgpt_build_rope_cache(context_len, n_elem=head_dim, base=theta_base)
|
|
|
+ litgpt_queries_rot = litgpt_apply_rope(queries, litgpt_cos, litgpt_sin)
|
|
|
+ litgpt_keys_rot = litgpt_apply_rope(keys, litgpt_cos, litgpt_sin)
|
|
|
+
|
|
|
+ torch.testing.assert_close(sin, litgpt_sin)
|
|
|
+ torch.testing.assert_close(cos, litgpt_cos)
|
|
|
+ torch.testing.assert_close(keys_rot, litgpt_keys_rot)
|
|
|
+ torch.testing.assert_close(queries_rot, litgpt_queries_rot)
|
|
|
+
|
|
|
|
|
|
def test_rope_llama3_12(notebook):
|
|
|
|
|
|
@@ -180,6 +271,7 @@ def test_rope_llama3_12(notebook):
|
|
|
queries_rot = nb1.compute_rope(queries, cos, sin)
|
|
|
keys_rot = nb1.compute_rope(keys, cos, sin)
|
|
|
|
|
|
+ # Generate reference RoPE via HF
|
|
|
hf_rope_params = {
|
|
|
"factor": 8.0,
|
|
|
"low_freq_factor": 1.0,
|
|
|
@@ -210,6 +302,28 @@ def test_rope_llama3_12(notebook):
|
|
|
torch.testing.assert_close(keys_rot, ref_keys_rot)
|
|
|
torch.testing.assert_close(queries_rot, ref_queries_rot)
|
|
|
|
|
|
+ # Generate reference RoPE via LitGPT
|
|
|
+ litgpt_rope_config = {
|
|
|
+ "factor": 8.0,
|
|
|
+ "low_freq_factor": 1.0,
|
|
|
+ "high_freq_factor": 4.0,
|
|
|
+ "original_max_seq_len": 8192
|
|
|
+ }
|
|
|
+
|
|
|
+ litgpt_cos, litgpt_sin = litgpt_build_rope_cache(
|
|
|
+ context_len,
|
|
|
+ n_elem=head_dim,
|
|
|
+ base=rope_theta,
|
|
|
+ extra_config=litgpt_rope_config
|
|
|
+ )
|
|
|
+ litgpt_queries_rot = litgpt_apply_rope(queries, litgpt_cos, litgpt_sin)
|
|
|
+ litgpt_keys_rot = litgpt_apply_rope(keys, litgpt_cos, litgpt_sin)
|
|
|
+
|
|
|
+ torch.testing.assert_close(sin, litgpt_sin)
|
|
|
+ torch.testing.assert_close(cos, litgpt_cos)
|
|
|
+ torch.testing.assert_close(keys_rot, litgpt_keys_rot)
|
|
|
+ torch.testing.assert_close(queries_rot, litgpt_queries_rot)
|
|
|
+
|
|
|
|
|
|
def test_silu(notebook):
|
|
|
example_batch = torch.randn(2, 3, 4)
|