Pārlūkot izejas kodu

Improve rope settings for llama3 (#380)

Sebastian Raschka 1 gadu atpakaļ
vecāks
revīzija
b993c2b25b

+ 3 - 2
.gitignore

@@ -35,8 +35,9 @@ ch05/01_main-chapter-code/model.pth
 ch05/01_main-chapter-code/model_and_optimizer.pth
 ch05/03_bonus_pretraining_on_gutenberg/model_checkpoints
 ch05/06_user_interface/gpt2
-ch05/07_gpt_to_llama/models--meta-llama--Llama-2-7b
-ch05/07_gpt_to_llama/models--meta-llama--Llama-2-7b-chat
+ch05/07_gpt_to_llama/Llama-2-7b
+ch05/07_gpt_to_llama/Llama-2-7b-chat
+ch05/07_gpt_to_llama/.cache
 
 ch06/01_main-chapter-code/gpt2
 ch06/02_bonus_additional-experiments/gpt2

+ 7 - 7
ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb

@@ -180,7 +180,7 @@
     "\n",
     "\n",
     "class RMSNorm(nn.Module):\n",
-    "    def __init__(self, emb_dim, eps=1e-6):\n",
+    "    def __init__(self, emb_dim, eps=1e-5):\n",
     "        super().__init__()\n",
     "        self.eps = eps\n",
     "        self.emb_dim = emb_dim\n",
@@ -216,7 +216,7 @@
     "example_batch = torch.randn(2, 3, 4)\n",
     "\n",
     "rms_norm = RMSNorm(emb_dim=example_batch.shape[-1])\n",
-    "rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-6)\n",
+    "rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5)\n",
     "\n",
     "assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))"
    ]
@@ -417,11 +417,11 @@
    },
    "outputs": [],
    "source": [
-    "def precompute_rope_params(head_dim, context_length=4096):\n",
+    "def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096):\n",
     "    assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
     "\n",
     "    # Compute the inverse frequencies\n",
-    "    inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n",
+    "    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n",
     "\n",
     "    # Generate position indices\n",
     "    positions = torch.arange(context_length)\n",
@@ -1151,7 +1151,7 @@
     "tokenizer_file = hf_hub_download(\n",
     "    repo_id=\"meta-llama/Llama-2-7b\",\n",
     "    filename=\"tokenizer.model\",\n",
-    "    cache_dir=\".\")"
+    "    local_dir=\"Llama-2-7B\")"
    ]
   },
   {
@@ -1285,7 +1285,7 @@
     "weights_file = hf_hub_download(\n",
     "   repo_id=\"meta-llama/Llama-2-7b\",\n",
     "   filename=\"consolidated.00.pth\",\n",
-    "   cache_dir=\".\"\n",
+    "   local_dir=\"Llama-2-7b\"\n",
     ")"
    ]
   },
@@ -1520,7 +1520,7 @@
     "weights_file = hf_hub_download(\n",
     "   repo_id=\"meta-llama/Llama-2-7b-chat\",\n",
     "   filename=\"consolidated.00.pth\",\n",
-    "   cache_dir=\".\"\n",
+    "   lcoal_dir=\"Llama-2-7b-chat\n",
     ")\n",
     "\n",
     "model = Llama2Model(LLAMA2_CONFIG_7B)\n",

+ 47 - 15
ch05/07_gpt_to_llama/tests/tests.py

@@ -58,10 +58,10 @@ def set_seed():
     torch.manual_seed(123)
 
 
-def test_rope(notebook):
+def test_rope_llama2(notebook):
     # Settings
     batch_size = 1
-    context_len = 5
+    context_len = 4096
     num_heads = 4
     head_dim = 16
 
@@ -76,19 +76,51 @@ def test_rope(notebook):
     queries_rot = notebook.compute_rope(queries, cos, sin)
     keys_rot = notebook.compute_rope(keys, cos, sin)
 
-    class RoPEConfig:
-        rope_type = "default"
-        rope_scaling = None
-        factor = 1.0
-        dim: int = head_dim
-        rope_theta = 10000
-        max_position_embeddings: int = 4096
-        hidden_size = head_dim * num_heads
-        num_attention_heads = num_heads
+    rot_emb = LlamaRotaryEmbedding(
+        dim=head_dim,
+        max_position_embeddings=context_len,
+        base=10_000
+    )
 
-    config = RoPEConfig()
+    position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
+    ref_cos, ref_sin = rot_emb(queries, position_ids)
+    ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
+
+    torch.testing.assert_close(sin, ref_sin.squeeze(0))
+    torch.testing.assert_close(cos, ref_cos.squeeze(0))
+    torch.testing.assert_close(keys_rot, ref_keys_rot)
+    torch.testing.assert_close(queries_rot, ref_queries_rot)
+
+
+def test_rope_llama3(notebook):
+    # Settings
+    batch_size = 1
+    context_len = 8192
+    num_heads = 4
+    head_dim = 16
+    theta_base = 50_000
+
+    # Instantiate RoPE parameters
+    cos, sin = notebook.precompute_rope_params(
+        head_dim=head_dim,
+        context_length=context_len,
+        theta_base=theta_base
+    )
+
+    # Dummy query and key tensors
+    queries = torch.randn(batch_size, num_heads, context_len, head_dim)
+    keys = torch.randn(batch_size, num_heads, context_len, head_dim)
+
+    # Apply rotary position embeddings
+    queries_rot = notebook.compute_rope(queries, cos, sin)
+    keys_rot = notebook.compute_rope(keys, cos, sin)
+
+    rot_emb = LlamaRotaryEmbedding(
+        dim=head_dim,
+        max_position_embeddings=context_len,
+        base=theta_base
+    )
 
-    rot_emb = LlamaRotaryEmbedding(config=config)
     position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
     ref_cos, ref_sin = rot_emb(queries, position_ids)
     ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
@@ -108,7 +140,7 @@ def test_silu(notebook):
 @pytest.mark.skipif(torch.__version__ < "2.4", reason="Requires PyTorch 2.4 or newer")
 def test_rmsnorm(notebook):
     example_batch = torch.randn(2, 3, 4)
-    rms_norm = notebook.RMSNorm(emb_dim=example_batch.shape[-1])
-    rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-6)
+    rms_norm = notebook.RMSNorm(emb_dim=example_batch.shape[-1], eps=1e-5)
+    rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5)
 
     assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))