Browse Source

removed old args in GQA class (#674)

casinca 5 months ago
parent
commit
58b8672452
1 changed files with 1 additions and 8 deletions
  1. 1 8
      ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb

+ 1 - 8
ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb

@@ -452,10 +452,8 @@
     "\n",
     "class GroupedQueryAttention(nn.Module):\n",
     "    def __init__(\n",
-    "            self, d_in, d_out, context_length, num_heads,\n",
+    "            self, d_in, d_out, num_heads,\n",
     "            num_kv_groups,       # NEW\n",
-    "            rope_base=10_000,    # NEW\n",
-    "            rope_config=None,    # NEW\n",
     "            dtype=None\n",
     "        ):\n",
     "        super().__init__()\n",
@@ -645,10 +643,8 @@
     "gqa = GroupedQueryAttention(\n",
     "    d_in=embed_dim,\n",
     "    d_out=embed_dim,\n",
-    "    context_length=max_context_len,\n",
     "    num_heads=num_heads,\n",
     "    num_kv_groups=8,\n",
-    "    rope_base=llama_3_theta_base\n",
     ")\n",
     "\n",
     "gqa(example_batch)\n",
@@ -753,11 +749,8 @@
     "        self.att =  GroupedQueryAttention(  # MultiHeadAttention(\n",
     "            d_in=cfg[\"emb_dim\"],\n",
     "            d_out=cfg[\"emb_dim\"],\n",
-    "            context_length=cfg[\"context_length\"],\n",
     "            num_heads=cfg[\"n_heads\"],\n",
     "            num_kv_groups=cfg[\"n_kv_groups\"],  # NEW\n",
-    "            rope_base=cfg[\"rope_base\"],        # NEW\n",
-    "            rope_config=cfg[\"rope_freq\"],      # NEW\n",
     "            dtype=cfg[\"dtype\"]\n",
     "        )\n",
     "        self.ff = FeedForward(cfg)\n",