|
|
@@ -452,10 +452,8 @@
|
|
|
"\n",
|
|
|
"class GroupedQueryAttention(nn.Module):\n",
|
|
|
" def __init__(\n",
|
|
|
- " self, d_in, d_out, context_length, num_heads,\n",
|
|
|
+ " self, d_in, d_out, num_heads,\n",
|
|
|
" num_kv_groups, # NEW\n",
|
|
|
- " rope_base=10_000, # NEW\n",
|
|
|
- " rope_config=None, # NEW\n",
|
|
|
" dtype=None\n",
|
|
|
" ):\n",
|
|
|
" super().__init__()\n",
|
|
|
@@ -645,10 +643,8 @@
|
|
|
"gqa = GroupedQueryAttention(\n",
|
|
|
" d_in=embed_dim,\n",
|
|
|
" d_out=embed_dim,\n",
|
|
|
- " context_length=max_context_len,\n",
|
|
|
" num_heads=num_heads,\n",
|
|
|
" num_kv_groups=8,\n",
|
|
|
- " rope_base=llama_3_theta_base\n",
|
|
|
")\n",
|
|
|
"\n",
|
|
|
"gqa(example_batch)\n",
|
|
|
@@ -753,11 +749,8 @@
|
|
|
" self.att = GroupedQueryAttention( # MultiHeadAttention(\n",
|
|
|
" d_in=cfg[\"emb_dim\"],\n",
|
|
|
" d_out=cfg[\"emb_dim\"],\n",
|
|
|
- " context_length=cfg[\"context_length\"],\n",
|
|
|
" num_heads=cfg[\"n_heads\"],\n",
|
|
|
" num_kv_groups=cfg[\"n_kv_groups\"], # NEW\n",
|
|
|
- " rope_base=cfg[\"rope_base\"], # NEW\n",
|
|
|
- " rope_config=cfg[\"rope_freq\"], # NEW\n",
|
|
|
" dtype=cfg[\"dtype\"]\n",
|
|
|
" )\n",
|
|
|
" self.ff = FeedForward(cfg)\n",
|