فهرست منبع

Add Llama 3.2 RoPE to CI (#391)

* add Llama 3.2 RoPE to CI

* update
Sebastian Raschka 1 سال پیش
والد
کامیت
ec18b6a8a3
1فایلهای تغییر یافته به همراه116 افزوده شده و 35 حذف شده
  1. 116 35
      ch05/07_gpt_to_llama/tests/tests.py

+ 116 - 35
ch05/07_gpt_to_llama/tests/tests.py

@@ -18,39 +18,45 @@ from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply
 
 @pytest.fixture(scope="module")
 def notebook():
-    def import_definitions_from_notebook(fullname, names):
-        # Get the directory of the current test file
-        current_dir = os.path.dirname(__file__)
-        path = os.path.join(current_dir, "..", fullname + ".ipynb")
-        path = os.path.normpath(path)
+    def import_definitions_from_notebook(notebooks):
+        imported_modules = {}
 
-        # Load the notebook
-        if not os.path.exists(path):
-            raise FileNotFoundError(f"Notebook file not found at: {path}")
+        for fullname, names in notebooks.items():
+            # Get the directory of the current test file
+            current_dir = os.path.dirname(__file__)
+            path = os.path.join(current_dir, "..", fullname + ".ipynb")
+            path = os.path.normpath(path)
 
-        with io.open(path, "r", encoding="utf-8") as f:
-            nb = nbformat.read(f, as_version=4)
+            # Load the notebook
+            if not os.path.exists(path):
+                raise FileNotFoundError(f"Notebook file not found at: {path}")
 
-        # Create a module to store the imported functions and classes
-        mod = types.ModuleType(fullname)
-        sys.modules[fullname] = mod
+            with io.open(path, "r", encoding="utf-8") as f:
+                nb = nbformat.read(f, as_version=4)
 
-        # Go through the notebook cells and only execute function or class definitions
-        for cell in nb.cells:
-            if cell.cell_type == "code":
-                cell_code = cell.source
-                for name in names:
-                    # Check for function or class definitions
-                    if f"def {name}" in cell_code or f"class {name}" in cell_code:
-                        exec(cell_code, mod.__dict__)
-        return mod
+            # Create a module to store the imported functions and classes
+            mod = types.ModuleType(fullname)
+            sys.modules[fullname] = mod
 
-    # Specify the notebook name and functions/classes to import
-    fullname = "converting-gpt-to-llama2"
-    names = ["precompute_rope_params", "compute_rope", "SiLU", "RMSNorm"]
+            # Go through the notebook cells and only execute function or class definitions
+            for cell in nb.cells:
+                if cell.cell_type == "code":
+                    cell_code = cell.source
+                    for name in names:
+                        # Check for function or class definitions
+                        if f"def {name}" in cell_code or f"class {name}" in cell_code:
+                            exec(cell_code, mod.__dict__)
 
-    # Import the required functions and classes from the notebook
-    return import_definitions_from_notebook(fullname, names)
+            imported_modules[fullname] = mod
+
+        return imported_modules
+
+    notebooks = {
+        "converting-gpt-to-llama2": ["SiLU", "RMSNorm", "precompute_rope_params", "compute_rope"],
+        "converting-llama2-to-llama3": ["precompute_rope_params"]
+    }
+
+    return import_definitions_from_notebook(notebooks)
 
 
 @pytest.fixture(autouse=True)
@@ -59,6 +65,9 @@ def set_seed():
 
 
 def test_rope_llama2(notebook):
+
+    this_nb = notebook["converting-gpt-to-llama2"]
+
     # Settings
     batch_size = 1
     context_len = 4096
@@ -66,15 +75,15 @@ def test_rope_llama2(notebook):
     head_dim = 16
 
     # Instantiate RoPE parameters
-    cos, sin = notebook.precompute_rope_params(head_dim=head_dim, context_length=context_len)
+    cos, sin = this_nb.precompute_rope_params(head_dim=head_dim, context_length=context_len)
 
     # Dummy query and key tensors
     queries = torch.randn(batch_size, num_heads, context_len, head_dim)
     keys = torch.randn(batch_size, num_heads, context_len, head_dim)
 
     # Apply rotary position embeddings
-    queries_rot = notebook.compute_rope(queries, cos, sin)
-    keys_rot = notebook.compute_rope(keys, cos, sin)
+    queries_rot = this_nb.compute_rope(queries, cos, sin)
+    keys_rot = this_nb.compute_rope(keys, cos, sin)
 
     rot_emb = LlamaRotaryEmbedding(
         dim=head_dim,
@@ -93,6 +102,10 @@ def test_rope_llama2(notebook):
 
 
 def test_rope_llama3(notebook):
+
+    nb1 = notebook["converting-gpt-to-llama2"]
+    nb2 = notebook["converting-llama2-to-llama3"]
+
     # Settings
     batch_size = 1
     context_len = 8192
@@ -101,19 +114,20 @@ def test_rope_llama3(notebook):
     theta_base = 50_000
 
     # Instantiate RoPE parameters
-    cos, sin = notebook.precompute_rope_params(
+    cos, sin = nb2.precompute_rope_params(
         head_dim=head_dim,
         context_length=context_len,
         theta_base=theta_base
     )
 
     # Dummy query and key tensors
+    torch.manual_seed(123)
     queries = torch.randn(batch_size, num_heads, context_len, head_dim)
     keys = torch.randn(batch_size, num_heads, context_len, head_dim)
 
     # Apply rotary position embeddings
-    queries_rot = notebook.compute_rope(queries, cos, sin)
-    keys_rot = notebook.compute_rope(keys, cos, sin)
+    queries_rot = nb1.compute_rope(queries, cos, sin)
+    keys_rot = nb1.compute_rope(keys, cos, sin)
 
     rot_emb = LlamaRotaryEmbedding(
         dim=head_dim,
@@ -131,16 +145,83 @@ def test_rope_llama3(notebook):
     torch.testing.assert_close(queries_rot, ref_queries_rot)
 
 
+def test_rope_llama3_12(notebook):
+
+    nb1 = notebook["converting-gpt-to-llama2"]
+    nb2 = notebook["converting-llama2-to-llama3"]
+
+    # Settings
+    batch_size = 1
+    context_len = 8192
+    num_heads = 4
+    head_dim = 16
+    rope_theta = 50_000
+
+    rope_config = {
+        "factor": 8.0,
+        "low_freq_factor": 1.0,
+        "high_freq_factor": 4.0,
+        "original_context_length": 8192,
+    }
+
+    # Instantiate RoPE parameters
+    cos, sin = nb2.precompute_rope_params(
+        head_dim=head_dim,
+        theta_base=rope_theta,
+        context_length=context_len,
+        freq_config=rope_config,
+    )
+
+    # Dummy query and key tensors
+    torch.manual_seed(123)
+    queries = torch.randn(batch_size, num_heads, context_len, head_dim)
+    keys = torch.randn(batch_size, num_heads, context_len, head_dim)
+
+    # Apply rotary position embeddings
+    queries_rot = nb1.compute_rope(queries, cos, sin)
+    keys_rot = nb1.compute_rope(keys, cos, sin)
+
+    hf_rope_params = {
+        "factor": 8.0,
+        "low_freq_factor": 1.0,
+        "high_freq_factor": 4.0,
+        "original_max_position_embeddings": 8192,
+        "rope_type": "llama3"
+    }
+
+    class RoPEConfig:
+        rope_type = "llama3"
+        rope_scaling = hf_rope_params
+        factor = 1.0
+        dim: int = head_dim
+        rope_theta = 50_000
+        max_position_embeddings: int = 8192
+        hidden_size = head_dim * num_heads
+        num_attention_heads = num_heads
+
+    config = RoPEConfig()
+
+    rot_emb = LlamaRotaryEmbedding(config=config)
+    position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
+    ref_cos, ref_sin = rot_emb(queries, position_ids)
+    ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
+
+    torch.testing.assert_close(sin, ref_sin.squeeze(0))
+    torch.testing.assert_close(cos, ref_cos.squeeze(0))
+    torch.testing.assert_close(keys_rot, ref_keys_rot)
+    torch.testing.assert_close(queries_rot, ref_queries_rot)
+
+
 def test_silu(notebook):
     example_batch = torch.randn(2, 3, 4)
-    silu = notebook.SiLU()
+    silu = notebook["converting-gpt-to-llama2"].SiLU()
     assert torch.allclose(silu(example_batch), torch.nn.functional.silu(example_batch))
 
 
 @pytest.mark.skipif(torch.__version__ < "2.4", reason="Requires PyTorch 2.4 or newer")
 def test_rmsnorm(notebook):
     example_batch = torch.randn(2, 3, 4)
-    rms_norm = notebook.RMSNorm(emb_dim=example_batch.shape[-1], eps=1e-5)
+    rms_norm = notebook["converting-gpt-to-llama2"].RMSNorm(emb_dim=example_batch.shape[-1], eps=1e-5)
     rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5)
 
     assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))