|
|
@@ -430,6 +430,14 @@
|
|
|
"- In short, the main change in GQA is that each query group needs to be repeated to match the number of heads it is associated with, as implemented below"
|
|
|
]
|
|
|
},
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "842aa71a-4659-424e-8830-392bd6ae86af",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "- In addition, we also introduce a `SharedBuffers` class that will allow us to reuse the `mask`, `cos`, and `sin` tensors in the transformer blocks to improve efficiency (this will be crucial when working with models such as Llama 3.1 and 3.2 later, which support up to 131k input tokens)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
"execution_count": 8,
|
|
|
@@ -441,6 +449,28 @@
|
|
|
"source": [
|
|
|
"import torch.nn as nn\n",
|
|
|
"\n",
|
|
|
+ "\n",
|
|
|
+ "############################# NEW #############################\n",
|
|
|
+ "class SharedBuffers:\n",
|
|
|
+ " _buffers = {}\n",
|
|
|
+ "\n",
|
|
|
+ " @staticmethod\n",
|
|
|
+ " def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):\n",
|
|
|
+ " key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)\n",
|
|
|
+ "\n",
|
|
|
+ " if key not in SharedBuffers._buffers:\n",
|
|
|
+ " # Create or fetch the buffers\n",
|
|
|
+ " mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
|
|
|
+ " cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)\n",
|
|
|
+ " if dtype is not None:\n",
|
|
|
+ " cos = cos.to(dtype)\n",
|
|
|
+ " sin = sin.to(dtype)\n",
|
|
|
+ " SharedBuffers._buffers[key] = (mask, cos, sin)\n",
|
|
|
+ "\n",
|
|
|
+ " return SharedBuffers._buffers[key]\n",
|
|
|
+ "############################# NEW #############################\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
"class GroupedQueryAttention(nn.Module):\n",
|
|
|
" def __init__(\n",
|
|
|
" self, d_in, d_out, context_length, num_heads,\n",
|
|
|
@@ -469,13 +499,12 @@
|
|
|
" self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
|
|
|
" self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n",
|
|
|
"\n",
|
|
|
- " self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
|
|
|
- " cos, sin = precompute_rope_params(\n",
|
|
|
- " head_dim=self.head_dim,\n",
|
|
|
- " theta_base=rope_base, # NEW\n",
|
|
|
- " freq_config=rope_config, # NEW\n",
|
|
|
- " context_length=8192\n",
|
|
|
- " )\n",
|
|
|
+ " ############################# NEW #############################\n",
|
|
|
+ " # Fetch buffers using SharedBuffers\n",
|
|
|
+ " mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n",
|
|
|
+ " ############################# NEW #############################\n",
|
|
|
+ " \n",
|
|
|
+ " self.register_buffer(\"mask\", mask)\n",
|
|
|
" self.register_buffer(\"cos\", cos)\n",
|
|
|
" self.register_buffer(\"sin\", sin)\n",
|
|
|
"\n",
|
|
|
@@ -907,6 +936,35 @@
|
|
|
"model = Llama3Model(LLAMA3_CONFIG_8B)"
|
|
|
]
|
|
|
},
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "edea6334-d1fc-427d-9cf2-4af963ff4bfc",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "- The following is expected to print True to confirm buffers are reused instead of being (wastefully) recreated:"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "ee9625cc-9afa-4b11-8aab-d536fd170761",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# Check buffers\n",
|
|
|
+ "print(model.trf_blocks[0].att.mask is model.trf_blocks[-1].att.mask)\n",
|
|
|
+ "print(model.trf_blocks[0].att.cos is model.trf_blocks[-1].att.cos)\n",
|
|
|
+ "print(model.trf_blocks[0].att.sin is model.trf_blocks[-1].att.sin) "
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "8056a521-91a6-440f-8473-591409c3177b",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "- Let's now also compute the number of trainable parameters:"
|
|
|
+ ]
|
|
|
+ },
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
"execution_count": 18,
|
|
|
@@ -2008,16 +2066,16 @@
|
|
|
"}\n",
|
|
|
"\n",
|
|
|
"LLAMA31_CONFIG_8B = {\n",
|
|
|
- " \"vocab_size\": 128_256, # Vocabulary size\n",
|
|
|
- " \"context_length\": 8192, # Context length\n",
|
|
|
- " \"emb_dim\": 4096, # Embedding dimension\n",
|
|
|
- " \"n_heads\": 32, # Number of attention heads\n",
|
|
|
- " \"n_layers\": 32, # Number of layers\n",
|
|
|
- " \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
|
|
|
- " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
|
|
- " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
|
|
|
- " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
|
|
|
- " \"rope_freq\": { # NEW: RoPE frequency scaling\n",
|
|
|
+ " \"vocab_size\": 128_256, # Vocabulary size\n",
|
|
|
+ " \"context_length\": 131_072, # NEW: Larger supported context length\n",
|
|
|
+ " \"emb_dim\": 4096, # Embedding dimension\n",
|
|
|
+ " \"n_heads\": 32, # Number of attention heads\n",
|
|
|
+ " \"n_layers\": 32, # Number of layers\n",
|
|
|
+ " \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
|
|
|
+ " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
|
|
+ " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
|
|
|
+ " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
|
|
|
+ " \"rope_freq\": { # NEW: RoPE frequency scaling\n",
|
|
|
" \"factor\": 8.0,\n",
|
|
|
" \"low_freq_factor\": 1.0,\n",
|
|
|
" \"high_freq_factor\": 4.0,\n",
|
|
|
@@ -2026,6 +2084,24 @@
|
|
|
"}"
|
|
|
]
|
|
|
},
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "d81ee464-c112-43b0-9ee8-70df6ac942d0",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "9bdbe32f-4c96-4e60-8bf4-52b5217df1e6",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "LLAMA32_CONFIG[\"context_length\"] = 8192"
|
|
|
+ ]
|
|
|
+ },
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
"id": "xa3bpMDtTdBs",
|
|
|
@@ -2338,16 +2414,16 @@
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
"LLAMA31_CONFIG_8B = {\n",
|
|
|
- " \"vocab_size\": 128_256, # Vocabulary size\n",
|
|
|
- " \"context_length\": 8192, # Context length\n",
|
|
|
- " \"emb_dim\": 4096, # Embedding dimension\n",
|
|
|
- " \"n_heads\": 32, # Number of attention heads\n",
|
|
|
- " \"n_layers\": 32, # Number of layers\n",
|
|
|
- " \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
|
|
|
- " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
|
|
- " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
|
|
|
- " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
|
|
|
- " \"rope_freq\": { # RoPE frequency scaling\n",
|
|
|
+ " \"vocab_size\": 128_256, # Vocabulary size\n",
|
|
|
+ " \"context_length\": 131_072, # NEW: Larger supported context length\n",
|
|
|
+ " \"emb_dim\": 4096, # Embedding dimension\n",
|
|
|
+ " \"n_heads\": 32, # Number of attention heads\n",
|
|
|
+ " \"n_layers\": 32, # Number of layers\n",
|
|
|
+ " \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
|
|
|
+ " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
|
|
+ " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
|
|
|
+ " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
|
|
|
+ " \"rope_freq\": { # NEW: RoPE frequency scaling\n",
|
|
|
" \"factor\": 8.0,\n",
|
|
|
" \"low_freq_factor\": 1.0,\n",
|
|
|
" \"high_freq_factor\": 4.0,\n",
|
|
|
@@ -2357,17 +2433,17 @@
|
|
|
"\n",
|
|
|
"\n",
|
|
|
"LLAMA32_CONFIG_1B = {\n",
|
|
|
- " \"vocab_size\": 128_256, # Vocabulary size\n",
|
|
|
- " \"context_length\": 8192, # Context length\n",
|
|
|
- " \"emb_dim\": 2048, # NEW: Half the embedding dimension\n",
|
|
|
- " \"n_heads\": 32, # Number of attention heads\n",
|
|
|
- " \"n_layers\": 16, # NEW: Half the number of layers\n",
|
|
|
- " \"hidden_dim\": 8192, # NEW: Almost half the size of the intermediate dimension in FeedForward\n",
|
|
|
- " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
|
|
- " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
|
|
|
- " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
|
|
|
- " \"rope_freq\": { # RoPE frequency scaling\n",
|
|
|
- " \"factor\": 32.0, # NEW: Adjustment of the rescaling factor\n",
|
|
|
+ " \"vocab_size\": 128_256, # Vocabulary size\n",
|
|
|
+ " \"context_length\": 131_072, # Context length\n",
|
|
|
+ " \"emb_dim\": 2048, # NEW: Half the embedding dimension\n",
|
|
|
+ " \"n_heads\": 32, # Number of attention heads\n",
|
|
|
+ " \"n_layers\": 16, # NEW: Half the number of layers\n",
|
|
|
+ " \"hidden_dim\": 8192, # NEW: Almost half the size of the intermediate dimension in FeedForward\n",
|
|
|
+ " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
|
|
+ " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
|
|
|
+ " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
|
|
|
+ " \"rope_freq\": { # RoPE frequency scaling\n",
|
|
|
+ " \"factor\": 32.0, # NEW: Adjustment of the rescaling factor\n",
|
|
|
" \"low_freq_factor\": 1.0,\n",
|
|
|
" \"high_freq_factor\": 4.0,\n",
|
|
|
" \"original_context_length\": 8192,\n",
|
|
|
@@ -2375,6 +2451,24 @@
|
|
|
"}"
|
|
|
]
|
|
|
},
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "b5cd351b-d883-460d-9cdc-47e15ddb884a",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "387456c3-c6a1-46fe-8830-6e00eb46ac13",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "LLAMA32_CONFIG[\"context_length\"] = 8192"
|
|
|
+ ]
|
|
|
+ },
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
"id": "Dl4_0EoJKKYv",
|
|
|
@@ -2593,7 +2687,7 @@
|
|
|
"provenance": []
|
|
|
},
|
|
|
"kernelspec": {
|
|
|
- "display_name": "base",
|
|
|
+ "display_name": "Python 3 (ipykernel)",
|
|
|
"language": "python",
|
|
|
"name": "python3"
|
|
|
},
|
|
|
@@ -2607,7 +2701,7 @@
|
|
|
"name": "python",
|
|
|
"nbconvert_exporter": "python",
|
|
|
"pygments_lexer": "ipython3",
|
|
|
- "version": "3.10.11"
|
|
|
+ "version": "3.11.4"
|
|
|
},
|
|
|
"widgets": {
|
|
|
"application/vnd.jupyter.widget-state+json": {
|