Jelajahi Sumber

Reduce Llama 3 RoPE memory requirements (#658)

* Llama3 from scratch improvements

* Fix Llama 3 expensive RoPE memory issue

* updates

* update package

* benchmark

* remove unused rescale_theta
Sebastian Raschka 5 bulan lalu
induk
melakukan
c4cde1c21b

+ 3 - 0
.gitignore

@@ -51,6 +51,9 @@ ch05/07_gpt_to_llama/Llama-3.2-3B-Instruct
 ch05/10_llm-training-speed/middlemarch.txt
 ch05/10_llm-training-speed/loss.pdf
 ch05/10_llm-training-speed/model.pth
+ch05/07_gpt_to_llama/Untitled.ipynb
+ch05/07_gpt_to_llama/llama3.2-1B-instruct.pth
+ch05/07_gpt_to_llama/tokenizer.model
 
 ch06/01_main-chapter-code/gpt2
 ch06/02_bonus_additional-experiments/gpt2

+ 5 - 9
ch05/07_gpt_to_llama/README.md

@@ -40,8 +40,6 @@ MODEL_FILE = "llama3.2-1B-instruct.pth"
 Basic text generation settings that can be defined by the user. Note that the recommended 8192-token context size requires approximately 3 GB of VRAM for the text generation example.
 
 ```python
-MODEL_CONTEXT_LENGTH = 8192  # Supports up to 131_072
-
 # Text generation settings
 if "instruct" in MODEL_FILE:
     PROMPT = "What do llamas eat?"
@@ -82,8 +80,6 @@ elif "3B" in MODEL_FILE:
 else:
     raise ValueError("Incorrect model file name")
 
-LLAMA32_CONFIG["context_length"] = MODEL_CONTEXT_LENGTH
-
 model = Llama3Model(LLAMA32_CONFIG)
 model.load_state_dict(torch.load(MODEL_FILE, weights_only=True, map_location="cpu"))
 
@@ -125,7 +121,7 @@ Lastly, we can generate text via the following code:
 ```python
 import time
 
-from ch05 import (
+from llms_from_scratch.ch05 import (
     generate,
     text_to_token_ids,
     token_ids_to_text
@@ -192,8 +188,8 @@ The following table shows a performance comparison on an A100:
 
 |                 | Tokens/sec | Memory  |
 | --------------- | ---------- | ------- |
-| Llama3Model     | 50         | 2.91 GB |
-| Llama3ModelFast | 58         | 2.85 GB |
+| Llama3Model     | 42         | 2.91 GB |
+| Llama3ModelFast | 54         | 2.91 GB |
 
  
 #### Pro tip 2: speed up inference with compilation
@@ -218,5 +214,5 @@ The following table shows a performance comparison on an A100 for consequent `ge
 
 |                 | Tokens/sec | Memory  |
 | --------------- | ---------- | ------- |
-| Llama3Model     | 156        | 3.12 GB |
-| Llama3ModelFast | 159        | 2.84 GB |
+| Llama3Model     | 170        | 3.12 GB |
+| Llama3ModelFast | 177        | 3.61 GB |

+ 136 - 358
ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb

@@ -95,9 +95,9 @@
      "output_type": "stream",
      "text": [
       "blobfile version: 3.0.0\n",
-      "huggingface_hub version: 0.24.7\n",
-      "tiktoken version: 0.8.0\n",
-      "torch version: 2.4.1+cu121\n"
+      "huggingface_hub version: 0.30.1\n",
+      "tiktoken version: 0.9.0\n",
+      "torch version: 2.6.0\n"
      ]
     }
    ],
@@ -435,7 +435,7 @@
    "id": "842aa71a-4659-424e-8830-392bd6ae86af",
    "metadata": {},
    "source": [
-    "- In addition, we also introduce a `SharedBuffers` class that will allow us to reuse the `mask`, `cos`, and `sin` tensors in the transformer blocks to improve efficiency (this will be crucial when working with models such as Llama 3.1 and 3.2 later, which support up to 131k input tokens)"
+    "- **We also redesign the attention class a bit so it receives the mask through its forward method instead of storing and accessing it as `self.mask`. This lets us build the mask on the fly to reduce memory usage. To foreshadow why: Llama 3.1 can handle sequences of up to 128 k tokens, and precomputing a 128 k × 128 k causal mask would be extremely memory‑intensive, so we avoid it unless absolutely necessary.**"
    ]
   },
   {
@@ -450,27 +450,6 @@
     "import torch.nn as nn\n",
     "\n",
     "\n",
-    "############################# NEW  #############################\n",
-    "class SharedBuffers:\n",
-    "    _buffers = {}\n",
-    "\n",
-    "    @staticmethod\n",
-    "    def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):\n",
-    "        key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)\n",
-    "\n",
-    "        if key not in SharedBuffers._buffers:\n",
-    "            # Create or fetch the buffers\n",
-    "            mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
-    "            cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)\n",
-    "            if dtype is not None:\n",
-    "                cos = cos.to(dtype)\n",
-    "                sin = sin.to(dtype)\n",
-    "            SharedBuffers._buffers[key] = (mask, cos, sin)\n",
-    "\n",
-    "        return SharedBuffers._buffers[key]\n",
-    "############################# NEW  #############################\n",
-    "\n",
-    "\n",
     "class GroupedQueryAttention(nn.Module):\n",
     "    def __init__(\n",
     "            self, d_in, d_out, context_length, num_heads,\n",
@@ -499,16 +478,12 @@
     "        self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
     "        self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n",
     "\n",
-    "        ############################# NEW  #############################\n",
-    "        # Fetch buffers using SharedBuffers\n",
-    "        mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n",
-    "        ############################# NEW  #############################\n",
-    "        \n",
-    "        self.register_buffer(\"mask\", mask)\n",
-    "        self.register_buffer(\"cos\", cos)\n",
-    "        self.register_buffer(\"sin\", sin)\n",
     "\n",
-    "    def forward(self, x):\n",
+    "    def forward(self, x, mask=None, cos=None, sin=None):\n",
+    "        ##################### NEW  #####################\n",
+    "        # The forward method now accepts `mask` instead of accessing it via self.mask.\n",
+    "        # Also, we now have cos and sin as input for RoPE\n",
+    "        ################################################    \n",
     "        b, num_tokens, d_in = x.shape\n",
     "\n",
     "        queries = self.W_query(x)  # Shape: (b, num_tokens, d_out)\n",
@@ -530,9 +505,12 @@
     "        values = values.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)\n",
     "        queries = queries.transpose(1, 2)  # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
     "\n",
+    "        ##################### NEW #####################\n",
     "        # Apply RoPE\n",
-    "        keys = compute_rope(keys, self.cos, self.sin)\n",
-    "        queries = compute_rope(queries, self.cos, self.sin)\n",
+    "        if cos is not None:\n",
+    "            keys = compute_rope(keys, cos, sin)\n",
+    "            queries = compute_rope(queries, cos, sin)\n",
+    "        ################################################\n",
     "\n",
     "        ##################### NEW  #####################\n",
     "        # Expand keys and values to match the number of heads\n",
@@ -552,11 +530,14 @@
     "        # Shape: (b, num_heads, num_tokens, num_tokens)\n",
     "        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head\n",
     "\n",
-    "        # Original mask truncated to the number of tokens and converted to boolean\n",
-    "        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
-    "\n",
+    "        ##################### NEW #####################\n",
+    "        # Create mask on the fly\n",
+    "        if mask is None:\n",
+    "            mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)\n",
+    "        ################################################\n",
+    "    \n",
     "        # Use the mask to fill attention scores\n",
-    "        attn_scores.masked_fill_(mask_bool, -torch.inf)\n",
+    "        attn_scores.masked_fill_(mask, -torch.inf)\n",
     "\n",
     "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
     "        assert keys.shape[-1] == self.head_dim\n",
@@ -578,7 +559,7 @@
     "id": "roAXSwJs9hR8"
    },
    "source": [
-    "- To illustrate the parameter savings, consider the following multi-head attention example from the GPT and Llama 2 code:"
+    "- To illustrate the parameter savings in GQA over MHA, consider the following multi-head attention example from the GPT and Llama 2 code:"
    ]
   },
   {
@@ -753,7 +734,8 @@
    },
    "source": [
     "- Next, we update the `TransformerBlock`\n",
-    "- Here, we simply swap `MultiHeadAttention` with `GroupedQueryAttention` and add the new RoPE settings"
+    "- Here, we simply swap `MultiHeadAttention` with `GroupedQueryAttention` and add the new RoPE settings\n",
+    "- In addition, we also modify the `forward` method so that it receives `mask`, `cos`, and `sin`; since the values for those are the same for each transformer block, we only have to compute them once and then can reuse them"
    ]
   },
   {
@@ -782,11 +764,15 @@
     "        self.norm1 = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
     "        self.norm2 = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
     "\n",
-    "    def forward(self, x):\n",
+    "    def forward(self, x, mask=None, cos=None, sin=None):\n",
+    "        ##################### NEW  #####################\n",
+    "        # The forward method now accepts `mask` instead of accessing it via self.mask.\n",
+    "        # Also, we now have cos and sin as input for RoPE\n",
+    "        ################################################\n",
     "        # Shortcut connection for attention block\n",
     "        shortcut = x\n",
     "        x = self.norm1(x)\n",
-    "        x = self.att(x.to(torch.bfloat16))   # Shape [batch_size, num_tokens, emb_size]\n",
+    "        x = self.att(x.to(torch.bfloat16), mask, cos, sin)   # Shape [batch_size, num_tokens, emb_size]\n",
     "        x = x + shortcut  # Add the original input back\n",
     "\n",
     "        # Shortcut connection for feed-forward block\n",
@@ -816,7 +802,8 @@
     "id": "M_tLAq_r_llN"
    },
    "source": [
-    "- When setting up the model class, we fortunately don't have to do much; we just update the name to `Llama3Model`"
+    "- When setting up the model class, we technically don't have to do much; we just update the name to `Llama3Model`\n",
+    "- However, since we now pass the `mask`, `cos`, and `sin` to the transformer blocks, we also have to add them here"
    ]
   },
   {
@@ -840,12 +827,33 @@
     "        self.final_norm = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
     "        self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
     "\n",
+    "        #################### NEW #####################\n",
+    "        cos, sin = precompute_rope_params(\n",
+    "            head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n",
+    "            theta_base=cfg[\"rope_base\"],\n",
+    "            context_length=cfg[\"context_length\"],\n",
+    "            freq_config=cfg[\"rope_freq\"]\n",
+    "        )\n",
+    "        \n",
+    "        self.register_buffer(\"cos\", cos, persistent=False)\n",
+    "        self.register_buffer(\"sin\", sin, persistent=False)\n",
+    "        ##############################################\n",
+    "\n",
+    "        self.cfg = cfg\n",
+    "\n",
     "    def forward(self, in_idx):\n",
     "        tok_embeds = self.tok_emb(in_idx)\n",
     "        x = tok_embeds\n",
-    "        x = self.trf_blocks(x)\n",
+    "\n",
+    "        #################### NEW #####################\n",
+    "        num_tokens = x.shape[1]\n",
+    "        mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)\n",
+    "        ##############################################\n",
+    "        \n",
+    "        for block in self.trf_blocks:\n",
+    "            x = block(x, mask, self.cos, self.sin)\n",
     "        x = self.final_norm(x)\n",
-    "        logits = self.out_head(x.to(torch.bfloat16))\n",
+    "        logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n",
     "        return logits"
    ]
   },
@@ -936,33 +944,12 @@
     "model = Llama3Model(LLAMA3_CONFIG_8B)"
    ]
   },
-  {
-   "cell_type": "markdown",
-   "id": "edea6334-d1fc-427d-9cf2-4af963ff4bfc",
-   "metadata": {},
-   "source": [
-    "- The following is expected to print True to confirm buffers are reused instead of being (wastefully) recreated:"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "ee9625cc-9afa-4b11-8aab-d536fd170761",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# Check buffers\n",
-    "print(model.trf_blocks[0].att.mask is model.trf_blocks[-1].att.mask)\n",
-    "print(model.trf_blocks[0].att.cos is model.trf_blocks[-1].att.cos)\n",
-    "print(model.trf_blocks[0].att.sin is model.trf_blocks[-1].att.sin) "
-   ]
-  },
   {
    "cell_type": "markdown",
    "id": "8056a521-91a6-440f-8473-591409c3177b",
    "metadata": {},
    "source": [
-    "- Let's now also compute the number of trainable parameters:"
+    "- Let's now compute the number of trainable parameters:"
    ]
   },
   {
@@ -1017,8 +1004,8 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "float32 (PyTorch default): 68.08 GB\n",
-      "bfloat16: 34.04 GB\n"
+      "float32 (PyTorch default): 59.84 GB\n",
+      "bfloat16: 29.92 GB\n"
      ]
     }
    ],
@@ -1121,43 +1108,47 @@
     "\n",
     "\n",
     "class Tokenizer:\n",
+    "    \"\"\"Thin wrapper around tiktoken that keeps track of Llama-3 special IDs.\"\"\"\n",
     "    def __init__(self, model_path):\n",
-    "        assert os.path.isfile(model_path), f\"Model file {model_path} not found\"\n",
-    "        mergeable_ranks = load_tiktoken_bpe(model_path)\n",
+    "        if not os.path.isfile(model_path):\n",
+    "            raise FileNotFoundError(model_path)\n",
+    "\n",
+    "        mergeable = load_tiktoken_bpe(model_path)\n",
     "\n",
-    "        self.special_tokens = {\n",
+    "        # hard-coded from Meta's tokenizer.json\n",
+    "        self.special = {\n",
     "            \"<|begin_of_text|>\": 128000,\n",
     "            \"<|end_of_text|>\": 128001,\n",
     "            \"<|start_header_id|>\": 128006,\n",
     "            \"<|end_header_id|>\": 128007,\n",
     "            \"<|eot_id|>\": 128009,\n",
     "        }\n",
-    "        self.special_tokens.update({\n",
-    "            f\"<|reserved_{i}|>\": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()\n",
-    "        })\n",
+    "        self.special.update({f\"<|reserved_{i}|>\": 128002 + i\n",
+    "                             for i in range(256)\n",
+    "                             if 128002 + i not in self.special.values()})\n",
     "\n",
     "        self.model = tiktoken.Encoding(\n",
     "            name=Path(model_path).name,\n",
-    "            pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\",\n",
-    "            mergeable_ranks=mergeable_ranks,\n",
-    "            special_tokens=self.special_tokens\n",
+    "            pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)\"\n",
+    "                    r\"|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+\"\n",
+    "                    r\"|\\p{N}{1,3}\"\n",
+    "                    r\"| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*\"\n",
+    "                    r\"|\\s*[\\r\\n]+\"\n",
+    "                    r\"|\\s+(?!\\S)\"\n",
+    "                    r\"|\\s+\",\n",
+    "            mergeable_ranks=mergeable,\n",
+    "            special_tokens=self.special,\n",
     "        )\n",
     "\n",
-    "\n",
-    "    def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):\n",
-    "        if bos:\n",
-    "            tokens = [self.special_tokens[\"<|begin_of_text|>\"]]\n",
-    "        else:\n",
-    "            tokens = []\n",
-    "\n",
-    "        tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)\n",
-    "\n",
+    "    def encode(self, text, bos=False, eos=False):\n",
+    "        ids = ([self.special[\"<|begin_of_text|>\"]] if bos else []) \\\n",
+    "              + self.model.encode(text)\n",
     "        if eos:\n",
-    "            tokens.append(self.special_tokens[\"<|end_of_text|>\"])\n",
-    "        return tokens\n",
+    "            ids.append(self.special[\"<|end_of_text|>\"])\n",
+    "        return ids\n",
     "\n",
-    "    def decode(self, tokens):\n",
-    "        return self.model.decode(tokens)"
+    "    def decode(self, ids):\n",
+    "        return self.model.decode(ids)"
    ]
   },
   {
@@ -1202,13 +1193,11 @@
    },
    "outputs": [
     {
-     "name": "stdout",
+     "name": "stderr",
      "output_type": "stream",
      "text": [
-      "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n",
-      "Token is valid (permission: read).\n",
-      "Your token has been saved to /root/.cache/huggingface/token\n",
-      "Login successful\n"
+      "/Users/sebastian/Developer/LLMs-from-scratch/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+      "  from .autonotebook import tqdm as notebook_tqdm\n"
      ]
     }
    ],
@@ -1309,7 +1298,8 @@
      "base_uri": "https://localhost:8080/"
     },
     "id": "e0a2b5cd-6cba-4d72-b8ff-04d8315d483e",
-    "outputId": "990d7b74-cb35-476b-d8bd-d544006e00f4"
+    "outputId": "990d7b74-cb35-476b-d8bd-d544006e00f4",
+    "scrolled": true
    },
    "outputs": [
     {
@@ -1318,7 +1308,9 @@
      "text": [
       "Output text:\n",
       " Every effort_dead aeros Ingredients başında.extensionégor clangmissions güc như submodule.and report官方%,.Reader(\",\");\n",
-      "ामल ندار Parliamentary !!! HigginsDynamicZhgmt writeln Globalsletion 사진------\n"
+      "ामल ندار Parliamentary !!! HigginsDynamicZhamincus_beam cyc......\n",
+      "\n",
+      " haciendo\n"
      ]
     }
    ],
@@ -1437,22 +1429,7 @@
     "id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4",
     "outputId": "c05118ce-9f81-41c8-a1f2-72caa932ae86"
    },
-   "outputs": [
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "245443330e4d40c887a5649cc1663e98",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    }
-   ],
+   "outputs": [],
    "source": [
     "from safetensors.torch import load_file\n",
     "\n",
@@ -1763,64 +1740,7 @@
     "id": "nbvAV7vaz6yc",
     "outputId": "9e1badc9-a6c4-48b7-9125-e0810655528b"
    },
-   "outputs": [
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "f7df6bbf8e63448c8a6cb5d2f6208403",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "model-00001-of-00004.safetensors:  36%|###6      | 1.81G/4.98G [00:00<?, ?B/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "4772f31a1c5b4c168c9aabe7a1d2bacc",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "ad49eeb9e1204ea2bd2e371df8ccdea2",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "951b9e81613a40a2a503f61e69677f0a",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    }
-   ],
+   "outputs": [],
    "source": [
     "combined_weights = {}\n",
     "\n",
@@ -1861,35 +1781,40 @@
    "outputs": [],
    "source": [
     "class ChatFormat:\n",
-    "    def __init__(self, tokenizer):\n",
-    "        self.tokenizer = tokenizer\n",
-    "\n",
-    "    def encode_header(self, message):\n",
-    "        tokens = []\n",
-    "        tokens.append(self.tokenizer.special_tokens[\"<|start_header_id|>\"])\n",
-    "        tokens.extend(self.tokenizer.encode(message[\"role\"], bos=False, eos=False))\n",
-    "        tokens.append(self.tokenizer.special_tokens[\"<|end_header_id|>\"])\n",
-    "        tokens.extend(self.tokenizer.encode(\"\\n\\n\", bos=False, eos=False))\n",
-    "        return tokens\n",
-    "\n",
-    "    def encode(self, text):\n",
-    "        message = {\n",
-    "            \"role\": \"user\",\n",
-    "            \"content\": text\n",
-    "        }\n",
     "\n",
-    "        tokens = self.encode_header(message)\n",
-    "        tokens.extend(\n",
-    "            self.tokenizer.encode(message[\"content\"].strip(), bos=False, eos=False)\n",
+    "    def __init__(self, tokenizer: Tokenizer, *,\n",
+    "                 default_system=\"You are a helpful assistant.\"):\n",
+    "        self.tok = tokenizer\n",
+    "        self.default_system = default_system\n",
+    "\n",
+    "    def _header(self, role):\n",
+    "        \"\"\"Encode <|start_header_id|>role<|end_header_id|>\\n\\n\"\"\"\n",
+    "        return (\n",
+    "            [self.tok.special[\"<|start_header_id|>\"]]\n",
+    "            + self.tok.encode(role)\n",
+    "            + [self.tok.special[\"<|end_header_id|>\"]]\n",
+    "            + self.tok.encode(\"\\n\\n\")\n",
     "        )\n",
-    "        tokens.append(self.tokenizer.special_tokens[\"<|eot_id|>\"])\n",
-    "        return tokens\n",
     "\n",
-    "    def decode(self, token_ids):\n",
-    "        return self.tokenizer.decode(token_ids)\n",
+    "    def encode(self, user_message, system_message=None):\n",
+    "        sys_msg = system_message if system_message is not None else self.default_system\n",
+    "\n",
+    "        ids = [self.tok.special[\"<|begin_of_text|>\"]]\n",
     "\n",
+    "        # system\n",
+    "        ids += self._header(\"system\")\n",
+    "        ids += self.tok.encode(sys_msg)\n",
+    "        ids += [self.tok.special[\"<|eot_id|>\"]]\n",
     "\n",
-    "chat_tokenizer = ChatFormat(tokenizer)"
+    "        # user\n",
+    "        ids += self._header(\"user\")\n",
+    "        ids += self.tok.encode(user_message)\n",
+    "        ids += [self.tok.special[\"<|eot_id|>\"]]\n",
+    "\n",
+    "        # assistant header (no content yet)\n",
+    "        ids += self._header(\"assistant\")\n",
+    "\n",
+    "        return ids"
    ]
   },
   {
@@ -1918,11 +1843,14 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "[128006, 882, 128007, 271, 9906, 4435, 0, 128009]\n"
+      "[128000, 128006, 9125, 128007, 271, 2675, 527, 264, 11190, 18328, 13, 128009, 128006, 882, 128007, 271, 9906, 4435, 0, 128009, 128006, 78191, 128007, 271]\n"
      ]
     }
    ],
    "source": [
+    "tokenizer = Tokenizer(tokenizer_file_path)\n",
+    "chat_tokenizer = ChatFormat(tokenizer)\n",
+    "\n",
     "token_ids = chat_tokenizer.encode(\"Hello World!\")\n",
     "print(token_ids)"
    ]
@@ -1943,7 +1871,7 @@
     {
      "data": {
       "text/plain": [
-       "'<|start_header_id|>user<|end_header_id|>\\n\\nHello World!<|eot_id|>'"
+       "'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\nHello World!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n'"
       ]
      },
      "execution_count": 35,
@@ -1982,12 +1910,13 @@
      "output_type": "stream",
      "text": [
       "Output text:\n",
-      " Llamas are herbivores, which means they primarily eat plants and plant-based foods. Here are some of the things llamas like to eat:\n",
+      " Llamas are herbivores, which means they primarily eat plants and plant-based foods. Their diet typically consists of:\n",
       "\n",
-      "1. Grass: Llamas love to graze on grass, especially in the spring and summer months.\n",
-      "2. Hay: Hay is a staple in a llama's diet. They like to eat timothy hay, alfalfa hay, and other types of hay.\n",
-      "3. Grains: Llamas may also be fed grains like oats, barley, and corn. However, grains should not make up more than 10-15% of a llama's diet.\n",
-      "4. Fruits and vegetables: Llamas may enjoy fruits and vegetables as treats, such as\n"
+      "1. Grasses: Llamas love to graze on grasses, including tall grasses, short grasses, and even weeds.\n",
+      "2. Hay: Hay is a staple in a llama's diet. They enjoy a variety of hays, such as timothy hay, alfalfa hay, and oat hay.\n",
+      "3. Grains: Llamas may be fed grains like oats, corn, and barley as a supplement to their diet.\n",
+      "4. Fruits and vegetables: Llamas enjoy fruits and vegetables like apples, carrots, and sweet potatoes as treats or additions to their diet.\n",
+      "5. Minerals:\n"
      ]
     }
    ],
@@ -2088,49 +2017,6 @@
     "}"
    ]
   },
-  {
-   "cell_type": "markdown",
-   "id": "d81ee464-c112-43b0-9ee8-70df6ac942d0",
-   "metadata": {},
-   "source": [
-    "- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 10,
-   "id": "a55a8769-1a03-4265-8fd0-15f1c423da53",
-   "metadata": {
-    "id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c"
-   },
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "New RoPE theta: 31250.0\n"
-     ]
-    }
-   ],
-   "source": [
-    "old_context_length = LLAMA31_CONFIG_8B[\"context_length\"]\n",
-    "LLAMA31_CONFIG_8B[\"context_length\"] = 8192\n",
-    "\n",
-    "\n",
-    "def rescale_theta(theta_old, context_length_old, context_length_new):\n",
-    "    scaling_factor = context_length_new / context_length_old\n",
-    "    theta_new = theta_old * scaling_factor\n",
-    "    return theta_new\n",
-    "\n",
-    "LLAMA31_CONFIG_8B[\"rope_base\"] = rescale_theta(\n",
-    "    LLAMA31_CONFIG_8B[\"rope_base\"],\n",
-    "    old_context_length,\n",
-    "    LLAMA31_CONFIG_8B[\"context_length\"]\n",
-    ")\n",
-    "\n",
-    "print(\"New RoPE theta:\", LLAMA31_CONFIG_8B[\"rope_base\"])"
-   ]
-  },
   {
    "cell_type": "markdown",
    "id": "xa3bpMDtTdBs",
@@ -2277,64 +2163,7 @@
     "id": "u4J7IxOvOyPM",
     "outputId": "925348d7-fc69-4d1b-90f1-7029426bcfcf"
    },
-   "outputs": [
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "eabfde3ef38b436ea750e6fb50a02b5c",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "e117ad45771747ae95c16f9876e6dc19",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "170185f2f046437dab57c2ad23163c5c",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "6e65f5d6c5af4ab78bc7b3778b98ef86",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    }
-   ],
+   "outputs": [],
    "source": [
     "combined_weights = {}\n",
     "\n",
@@ -2481,43 +2310,6 @@
     "}"
    ]
   },
-  {
-   "cell_type": "markdown",
-   "id": "b5cd351b-d883-460d-9cdc-47e15ddb884a",
-   "metadata": {},
-   "source": [
-    "- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 10,
-   "id": "73f001a6-7ae0-4204-aa83-a27a8878dfd2",
-   "metadata": {
-    "id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c"
-   },
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "New RoPE theta: 31250.0\n"
-     ]
-    }
-   ],
-   "source": [
-    "old_context_length = LLAMA32_CONFIG_1B[\"context_length\"]\n",
-    "LLAMA32_CONFIG_1B[\"context_length\"] = 8192\n",
-    "\n",
-    "LLAMA32_CONFIG_1B[\"rope_base\"] = rescale_theta(\n",
-    "    LLAMA32_CONFIG_1B[\"rope_base\"],\n",
-    "    old_context_length,\n",
-    "    LLAMA32_CONFIG_1B[\"context_length\"]\n",
-    ")\n",
-    "\n",
-    "print(\"New RoPE theta:\", LLAMA32_CONFIG_1B[\"rope_base\"])"
-   ]
-  },
   {
    "cell_type": "markdown",
    "id": "Dl4_0EoJKKYv",
@@ -2612,20 +2404,6 @@
     "outputId": "35588405-e2e1-4871-a1db-1d4bcb852e49"
    },
    "outputs": [
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "c309c56a6cdf426e8ba7967b6a21864e",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
     {
      "name": "stdout",
      "output_type": "stream",
@@ -2688,7 +2466,7 @@
      "output_type": "stream",
      "text": [
       "Output text:\n",
-      " Every effort is made to ensure that the information on this website is accurate. However, we cannot guarantee that the information is accurate, complete\n"
+      " Every effort is made to ensure that the information on this website is accurate and up to date. However, the information is provided without any\n"
      ]
     }
    ],

+ 0 - 1881
ch05/07_gpt_to_llama/standalone-llama32-mem-opt.ipynb

@@ -1,1881 +0,0 @@
-{
- "cells": [
-  {
-   "cell_type": "markdown",
-   "id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c",
-   "metadata": {
-    "id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c"
-   },
-   "source": [
-    "<table style=\"width:100%\">\n",
-    "<tr>\n",
-    "<td style=\"vertical-align:middle; text-align:left;\">\n",
-    "<font size=\"2\">\n",
-    "Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
-    "<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
-    "</font>\n",
-    "</td>\n",
-    "<td style=\"vertical-align:middle; text-align:left;\">\n",
-    "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
-    "</td>\n",
-    "</tr>\n",
-    "</table>"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "efde77f2-6af3-4781-8597-89ecd3f41a52",
-   "metadata": {
-    "id": "efde77f2-6af3-4781-8597-89ecd3f41a52"
-   },
-   "source": [
-    "# Llama 3.2 From Scratch (A Standalone Notebook)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d",
-   "metadata": {
-    "id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d"
-   },
-   "source": [
-    "**Note: This notebook is an alternative to the [standalone-llama32.ipynb](standalone-llama32.ipynb) notebook but optimized for memory efficiency by using a global mask, cos, and sin. On an A100, based on a 8192 context length, this only uses 3.1 GB (vs 7.07 GB) VRAM.** \n",
-    "\n",
-    "\n",
-    "- This notebook is purposefully minimal and focuses on the code to implement the Llama 3.2 1B and 3B LLMs\n",
-    "- For a step-by-step guide that explains the individual components and the relationship between GPT, Llama 2, and Llama 3, please see the following companion notebooks:\n",
-    "  - [Converting a From-Scratch GPT Architecture to Llama 2](converting-gpt-to-llama2.ipynb)\n",
-    "  - [Converting Llama 2 to Llama 3.2 From Scratch](converting-llama2-to-llama3.ipynb)\n",
-    "  \n",
-    "\n",
-    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/llama32.webp\" width=\"700px\">\n",
-    "  \n",
-    "  \n",
-    "- About the code:\n",
-    "  - all code is my own code, mapping the Llama 3 architecture onto the model code implemented in my [Build A Large Language Model (From Scratch)](http://mng.bz/orYv) book; the code is released under a permissive open-source Apache 2.0 license (see [LICENSE.txt](https://github.com/rasbt/LLMs-from-scratch/blob/main/LICENSE.txt))\n",
-    "  - the tokenizer code is inspired by the original [Llama 3 tokenizer code](https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py), which Meta AI used to extend the Tiktoken GPT-4 tokenizer\n",
-    "  - the RoPE rescaling section is inspired by the [_compute_llama3_parameters function](https://github.com/huggingface/transformers/blob/5c1027bf09717f664b579e01cbb8ec3ef5aeb140/src/transformers/modeling_rope_utils.py#L329-L347) in the `transformers` library"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 1,
-   "id": "7c201adb-747e-437b-9a62-442802941e01",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 2,
-   "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
-   "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/"
-    },
-    "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
-    "outputId": "4f762354-e0a3-4cc2-e5d4-e61a227a202c"
-   },
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "blobfile version: 3.0.0\n",
-      "huggingface_hub version: 0.29.3\n",
-      "tiktoken version: 0.9.0\n",
-      "torch version: 2.6.0\n"
-     ]
-    }
-   ],
-   "source": [
-    "from importlib.metadata import version\n",
-    "\n",
-    "pkgs = [\n",
-    "    \"blobfile\",         # to download pretrained weights\n",
-    "    \"huggingface_hub\",  # to download pretrained weights\n",
-    "    \"tiktoken\",         # to implement the tokenizer\n",
-    "    \"torch\",            # to implement the model\n",
-    "]\n",
-    "for p in pkgs:\n",
-    "    print(f\"{p} version: {version(p)}\")"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "653410a6-dd2b-4eb2-a722-23d9782e726d",
-   "metadata": {
-    "id": "653410a6-dd2b-4eb2-a722-23d9782e726d"
-   },
-   "source": [
-    "&nbsp;\n",
-    "# 1. Architecture code"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 3,
-   "id": "82076c21-9331-4dcd-b017-42b046cf1a60",
-   "metadata": {
-    "id": "82076c21-9331-4dcd-b017-42b046cf1a60"
-   },
-   "outputs": [],
-   "source": [
-    "import torch\n",
-    "import torch.nn as nn\n",
-    "\n",
-    "\n",
-    "class FeedForward(nn.Module):\n",
-    "    def __init__(self, cfg):\n",
-    "        super().__init__()\n",
-    "        self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
-    "        self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
-    "        self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
-    "\n",
-    "    def forward(self, x):\n",
-    "        x_fc1 = self.fc1(x)\n",
-    "        x_fc2 = self.fc2(x)\n",
-    "        x = nn.functional.silu(x_fc1) * x_fc2\n",
-    "        return self.fc3(x)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 4,
-   "id": "4b9a346f-5826-4083-9162-abd56afc03f0",
-   "metadata": {
-    "id": "4b9a346f-5826-4083-9162-abd56afc03f0"
-   },
-   "outputs": [],
-   "source": [
-    "def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32):\n",
-    "    assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
-    "\n",
-    "    # Compute the inverse frequencies\n",
-    "    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))\n",
-    "\n",
-    "    # Frequency adjustments\n",
-    "    if freq_config is not None:\n",
-    "        low_freq_wavelen = freq_config[\"original_context_length\"] / freq_config[\"low_freq_factor\"]\n",
-    "        high_freq_wavelen = freq_config[\"original_context_length\"] / freq_config[\"high_freq_factor\"]\n",
-    "\n",
-    "        wavelen = 2 * torch.pi / inv_freq\n",
-    "\n",
-    "        inv_freq_llama = torch.where(\n",
-    "            wavelen > low_freq_wavelen, inv_freq / freq_config[\"factor\"], inv_freq\n",
-    "        )\n",
-    "\n",
-    "        smooth_factor = (freq_config[\"original_context_length\"] / wavelen - freq_config[\"low_freq_factor\"]) / (\n",
-    "            freq_config[\"high_freq_factor\"] - freq_config[\"low_freq_factor\"]\n",
-    "        )\n",
-    "\n",
-    "        smoothed_inv_freq = (\n",
-    "            (1 - smooth_factor) * (inv_freq / freq_config[\"factor\"]) + smooth_factor * inv_freq\n",
-    "        )\n",
-    "\n",
-    "        is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)\n",
-    "        inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)\n",
-    "        inv_freq = inv_freq_llama\n",
-    "\n",
-    "    # Generate position indices\n",
-    "    positions = torch.arange(context_length, dtype=dtype)\n",
-    "\n",
-    "    # Compute the angles\n",
-    "    angles = positions[:, None] * inv_freq[None, :]  # Shape: (context_length, head_dim // 2)\n",
-    "\n",
-    "    # Expand angles to match the head_dim\n",
-    "    angles = torch.cat([angles, angles], dim=1)  # Shape: (context_length, head_dim)\n",
-    "\n",
-    "    # Precompute sine and cosine\n",
-    "    cos = torch.cos(angles)\n",
-    "    sin = torch.sin(angles)\n",
-    "\n",
-    "    return cos, sin\n",
-    "\n",
-    "\n",
-    "def apply_rope(x, cos, sin):\n",
-    "    # x: (batch_size, num_heads, seq_len, head_dim)\n",
-    "    batch_size, num_heads, seq_len, head_dim = x.shape\n",
-    "    assert head_dim % 2 == 0, \"Head dimension must be even\"\n",
-    "\n",
-    "    # Split x into first half and second half\n",
-    "    x1 = x[..., : head_dim // 2]  # First half\n",
-    "    x2 = x[..., head_dim // 2 :]  # Second half\n",
-    "\n",
-    "    # Adjust sin and cos shapes\n",
-    "    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_len, head_dim)\n",
-    "    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)\n",
-    "\n",
-    "    # Apply the rotary transformation\n",
-    "    rotated = torch.cat((-x2, x1), dim=-1)\n",
-    "    x_rotated = (x * cos) + (rotated * sin)\n",
-    "\n",
-    "    # It's ok to use lower-precision after applying cos and sin rotation\n",
-    "    return x_rotated.to(dtype=x.dtype)\n",
-    "\n",
-    "\n",
-    "def rescale_theta(theta_old, context_length_old, context_length_new):\n",
-    "    scaling_factor = context_length_new / context_length_old\n",
-    "    theta_new = theta_old * scaling_factor\n",
-    "    return theta_new"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 5,
-   "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb",
-   "metadata": {
-    "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb"
-   },
-   "outputs": [],
-   "source": [
-    "class GroupedQueryAttention(nn.Module):\n",
-    "    def __init__(\n",
-    "            self, d_in, d_out, num_heads,\n",
-    "            num_kv_groups,\n",
-    "            dtype=None\n",
-    "        ):\n",
-    "        super().__init__()\n",
-    "        assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
-    "        assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n",
-    "\n",
-    "        self.d_out = d_out\n",
-    "        self.num_heads = num_heads\n",
-    "        self.head_dim = d_out // num_heads\n",
-    "\n",
-    "        self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)\n",
-    "        self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)\n",
-    "        self.num_kv_groups = num_kv_groups\n",
-    "        self.group_size = num_heads // num_kv_groups\n",
-    "\n",
-    "        self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
-    "        self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n",
-    "\n",
-    "    def forward(self, x, mask, cos, sin):\n",
-    "        b, num_tokens, d_in = x.shape\n",
-    "\n",
-    "        queries = self.W_query(x)  # Shape: (b, num_tokens, d_out)\n",
-    "        keys = self.W_key(x)  # Shape: (b, num_tokens, num_kv_groups * head_dim)\n",
-    "        values = self.W_value(x)  # Shape: (b, num_tokens, num_kv_groups * head_dim)\n",
-    "\n",
-    "        # Reshape queries, keys, and values\n",
-    "        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n",
-    "        keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n",
-    "        values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n",
-    "\n",
-    "        # Transpose keys, values, and queries\n",
-    "        keys = keys.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)\n",
-    "        values = values.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)\n",
-    "        queries = queries.transpose(1, 2)  # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
-    "\n",
-    "        # Apply RoPE\n",
-    "        keys = apply_rope(keys, cos, sin)\n",
-    "        queries = apply_rope(queries, cos, sin)\n",
-    "\n",
-    "        # Expand keys and values to match the number of heads\n",
-    "        # Shape: (b, num_heads, num_tokens, head_dim)\n",
-    "        keys = keys.repeat_interleave(self.group_size, dim=1)  # Shape: (b, num_heads, num_tokens, head_dim)\n",
-    "        values = values.repeat_interleave(self.group_size, dim=1)  # Shape: (b, num_heads, num_tokens, head_dim)\n",
-    "        # For example, before repeat_interleave along dim=1 (query groups):\n",
-    "        #   [K1, K2]\n",
-    "        # After repeat_interleave (each query group is repeated group_size times):\n",
-    "        #   [K1, K1, K2, K2]\n",
-    "        # If we used regular repeat instead of repeat_interleave, we'd get:\n",
-    "        #   [K1, K2, K1, K2]\n",
-    "\n",
-    "        # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
-    "        # Shape: (b, num_heads, num_tokens, num_tokens)\n",
-    "        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head\n",
-    "\n",
-    "        # Use the mask to fill attention scores\n",
-    "        attn_scores = attn_scores.masked_fill(mask[:num_tokens, :num_tokens], -torch.inf)\n",
-    "\n",
-    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
-    "        assert keys.shape[-1] == self.head_dim\n",
-    "\n",
-    "        # Shape: (b, num_tokens, num_heads, head_dim)\n",
-    "        context_vec = (attn_weights @ values).transpose(1, 2)\n",
-    "\n",
-    "        # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
-    "        context_vec = context_vec.reshape(b, num_tokens, self.d_out)\n",
-    "        context_vec = self.out_proj(context_vec)  # optional projection\n",
-    "\n",
-    "        return context_vec"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 6,
-   "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9",
-   "metadata": {
-    "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9"
-   },
-   "outputs": [],
-   "source": [
-    "class TransformerBlock(nn.Module):\n",
-    "    def __init__(self, cfg):\n",
-    "        super().__init__()\n",
-    "        self.att = GroupedQueryAttention(\n",
-    "            d_in=cfg[\"emb_dim\"],\n",
-    "            d_out=cfg[\"emb_dim\"],\n",
-    "            num_heads=cfg[\"n_heads\"],\n",
-    "            num_kv_groups=cfg[\"n_kv_groups\"],\n",
-    "            dtype=cfg[\"dtype\"]\n",
-    "        )\n",
-    "        self.ff = FeedForward(cfg)\n",
-    "        self.norm1 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
-    "        self.norm2 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
-    "\n",
-    "    def forward(self, x, mask, cos, sin):\n",
-    "        # Shortcut connection for attention block\n",
-    "        shortcut = x\n",
-    "        x = self.norm1(x)\n",
-    "        x = self.att(x, mask, cos, sin)  # Shape [batch_size, num_tokens, emb_size]\n",
-    "        x = x + shortcut  # Add the original input back\n",
-    "\n",
-    "        # Shortcut connection for feed-forward block\n",
-    "        shortcut = x\n",
-    "        x = self.norm2(x)\n",
-    "        x = self.ff(x)\n",
-    "        x = x + shortcut  # Add the original input back\n",
-    "\n",
-    "        return x"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 7,
-   "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4",
-   "metadata": {
-    "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4"
-   },
-   "outputs": [],
-   "source": [
-    "class Llama3Model(nn.Module):\n",
-    "    def __init__(self, cfg):\n",
-    "        super().__init__()\n",
-    "\n",
-    "        # Main model parameters\n",
-    "        self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n",
-    "\n",
-    "        self.trf_blocks = nn.ModuleList(  # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`\n",
-    "            [TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])]\n",
-    "        )\n",
-    "\n",
-    "        self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
-    "        self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
-    "\n",
-    "        # Reusuable utilities\n",
-    "        self.register_buffer(\n",
-    "            \"mask\", torch.triu(torch.ones(cfg[\"context_length\"], cfg[\"context_length\"]), diagonal=1).bool(),\n",
-    "            persistent=False\n",
-    "        )\n",
-    "        cfg[\"rope_base\"] = rescale_theta(\n",
-    "                        cfg[\"rope_base\"],\n",
-    "                        cfg[\"orig_context_length\"],\n",
-    "                        cfg[\"context_length\"]\n",
-    "                    )\n",
-    "        cos, sin = compute_rope_params(\n",
-    "            head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n",
-    "            theta_base=cfg[\"rope_base\"],\n",
-    "            context_length=cfg[\"context_length\"],\n",
-    "            freq_config=cfg[\"rope_freq\"]\n",
-    "        )\n",
-    "        self.register_buffer(\"cos\", cos)\n",
-    "        self.register_buffer(\"sin\", sin)\n",
-    "        self.cfg = cfg\n",
-    "\n",
-    "\n",
-    "    def forward(self, in_idx):\n",
-    "        # Forward pass\n",
-    "        tok_embeds = self.tok_emb(in_idx)\n",
-    "        x = tok_embeds\n",
-    "        \n",
-    "        for block in self.trf_blocks:\n",
-    "            x = block(x, self.mask, self.cos, self.sin)\n",
-    "        x = self.final_norm(x)\n",
-    "        logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n",
-    "        return logits"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "be2d201f-74ad-4d63-ab9c-601b00674a48",
-   "metadata": {
-    "id": "be2d201f-74ad-4d63-ab9c-601b00674a48"
-   },
-   "source": [
-    "&nbsp;\n",
-    "# 2. Initialize model"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "23dea40c-fe20-4a75-be25-d6fce5863c01",
-   "metadata": {
-    "id": "23dea40c-fe20-4a75-be25-d6fce5863c01"
-   },
-   "source": [
-    "- The remainder of this notebook uses the Llama 3.2 1B model; to use the 3B model variant, just uncomment the second configuration file in the following code cell"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 8,
-   "id": "caa142fa-b375-4e78-b392-2072ced666f3",
-   "metadata": {
-    "id": "caa142fa-b375-4e78-b392-2072ced666f3"
-   },
-   "outputs": [],
-   "source": [
-    "# Llama 3.2 1B\n",
-    "\n",
-    "LLAMA32_CONFIG = {\n",
-    "    \"vocab_size\": 128_256,           # Vocabulary size\n",
-    "    \"context_length\": 8192,          # Maximum context length to use (reduced to save memory)\n",
-    "    \"orig_context_length\": 131_072,  # Context length that was used to train the model\n",
-    "    \"emb_dim\": 2048,                 # Embedding dimension\n",
-    "    \"n_heads\": 32,                   # Number of attention heads\n",
-    "    \"n_layers\": 16,                  # Number of layers\n",
-    "    \"hidden_dim\": 8192,              # Size of the intermediate dimension in FeedForward\n",
-    "    \"n_kv_groups\": 8,                # Key-Value groups for grouped-query attention\n",
-    "    \"rope_base\": 500_000.0,          # The base in RoPE's \"theta\"\n",
-    "    \"dtype\": torch.bfloat16,         # Lower-precision dtype to reduce memory usage\n",
-    "    \"rope_freq\": {                   # RoPE frequency scaling\n",
-    "        \"factor\": 32.0,\n",
-    "        \"low_freq_factor\": 1.0,\n",
-    "        \"high_freq_factor\": 4.0,\n",
-    "        \"original_context_length\": 8192,\n",
-    "    }\n",
-    "}\n",
-    "\n",
-    "# Llama 3.2 3B\n",
-    "\n",
-    "# LLAMA32_CONFIG = {\n",
-    "#     \"vocab_size\": 128_256,           # Vocabulary size\n",
-    "#     \"context_length\": 8192,          # Maximum context length to use (reduced to save memory)\n",
-    "#     \"orig_context_length\": 131_072,  # Context length that was used to train the model\n",
-    "#     \"emb_dim\": 3072,                 # Embedding dimension\n",
-    "#     \"n_heads\": 24,                   # Number of attention heads\n",
-    "#     \"n_layers\": 28,                  # Number of layers\n",
-    "#     \"hidden_dim\": 8192,              # Size of the intermediate dimension in FeedForward\n",
-    "#     \"n_kv_groups\": 8,                # Key-Value groups for grouped-query attention\n",
-    "#     \"rope_base\": 500_000.0,          # The base in RoPE's \"theta\"\n",
-    "#     \"dtype\": torch.bfloat16,         # Lower-precision dtype to reduce memory usage\n",
-    "#     \"rope_freq\": {                   # RoPE frequency scaling\n",
-    "#         \"factor\": 32.0,\n",
-    "#         \"low_freq_factor\": 1.0,\n",
-    "#         \"high_freq_factor\": 4.0,\n",
-    "#         \"original_context_length\": 8192,\n",
-    "#     }\n",
-    "# }\n",
-    "\n",
-    "LLAMA_SIZE_STR = \"1B\" if LLAMA32_CONFIG[\"emb_dim\"] == 2048 else \"3B\""
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 9,
-   "id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
-   "metadata": {
-    "id": "156253fe-aacd-4da2-8f13-705f05c4b11e"
-   },
-   "outputs": [],
-   "source": [
-    "model = Llama3Model(LLAMA32_CONFIG)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "19de6c2c-83ce-456d-8be9-6ec415fe9eb1",
-   "metadata": {
-    "id": "19de6c2c-83ce-456d-8be9-6ec415fe9eb1"
-   },
-   "source": [
-    "- The following is expected to print True to confirm buffers are reused instead of being (wastefully) recreated:"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 10,
-   "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
-   "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/"
-    },
-    "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
-    "outputId": "00d7e983-262e-4c65-f322-f4d999311988"
-   },
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Total number of parameters: 1,498,482,688\n",
-      "\n",
-      "Total number of unique parameters: 1,235,814,400\n"
-     ]
-    }
-   ],
-   "source": [
-    "total_params = sum(p.numel() for p in model.parameters())\n",
-    "print(f\"Total number of parameters: {total_params:,}\")\n",
-    "\n",
-    "# Account for weight tying\n",
-    "total_params_normalized = total_params - model.tok_emb.weight.numel()\n",
-    "print(f\"\\nTotal number of unique parameters: {total_params_normalized:,}\")"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 11,
-   "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
-   "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/"
-    },
-    "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
-    "outputId": "65c1a95e-b502-4150-9e2e-da619d9053d5"
-   },
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "float32 (PyTorch default): 11.42 GB\n",
-      "bfloat16: 5.71 GB\n"
-     ]
-    }
-   ],
-   "source": [
-    "def model_memory_size(model, input_dtype=torch.float32):\n",
-    "    total_params = 0\n",
-    "    total_grads = 0\n",
-    "    for param in model.parameters():\n",
-    "        # Calculate total number of elements per parameter\n",
-    "        param_size = param.numel()\n",
-    "        total_params += param_size\n",
-    "        # Check if gradients are stored for this parameter\n",
-    "        if param.requires_grad:\n",
-    "            total_grads += param_size\n",
-    "\n",
-    "    # Calculate buffer size (non-parameters that require memory)\n",
-    "    total_buffers = sum(buf.numel() for buf in model.buffers())\n",
-    "\n",
-    "    # Size in bytes = (Number of elements) * (Size of each element in bytes)\n",
-    "    # We assume parameters and gradients are stored in the same type as input dtype\n",
-    "    element_size = torch.tensor(0, dtype=input_dtype).element_size()\n",
-    "    total_memory_bytes = (total_params + total_grads + total_buffers) * element_size\n",
-    "\n",
-    "    # Convert bytes to gigabytes\n",
-    "    total_memory_gb = total_memory_bytes / (1024**3)\n",
-    "\n",
-    "    return total_memory_gb\n",
-    "\n",
-    "print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
-    "print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 12,
-   "id": "31f12baf-f79b-499f-85c0-51328a6a20f5",
-   "metadata": {
-    "id": "31f12baf-f79b-499f-85c0-51328a6a20f5"
-   },
-   "outputs": [],
-   "source": [
-    "if torch.cuda.is_available():\n",
-    "    device = torch.device(\"cuda\")\n",
-    "elif torch.backends.mps.is_available():\n",
-    "    device = torch.device(\"mps\")\n",
-    "else:\n",
-    "    device = torch.device(\"cpu\")\n",
-    "\n",
-    "model.to(device);"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "78e091e1-afa8-4d23-9aea-cced86181bfd",
-   "metadata": {
-    "id": "78e091e1-afa8-4d23-9aea-cced86181bfd"
-   },
-   "source": [
-    "&nbsp;\n",
-    "# 3. Load tokenizer"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 13,
-   "id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77",
-   "metadata": {
-    "id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77"
-   },
-   "outputs": [],
-   "source": [
-    "import os\n",
-    "from pathlib import Path\n",
-    "\n",
-    "import tiktoken\n",
-    "from tiktoken.load import load_tiktoken_bpe\n",
-    "\n",
-    "\n",
-    "class Tokenizer:\n",
-    "    def __init__(self, model_path):\n",
-    "        assert os.path.isfile(model_path), f\"Model file {model_path} not found\"\n",
-    "        mergeable_ranks = load_tiktoken_bpe(model_path)\n",
-    "\n",
-    "        self.special_tokens = {\n",
-    "            \"<|begin_of_text|>\": 128000,\n",
-    "            \"<|end_of_text|>\": 128001,\n",
-    "            \"<|start_header_id|>\": 128006,\n",
-    "            \"<|end_header_id|>\": 128007,\n",
-    "            \"<|eot_id|>\": 128009,\n",
-    "        }\n",
-    "        self.special_tokens.update({\n",
-    "            f\"<|reserved_{i}|>\": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()\n",
-    "        })\n",
-    "\n",
-    "        self.model = tiktoken.Encoding(\n",
-    "            name=Path(model_path).name,\n",
-    "            pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\",\n",
-    "            mergeable_ranks=mergeable_ranks,\n",
-    "            special_tokens=self.special_tokens\n",
-    "        )\n",
-    "\n",
-    "\n",
-    "    def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):\n",
-    "        if bos:\n",
-    "            tokens = [self.special_tokens[\"<|begin_of_text|>\"]]\n",
-    "        else:\n",
-    "            tokens = []\n",
-    "\n",
-    "        tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)\n",
-    "\n",
-    "        if eos:\n",
-    "            tokens.append(self.special_tokens[\"<|end_of_text|>\"])\n",
-    "        return tokens\n",
-    "\n",
-    "    def decode(self, tokens):\n",
-    "        return self.model.decode(tokens)\n",
-    "\n",
-    "\n",
-    "class ChatFormat:\n",
-    "    def __init__(self, tokenizer):\n",
-    "        self.tokenizer = tokenizer\n",
-    "\n",
-    "    def encode_header(self, message):\n",
-    "        tokens = []\n",
-    "        tokens.append(self.tokenizer.special_tokens[\"<|start_header_id|>\"])\n",
-    "        tokens.extend(self.tokenizer.encode(message[\"role\"], bos=False, eos=False))\n",
-    "        tokens.append(self.tokenizer.special_tokens[\"<|end_header_id|>\"])\n",
-    "        tokens.extend(self.tokenizer.encode(\"\\n\\n\", bos=False, eos=False))\n",
-    "        return tokens\n",
-    "\n",
-    "    def encode(self, text):\n",
-    "        message = {\n",
-    "            \"role\": \"user\",\n",
-    "            \"content\": text\n",
-    "        }\n",
-    "\n",
-    "        tokens = self.encode_header(message)\n",
-    "        tokens.extend(\n",
-    "            self.tokenizer.encode(message[\"content\"].strip(), bos=False, eos=False)\n",
-    "        )\n",
-    "        tokens.append(self.tokenizer.special_tokens[\"<|eot_id|>\"])\n",
-    "        return tokens\n",
-    "\n",
-    "    def decode(self, token_ids):\n",
-    "        return self.tokenizer.decode(token_ids)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "b771b60c-c198-4b30-bf10-42031197ae86",
-   "metadata": {
-    "id": "b771b60c-c198-4b30-bf10-42031197ae86"
-   },
-   "source": [
-    "- Please note that Meta AI requires that you accept the Llama 3.2 licensing terms before you can download the files; to do this, you have to create a Hugging Face Hub account and visit the [meta-llama/Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B) repository to accept the terms\n",
-    "- Next, you will need to create an access token; to generate an access token with READ permissions, click on the profile picture in the upper right and click on \"Settings\"\n",
-    "\n",
-    "\n",
-    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/settings.webp?1\" width=\"300px\">\n",
-    "\n",
-    "- Then, create and copy the access token so you can copy & paste it into the next code cell\n",
-    "\n",
-    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/access-token.webp?1\" width=\"600px\">"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 14,
-   "id": "e9d96dc8-603a-4cb5-8c3e-4d2ca56862ed",
-   "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/"
-    },
-    "id": "e9d96dc8-603a-4cb5-8c3e-4d2ca56862ed",
-    "outputId": "e6e6dc05-7330-45bc-a9a7-331919155bdd"
-   },
-   "outputs": [],
-   "source": [
-    "from huggingface_hub import login\n",
-    "\n",
-    "login()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 15,
-   "id": "986bc1a0-804f-4154-80f8-44cefbee1368",
-   "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/",
-     "height": 141,
-     "referenced_widgets": [
-      "a1608feac06d4687967a3e398f01c489",
-      "518fb202e4b44aaba47f07d1a61b6762",
-      "672cdc5aea954de3af851c001a667ad3",
-      "eebf8874618746b39cf4a21a2728dc7f",
-      "5176834aa8784bba9ec21234b87a8948",
-      "e2dc407afcd945c798e30597fddfcb3c",
-      "0dccd57dcc5c43a588157cef957c07e8",
-      "33ca0cdf2c7f41598a381c4ebe6a4ee1",
-      "ee44487f58454dacb522b1e084ffb733",
-      "d2c41e71a3f441deaed091b620ac5603",
-      "3326b6141a1a4eba9f316df528a9b99a"
-     ]
-    },
-    "id": "986bc1a0-804f-4154-80f8-44cefbee1368",
-    "outputId": "5dd7334b-4c71-465a-94d2-c3e95b9ddc58"
-   },
-   "outputs": [],
-   "source": [
-    "from huggingface_hub import hf_hub_download\n",
-    "\n",
-    "tokenizer_file_path = hf_hub_download(\n",
-    "    repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n",
-    "    filename=\"original/tokenizer.model\",\n",
-    "    local_dir=f\"Llama-3.2-{LLAMA_SIZE_STR}-Instruct\"\n",
-    ")"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 16,
-   "id": "_gBhxDtU_nxo",
-   "metadata": {
-    "id": "_gBhxDtU_nxo"
-   },
-   "outputs": [],
-   "source": [
-    "tokenizer = Tokenizer(tokenizer_file_path)\n",
-    "chat_tokenizer = ChatFormat(tokenizer)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "c172f89f-d301-439f-b809-46169e5f5945",
-   "metadata": {
-    "id": "c172f89f-d301-439f-b809-46169e5f5945"
-   },
-   "source": [
-    "&nbsp;\n",
-    "# 4. Load pretrained weights"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 17,
-   "id": "75166128-5899-4995-9b88-9672e135650e",
-   "metadata": {
-    "id": "75166128-5899-4995-9b88-9672e135650e"
-   },
-   "outputs": [],
-   "source": [
-    "def assign(left, right, tensor_name=\"unknown\"):\n",
-    "    if left.shape != right.shape:\n",
-    "        raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
-    "\n",
-    "    if isinstance(right, torch.Tensor):\n",
-    "        return torch.nn.Parameter(right.clone().detach())\n",
-    "    else:\n",
-    "        return torch.nn.Parameter(torch.tensor(right))\n",
-    "\n",
-    "\n",
-    "def load_weights_into_llama(model, param_config, params):\n",
-    "    model.tok_emb.weight = assign(model.tok_emb.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n",
-    "\n",
-    "    for l in range(param_config[\"n_layers\"]):\n",
-    "\n",
-    "        # Load attention weights\n",
-    "        model.trf_blocks[l].att.W_query.weight = assign(\n",
-    "            model.trf_blocks[l].att.W_query.weight,\n",
-    "            params[f\"model.layers.{l}.self_attn.q_proj.weight\"],\n",
-    "            f\"model.layers.{l}.self_attn.q_proj.weight\"\n",
-    "        )\n",
-    "        model.trf_blocks[l].att.W_key.weight = assign(\n",
-    "            model.trf_blocks[l].att.W_key.weight,\n",
-    "            params[f\"model.layers.{l}.self_attn.k_proj.weight\"],\n",
-    "            f\"model.layers.{l}.self_attn.k_proj.weight\"\n",
-    "        )\n",
-    "        model.trf_blocks[l].att.W_value.weight = assign(\n",
-    "            model.trf_blocks[l].att.W_value.weight,\n",
-    "            params[f\"model.layers.{l}.self_attn.v_proj.weight\"],\n",
-    "            f\"model.layers.{l}.self_attn.v_proj.weight\"\n",
-    "        )\n",
-    "        model.trf_blocks[l].att.out_proj.weight = assign(\n",
-    "            model.trf_blocks[l].att.out_proj.weight,\n",
-    "            params[f\"model.layers.{l}.self_attn.o_proj.weight\"],\n",
-    "            f\"model.layers.{l}.self_attn.o_proj.weight\"\n",
-    "        )\n",
-    "        model.trf_blocks[l].norm1.weight = assign(\n",
-    "            model.trf_blocks[l].norm1.weight,\n",
-    "            params[f\"model.layers.{l}.input_layernorm.weight\"],\n",
-    "            f\"model.layers.{l}.input_layernorm.weight\"\n",
-    "        )\n",
-    "\n",
-    "        # Load FeedForward weights\n",
-    "        model.trf_blocks[l].ff.fc1.weight = assign(\n",
-    "            model.trf_blocks[l].ff.fc1.weight,\n",
-    "            params[f\"model.layers.{l}.mlp.gate_proj.weight\"],\n",
-    "            f\"model.layers.{l}.mlp.gate_proj.weight\"\n",
-    "        )\n",
-    "        model.trf_blocks[l].ff.fc2.weight = assign(\n",
-    "            model.trf_blocks[l].ff.fc2.weight,\n",
-    "            params[f\"model.layers.{l}.mlp.up_proj.weight\"],\n",
-    "            f\"model.layers.{l}.mlp.up_proj.weight\"\n",
-    "        )\n",
-    "        model.trf_blocks[l].ff.fc3.weight = assign(\n",
-    "            model.trf_blocks[l].ff.fc3.weight,\n",
-    "            params[f\"model.layers.{l}.mlp.down_proj.weight\"],\n",
-    "            f\"model.layers.{l}.mlp.down_proj.weight\"\n",
-    "        )\n",
-    "        model.trf_blocks[l].norm2.weight = assign(\n",
-    "            model.trf_blocks[l].norm2.weight,\n",
-    "            params[f\"model.layers.{l}.post_attention_layernorm.weight\"],\n",
-    "            f\"model.layers.{l}.post_attention_layernorm.weight\"\n",
-    "        )\n",
-    "\n",
-    "    # Load output layer weights\n",
-    "    model.final_norm.weight = assign(model.final_norm.weight, params[\"model.norm.weight\"], \"model.norm.weight\")\n",
-    "\n",
-    "    if \"lm_head.weight\" in params.keys():\n",
-    "        model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n",
-    "    else:\n",
-    "        model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n",
-    "        print(\"Model uses weight tying.\")"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 18,
-   "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
-   "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/",
-     "height": 17,
-     "referenced_widgets": [
-      "9881b6995c3f49dc89e6992fd9ab660b",
-      "17a3174e65c54476b2e0d1faf8f011ca",
-      "1bbf2e62c0754d1593beb4105a7f1ac1",
-      "b82112e1dec645d98aa1c1ba64abcb61",
-      "271e2bd6a35e4a8b92de8697f7c0be5f",
-      "90a79523187446dfa692723b2e5833a7",
-      "431ffb83b8c14bf182f0430e07ea6154",
-      "a8f1b72a33dd4b548de23fbd95e0da18",
-      "25cc36132d384189acfbecc59483134b",
-      "bfd06423ad544218968648016e731a46",
-      "d029630b63ff44cf807ade428d2eb421"
-     ]
-    },
-    "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
-    "outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d"
-   },
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Model uses weight tying.\n"
-     ]
-    }
-   ],
-   "source": [
-    "from safetensors.torch import load_file\n",
-    "\n",
-    "\n",
-    "if LLAMA_SIZE_STR == \"1B\":\n",
-    "    weights_file = hf_hub_download(\n",
-    "        repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n",
-    "        filename=\"model.safetensors\",\n",
-    "        local_dir=f\"Llama-3.2-{LLAMA_SIZE_STR}-Instruct\"\n",
-    "    )\n",
-    "    combined_weights = load_file(weights_file)\n",
-    "\n",
-    "\n",
-    "else:\n",
-    "    combined_weights = {}\n",
-    "    for i in range(1, 3):\n",
-    "        weights_file = hf_hub_download(\n",
-    "            repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n",
-    "            filename=f\"model-0000{i}-of-00002.safetensors\",\n",
-    "            local_dir=f\"Llama-3.2-{LLAMA_SIZE_STR}-Instruct\"\n",
-    "        )\n",
-    "        current_weights = load_file(weights_file)\n",
-    "        combined_weights.update(current_weights)\n",
-    "\n",
-    "\n",
-    "load_weights_into_llama(model, LLAMA32_CONFIG, combined_weights)\n",
-    "model.to(device)\n",
-    "del combined_weights  # free up memory"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 19,
-   "id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37",
-   "metadata": {
-    "id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37"
-   },
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Weight tying: True\n"
-     ]
-    }
-   ],
-   "source": [
-    "print(\"Weight tying:\", torch.equal(model.tok_emb.weight, model.out_head.weight))"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "57d07df1-4401-4792-b549-7c4cc5632323",
-   "metadata": {
-    "id": "57d07df1-4401-4792-b549-7c4cc5632323"
-   },
-   "source": [
-    "&nbsp;\n",
-    "# 5. Generate text"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 20,
-   "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
-   "metadata": {
-    "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
-   },
-   "outputs": [],
-   "source": [
-    "def text_to_token_ids(text, tokenizer):\n",
-    "    encoded = tokenizer.encode(text)\n",
-    "    encoded_tensor = torch.tensor(encoded).unsqueeze(0)  # add batch dimension\n",
-    "    return encoded_tensor\n",
-    "\n",
-    "\n",
-    "def token_ids_to_text(token_ids, tokenizer):\n",
-    "    flat = token_ids.squeeze(0)  # remove batch dimension\n",
-    "    return tokenizer.decode(flat.tolist())\n",
-    "\n",
-    "\n",
-    "def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):\n",
-    "\n",
-    "    # For-loop is the same as before: Get logits, and only focus on last time step\n",
-    "    for _ in range(max_new_tokens):\n",
-    "        idx_cond = idx[:, -context_size:]\n",
-    "        with torch.no_grad():\n",
-    "            logits = model(idx_cond)\n",
-    "        logits = logits[:, -1, :]\n",
-    "\n",
-    "        # New: Filter logits with top_k sampling\n",
-    "        if top_k is not None:\n",
-    "            # Keep only top_k values\n",
-    "            top_logits, _ = torch.topk(logits, top_k)\n",
-    "            min_val = top_logits[:, -1]\n",
-    "            logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)\n",
-    "\n",
-    "        # New: Apply temperature scaling\n",
-    "        if temperature > 0.0:\n",
-    "            logits = logits / temperature\n",
-    "\n",
-    "            # Apply softmax to get probabilities\n",
-    "            probs = torch.softmax(logits, dim=-1)  # (batch_size, context_len)\n",
-    "\n",
-    "            # Sample from the distribution\n",
-    "            idx_next = torch.multinomial(probs, num_samples=1)  # (batch_size, 1)\n",
-    "\n",
-    "        # Otherwise same as before: get idx of the vocab entry with the highest logits value\n",
-    "        else:\n",
-    "            idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # (batch_size, 1)\n",
-    "\n",
-    "        if idx_next == eos_id:  # Stop generating early if end-of-sequence token is encountered and eos_id is specified\n",
-    "            break\n",
-    "\n",
-    "        # Same as before: append sampled index to the running sequence\n",
-    "        idx = torch.cat((idx, idx_next), dim=1)  # (batch_size, num_tokens+1)\n",
-    "\n",
-    "    return idx"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 24,
-   "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
-   "metadata": {
-    "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
-   },
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Time: 19.49 sec\n",
-      "\n",
-      "\n",
-      "Output text:\n",
-      "\n",
-      " Llamas are herbivores, which means they primarily eat plants and plant-based foods. Their diet typically consists of:\n",
-      "\n",
-      "1. Grasses: Llamas love to graze on various types of grasses, including tall grasses and short grasses.\n",
-      "2. Hay: Llamas also eat hay, which is a dry, compressed form of grass or other plants.\n",
-      "3. Alfalfa: Alfalfa is a legume that is commonly used as a hay substitute in llama feed.\n",
-      "4. Other plants: Llamas will also eat other plants, such as clover, dandelions, and wild grasses.\n",
-      "\n",
-      "It's worth noting that the specific diet of llamas can vary depending on factors such as\n"
-     ]
-    }
-   ],
-   "source": [
-    "import time\n",
-    "\n",
-    "\n",
-    "PROMPT = \"What do llamas eat?\"\n",
-    "\n",
-    "torch.manual_seed(123)\n",
-    "\n",
-    "start = time.time()\n",
-    "\n",
-    "token_ids = generate(\n",
-    "    model=model,\n",
-    "    idx=text_to_token_ids(PROMPT, chat_tokenizer).to(device),\n",
-    "    max_new_tokens=150,\n",
-    "    context_size=LLAMA32_CONFIG[\"context_length\"],\n",
-    "    top_k=1,\n",
-    "    temperature=0.\n",
-    ")\n",
-    "\n",
-    "print(f\"Time: {time.time() - start:.2f} sec\")\n",
-    "\n",
-    "if torch.cuda.is_available():\n",
-    "    max_mem_bytes = torch.cuda.max_memory_allocated()\n",
-    "    max_mem_gb = max_mem_bytes / (1024 ** 3)\n",
-    "    print(f\"Max memory allocated: {max_mem_gb:.2f} GB\")\n",
-    "\n",
-    "output_text = token_ids_to_text(token_ids, tokenizer)\n",
-    "\n",
-    "\n",
-    "def clean_text(text, header_end=\"assistant<|end_header_id|>\\n\\n\"):\n",
-    "    # Find the index of the first occurrence of \"<|end_header_id|>\"\n",
-    "    index = text.find(header_end)\n",
-    "\n",
-    "    if index != -1:\n",
-    "        # Return the substring starting after \"<|end_header_id|>\"\n",
-    "        return text[index + len(header_end):].strip()  # Strip removes leading/trailing whitespace\n",
-    "    else:\n",
-    "        # If the token is not found, return the original text\n",
-    "        return text\n",
-    "\n",
-    "print(\"\\n\\nOutput text:\\n\\n\", clean_text(output_text))"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "549324d6-5c71-4147-ae21-2e67675faa3d",
-   "metadata": {
-    "id": "549324d6-5c71-4147-ae21-2e67675faa3d"
-   },
-   "source": [
-    "&nbsp;\n",
-    "# What's next?"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c",
-   "metadata": {
-    "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c"
-   },
-   "source": [
-    "- The notebook was kept purposefully minimal; if you are interested in additional explanation about the individual components, check out the following two companion notebooks:\n",
-    "\n",
-    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/gpt-and-all-llamas.webp\">\n",
-    "\n",
-    "  1. [Converting a From-Scratch GPT Architecture to Llama 2](converting-gpt-to-llama2.ipynb)\n",
-    "  2. [Converting Llama 2 to Llama 3.2 From Scratch](converting-llama2-to-llama3.ipynb)\n",
-    "  \n",
-    "- For those interested in a comprehensive guide on building a large language model from scratch and gaining a deeper understanding of its mechanics, you might like my [Build a Large Language Model (From Scratch)](http://mng.bz/orYv)\n",
-    "\n",
-    "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>"
-   ]
-  }
- ],
- "metadata": {
-  "accelerator": "GPU",
-  "colab": {
-   "gpuType": "A100",
-   "provenance": []
-  },
-  "kernelspec": {
-   "display_name": "Python 3 (ipykernel)",
-   "language": "python",
-   "name": "python3"
-  },
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.10.16"
-  },
-  "widgets": {
-   "application/vnd.jupyter.widget-state+json": {
-    "0dccd57dcc5c43a588157cef957c07e8": {
-     "model_module": "@jupyter-widgets/controls",
-     "model_module_version": "2.0.0",
-     "model_name": "HTMLStyleModel",
-     "state": {
-      "_model_module": "@jupyter-widgets/controls",
-      "_model_module_version": "2.0.0",
-      "_model_name": "HTMLStyleModel",
-      "_view_count": null,
-      "_view_module": "@jupyter-widgets/base",
-      "_view_module_version": "2.0.0",
-      "_view_name": "StyleView",
-      "background": null,
-      "description_width": "",
-      "font_size": null,
-      "text_color": null
-     }
-    },
-    "17a3174e65c54476b2e0d1faf8f011ca": {
-     "model_module": "@jupyter-widgets/controls",
-     "model_module_version": "2.0.0",
-     "model_name": "HTMLModel",
-     "state": {
-      "_dom_classes": [],
-      "_model_module": "@jupyter-widgets/controls",
-      "_model_module_version": "2.0.0",
-      "_model_name": "HTMLModel",
-      "_view_count": null,
-      "_view_module": "@jupyter-widgets/controls",
-      "_view_module_version": "2.0.0",
-      "_view_name": "HTMLView",
-      "description": "",
-      "description_allow_html": false,
-      "layout": "IPY_MODEL_90a79523187446dfa692723b2e5833a7",
-      "placeholder": "​",
-      "style": "IPY_MODEL_431ffb83b8c14bf182f0430e07ea6154",
-      "tabbable": null,
-      "tooltip": null,
-      "value": "model.safetensors:  35%"
-     }
-    },
-    "1bbf2e62c0754d1593beb4105a7f1ac1": {
-     "model_module": "@jupyter-widgets/controls",
-     "model_module_version": "2.0.0",
-     "model_name": "FloatProgressModel",
-     "state": {
-      "_dom_classes": [],
-      "_model_module": "@jupyter-widgets/controls",
-      "_model_module_version": "2.0.0",
-      "_model_name": "FloatProgressModel",
-      "_view_count": null,
-      "_view_module": "@jupyter-widgets/controls",
-      "_view_module_version": "2.0.0",
-      "_view_name": "ProgressView",
-      "bar_style": "",
-      "description": "",
-      "description_allow_html": false,
-      "layout": "IPY_MODEL_a8f1b72a33dd4b548de23fbd95e0da18",
-      "max": 2471645608,
-      "min": 0,
-      "orientation": "horizontal",
-      "style": "IPY_MODEL_25cc36132d384189acfbecc59483134b",
-      "tabbable": null,
-      "tooltip": null,
-      "value": 880803840
-     }
-    },
-    "25cc36132d384189acfbecc59483134b": {
-     "model_module": "@jupyter-widgets/controls",
-     "model_module_version": "2.0.0",
-     "model_name": "ProgressStyleModel",
-     "state": {
-      "_model_module": "@jupyter-widgets/controls",
-      "_model_module_version": "2.0.0",
-      "_model_name": "ProgressStyleModel",
-      "_view_count": null,
-      "_view_module": "@jupyter-widgets/base",
-      "_view_module_version": "2.0.0",
-      "_view_name": "StyleView",
-      "bar_color": null,
-      "description_width": ""
-     }
-    },
-    "271e2bd6a35e4a8b92de8697f7c0be5f": {
-     "model_module": "@jupyter-widgets/base",
-     "model_module_version": "2.0.0",
-     "model_name": "LayoutModel",
-     "state": {
-      "_model_module": "@jupyter-widgets/base",
-      "_model_module_version": "2.0.0",
-      "_model_name": "LayoutModel",
-      "_view_count": null,
-      "_view_module": "@jupyter-widgets/base",
-      "_view_module_version": "2.0.0",
-      "_view_name": "LayoutView",
-      "align_content": null,
-      "align_items": null,
-      "align_self": null,
-      "border_bottom": null,
-      "border_left": null,
-      "border_right": null,
-      "border_top": null,
-      "bottom": null,
-      "display": null,
-      "flex": null,
-      "flex_flow": null,
-      "grid_area": null,
-      "grid_auto_columns": null,
-      "grid_auto_flow": null,
-      "grid_auto_rows": null,
-      "grid_column": null,
-      "grid_gap": null,
-      "grid_row": null,
-      "grid_template_areas": null,
-      "grid_template_columns": null,
-      "grid_template_rows": null,
-      "height": null,
-      "justify_content": null,
-      "justify_items": null,
-      "left": null,
-      "margin": null,
-      "max_height": null,
-      "max_width": null,
-      "min_height": null,
-      "min_width": null,
-      "object_fit": null,
-      "object_position": null,
-      "order": null,
-      "overflow": null,
-      "padding": null,
-      "right": null,
-      "top": null,
-      "visibility": null,
-      "width": null
-     }
-    },
-    "3326b6141a1a4eba9f316df528a9b99a": {
-     "model_module": "@jupyter-widgets/controls",
-     "model_module_version": "2.0.0",
-     "model_name": "HTMLStyleModel",
-     "state": {
-      "_model_module": "@jupyter-widgets/controls",
-      "_model_module_version": "2.0.0",
-      "_model_name": "HTMLStyleModel",
-      "_view_count": null,
-      "_view_module": "@jupyter-widgets/base",
-      "_view_module_version": "2.0.0",
-      "_view_name": "StyleView",
-      "background": null,
-      "description_width": "",
-      "font_size": null,
-      "text_color": null
-     }
-    },
-    "33ca0cdf2c7f41598a381c4ebe6a4ee1": {
-     "model_module": "@jupyter-widgets/base",
-     "model_module_version": "2.0.0",
-     "model_name": "LayoutModel",
-     "state": {
-      "_model_module": "@jupyter-widgets/base",
-      "_model_module_version": "2.0.0",
-      "_model_name": "LayoutModel",
-      "_view_count": null,
-      "_view_module": "@jupyter-widgets/base",
-      "_view_module_version": "2.0.0",
-      "_view_name": "LayoutView",
-      "align_content": null,
-      "align_items": null,
-      "align_self": null,
-      "border_bottom": null,
-      "border_left": null,
-      "border_right": null,
-      "border_top": null,
-      "bottom": null,
-      "display": null,
-      "flex": null,
-      "flex_flow": null,
-      "grid_area": null,
-      "grid_auto_columns": null,
-      "grid_auto_flow": null,
-      "grid_auto_rows": null,
-      "grid_column": null,
-      "grid_gap": null,
-      "grid_row": null,
-      "grid_template_areas": null,
-      "grid_template_columns": null,
-      "grid_template_rows": null,
-      "height": null,
-      "justify_content": null,
-      "justify_items": null,
-      "left": null,
-      "margin": null,
-      "max_height": null,
-      "max_width": null,
-      "min_height": null,
-      "min_width": null,
-      "object_fit": null,
-      "object_position": null,
-      "order": null,
-      "overflow": null,
-      "padding": null,
-      "right": null,
-      "top": null,
-      "visibility": null,
-      "width": null
-     }
-    },
-    "431ffb83b8c14bf182f0430e07ea6154": {
-     "model_module": "@jupyter-widgets/controls",
-     "model_module_version": "2.0.0",
-     "model_name": "HTMLStyleModel",
-     "state": {
-      "_model_module": "@jupyter-widgets/controls",
-      "_model_module_version": "2.0.0",
-      "_model_name": "HTMLStyleModel",
-      "_view_count": null,
-      "_view_module": "@jupyter-widgets/base",
-      "_view_module_version": "2.0.0",
-      "_view_name": "StyleView",
-      "background": null,
-      "description_width": "",
-      "font_size": null,
-      "text_color": null
-     }
-    },
-    "5176834aa8784bba9ec21234b87a8948": {
-     "model_module": "@jupyter-widgets/base",
-     "model_module_version": "2.0.0",
-     "model_name": "LayoutModel",
-     "state": {
-      "_model_module": "@jupyter-widgets/base",
-      "_model_module_version": "2.0.0",
-      "_model_name": "LayoutModel",
-      "_view_count": null,
-      "_view_module": "@jupyter-widgets/base",
-      "_view_module_version": "2.0.0",
-      "_view_name": "LayoutView",
-      "align_content": null,
-      "align_items": null,
-      "align_self": null,
-      "border_bottom": null,
-      "border_left": null,
-      "border_right": null,
-      "border_top": null,
-      "bottom": null,
-      "display": null,
-      "flex": null,
-      "flex_flow": null,
-      "grid_area": null,
-      "grid_auto_columns": null,
-      "grid_auto_flow": null,
-      "grid_auto_rows": null,
-      "grid_column": null,
-      "grid_gap": null,
-      "grid_row": null,
-      "grid_template_areas": null,
-      "grid_template_columns": null,
-      "grid_template_rows": null,
-      "height": null,
-      "justify_content": null,
-      "justify_items": null,
-      "left": null,
-      "margin": null,
-      "max_height": null,
-      "max_width": null,
-      "min_height": null,
-      "min_width": null,
-      "object_fit": null,
-      "object_position": null,
-      "order": null,
-      "overflow": null,
-      "padding": null,
-      "right": null,
-      "top": null,
-      "visibility": null,
-      "width": null
-     }
-    },
-    "518fb202e4b44aaba47f07d1a61b6762": {
-     "model_module": "@jupyter-widgets/controls",
-     "model_module_version": "2.0.0",
-     "model_name": "HTMLModel",
-     "state": {
-      "_dom_classes": [],
-      "_model_module": "@jupyter-widgets/controls",
-      "_model_module_version": "2.0.0",
-      "_model_name": "HTMLModel",
-      "_view_count": null,
-      "_view_module": "@jupyter-widgets/controls",
-      "_view_module_version": "2.0.0",
-      "_view_name": "HTMLView",
-      "description": "",
-      "description_allow_html": false,
-      "layout": "IPY_MODEL_e2dc407afcd945c798e30597fddfcb3c",
-      "placeholder": "​",
-      "style": "IPY_MODEL_0dccd57dcc5c43a588157cef957c07e8",
-      "tabbable": null,
-      "tooltip": null,
-      "value": "tokenizer.model: 100%"
-     }
-    },
-    "672cdc5aea954de3af851c001a667ad3": {
-     "model_module": "@jupyter-widgets/controls",
-     "model_module_version": "2.0.0",
-     "model_name": "FloatProgressModel",
-     "state": {
-      "_dom_classes": [],
-      "_model_module": "@jupyter-widgets/controls",
-      "_model_module_version": "2.0.0",
-      "_model_name": "FloatProgressModel",
-      "_view_count": null,
-      "_view_module": "@jupyter-widgets/controls",
-      "_view_module_version": "2.0.0",
-      "_view_name": "ProgressView",
-      "bar_style": "success",
-      "description": "",
-      "description_allow_html": false,
-      "layout": "IPY_MODEL_33ca0cdf2c7f41598a381c4ebe6a4ee1",
-      "max": 2183982,
-      "min": 0,
-      "orientation": "horizontal",
-      "style": "IPY_MODEL_ee44487f58454dacb522b1e084ffb733",
-      "tabbable": null,
-      "tooltip": null,
-      "value": 2183982
-     }
-    },
-    "90a79523187446dfa692723b2e5833a7": {
-     "model_module": "@jupyter-widgets/base",
-     "model_module_version": "2.0.0",
-     "model_name": "LayoutModel",
-     "state": {
-      "_model_module": "@jupyter-widgets/base",
-      "_model_module_version": "2.0.0",
-      "_model_name": "LayoutModel",
-      "_view_count": null,
-      "_view_module": "@jupyter-widgets/base",
-      "_view_module_version": "2.0.0",
-      "_view_name": "LayoutView",
-      "align_content": null,
-      "align_items": null,
-      "align_self": null,
-      "border_bottom": null,
-      "border_left": null,
-      "border_right": null,
-      "border_top": null,
-      "bottom": null,
-      "display": null,
-      "flex": null,
-      "flex_flow": null,
-      "grid_area": null,
-      "grid_auto_columns": null,
-      "grid_auto_flow": null,
-      "grid_auto_rows": null,
-      "grid_column": null,
-      "grid_gap": null,
-      "grid_row": null,
-      "grid_template_areas": null,
-      "grid_template_columns": null,
-      "grid_template_rows": null,
-      "height": null,
-      "justify_content": null,
-      "justify_items": null,
-      "left": null,
-      "margin": null,
-      "max_height": null,
-      "max_width": null,
-      "min_height": null,
-      "min_width": null,
-      "object_fit": null,
-      "object_position": null,
-      "order": null,
-      "overflow": null,
-      "padding": null,
-      "right": null,
-      "top": null,
-      "visibility": null,
-      "width": null
-     }
-    },
-    "9881b6995c3f49dc89e6992fd9ab660b": {
-     "model_module": "@jupyter-widgets/controls",
-     "model_module_version": "2.0.0",
-     "model_name": "HBoxModel",
-     "state": {
-      "_dom_classes": [],
-      "_model_module": "@jupyter-widgets/controls",
-      "_model_module_version": "2.0.0",
-      "_model_name": "HBoxModel",
-      "_view_count": null,
-      "_view_module": "@jupyter-widgets/controls",
-      "_view_module_version": "2.0.0",
-      "_view_name": "HBoxView",
-      "box_style": "",
-      "children": [
-       "IPY_MODEL_17a3174e65c54476b2e0d1faf8f011ca",
-       "IPY_MODEL_1bbf2e62c0754d1593beb4105a7f1ac1",
-       "IPY_MODEL_b82112e1dec645d98aa1c1ba64abcb61"
-      ],
-      "layout": "IPY_MODEL_271e2bd6a35e4a8b92de8697f7c0be5f",
-      "tabbable": null,
-      "tooltip": null
-     }
-    },
-    "a1608feac06d4687967a3e398f01c489": {
-     "model_module": "@jupyter-widgets/controls",
-     "model_module_version": "2.0.0",
-     "model_name": "HBoxModel",
-     "state": {
-      "_dom_classes": [],
-      "_model_module": "@jupyter-widgets/controls",
-      "_model_module_version": "2.0.0",
-      "_model_name": "HBoxModel",
-      "_view_count": null,
-      "_view_module": "@jupyter-widgets/controls",
-      "_view_module_version": "2.0.0",
-      "_view_name": "HBoxView",
-      "box_style": "",
-      "children": [
-       "IPY_MODEL_518fb202e4b44aaba47f07d1a61b6762",
-       "IPY_MODEL_672cdc5aea954de3af851c001a667ad3",
-       "IPY_MODEL_eebf8874618746b39cf4a21a2728dc7f"
-      ],
-      "layout": "IPY_MODEL_5176834aa8784bba9ec21234b87a8948",
-      "tabbable": null,
-      "tooltip": null
-     }
-    },
-    "a8f1b72a33dd4b548de23fbd95e0da18": {
-     "model_module": "@jupyter-widgets/base",
-     "model_module_version": "2.0.0",
-     "model_name": "LayoutModel",
-     "state": {
-      "_model_module": "@jupyter-widgets/base",
-      "_model_module_version": "2.0.0",
-      "_model_name": "LayoutModel",
-      "_view_count": null,
-      "_view_module": "@jupyter-widgets/base",
-      "_view_module_version": "2.0.0",
-      "_view_name": "LayoutView",
-      "align_content": null,
-      "align_items": null,
-      "align_self": null,
-      "border_bottom": null,
-      "border_left": null,
-      "border_right": null,
-      "border_top": null,
-      "bottom": null,
-      "display": null,
-      "flex": null,
-      "flex_flow": null,
-      "grid_area": null,
-      "grid_auto_columns": null,
-      "grid_auto_flow": null,
-      "grid_auto_rows": null,
-      "grid_column": null,
-      "grid_gap": null,
-      "grid_row": null,
-      "grid_template_areas": null,
-      "grid_template_columns": null,
-      "grid_template_rows": null,
-      "height": null,
-      "justify_content": null,
-      "justify_items": null,
-      "left": null,
-      "margin": null,
-      "max_height": null,
-      "max_width": null,
-      "min_height": null,
-      "min_width": null,
-      "object_fit": null,
-      "object_position": null,
-      "order": null,
-      "overflow": null,
-      "padding": null,
-      "right": null,
-      "top": null,
-      "visibility": null,
-      "width": null
-     }
-    },
-    "b82112e1dec645d98aa1c1ba64abcb61": {
-     "model_module": "@jupyter-widgets/controls",
-     "model_module_version": "2.0.0",
-     "model_name": "HTMLModel",
-     "state": {
-      "_dom_classes": [],
-      "_model_module": "@jupyter-widgets/controls",
-      "_model_module_version": "2.0.0",
-      "_model_name": "HTMLModel",
-      "_view_count": null,
-      "_view_module": "@jupyter-widgets/controls",
-      "_view_module_version": "2.0.0",
-      "_view_name": "HTMLView",
-      "description": "",
-      "description_allow_html": false,
-      "layout": "IPY_MODEL_bfd06423ad544218968648016e731a46",
-      "placeholder": "​",
-      "style": "IPY_MODEL_d029630b63ff44cf807ade428d2eb421",
-      "tabbable": null,
-      "tooltip": null,
-      "value": " 870M/2.47G [00:20&lt;00:37, 42.8MB/s]"
-     }
-    },
-    "bfd06423ad544218968648016e731a46": {
-     "model_module": "@jupyter-widgets/base",
-     "model_module_version": "2.0.0",
-     "model_name": "LayoutModel",
-     "state": {
-      "_model_module": "@jupyter-widgets/base",
-      "_model_module_version": "2.0.0",
-      "_model_name": "LayoutModel",
-      "_view_count": null,
-      "_view_module": "@jupyter-widgets/base",
-      "_view_module_version": "2.0.0",
-      "_view_name": "LayoutView",
-      "align_content": null,
-      "align_items": null,
-      "align_self": null,
-      "border_bottom": null,
-      "border_left": null,
-      "border_right": null,
-      "border_top": null,
-      "bottom": null,
-      "display": null,
-      "flex": null,
-      "flex_flow": null,
-      "grid_area": null,
-      "grid_auto_columns": null,
-      "grid_auto_flow": null,
-      "grid_auto_rows": null,
-      "grid_column": null,
-      "grid_gap": null,
-      "grid_row": null,
-      "grid_template_areas": null,
-      "grid_template_columns": null,
-      "grid_template_rows": null,
-      "height": null,
-      "justify_content": null,
-      "justify_items": null,
-      "left": null,
-      "margin": null,
-      "max_height": null,
-      "max_width": null,
-      "min_height": null,
-      "min_width": null,
-      "object_fit": null,
-      "object_position": null,
-      "order": null,
-      "overflow": null,
-      "padding": null,
-      "right": null,
-      "top": null,
-      "visibility": null,
-      "width": null
-     }
-    },
-    "d029630b63ff44cf807ade428d2eb421": {
-     "model_module": "@jupyter-widgets/controls",
-     "model_module_version": "2.0.0",
-     "model_name": "HTMLStyleModel",
-     "state": {
-      "_model_module": "@jupyter-widgets/controls",
-      "_model_module_version": "2.0.0",
-      "_model_name": "HTMLStyleModel",
-      "_view_count": null,
-      "_view_module": "@jupyter-widgets/base",
-      "_view_module_version": "2.0.0",
-      "_view_name": "StyleView",
-      "background": null,
-      "description_width": "",
-      "font_size": null,
-      "text_color": null
-     }
-    },
-    "d2c41e71a3f441deaed091b620ac5603": {
-     "model_module": "@jupyter-widgets/base",
-     "model_module_version": "2.0.0",
-     "model_name": "LayoutModel",
-     "state": {
-      "_model_module": "@jupyter-widgets/base",
-      "_model_module_version": "2.0.0",
-      "_model_name": "LayoutModel",
-      "_view_count": null,
-      "_view_module": "@jupyter-widgets/base",
-      "_view_module_version": "2.0.0",
-      "_view_name": "LayoutView",
-      "align_content": null,
-      "align_items": null,
-      "align_self": null,
-      "border_bottom": null,
-      "border_left": null,
-      "border_right": null,
-      "border_top": null,
-      "bottom": null,
-      "display": null,
-      "flex": null,
-      "flex_flow": null,
-      "grid_area": null,
-      "grid_auto_columns": null,
-      "grid_auto_flow": null,
-      "grid_auto_rows": null,
-      "grid_column": null,
-      "grid_gap": null,
-      "grid_row": null,
-      "grid_template_areas": null,
-      "grid_template_columns": null,
-      "grid_template_rows": null,
-      "height": null,
-      "justify_content": null,
-      "justify_items": null,
-      "left": null,
-      "margin": null,
-      "max_height": null,
-      "max_width": null,
-      "min_height": null,
-      "min_width": null,
-      "object_fit": null,
-      "object_position": null,
-      "order": null,
-      "overflow": null,
-      "padding": null,
-      "right": null,
-      "top": null,
-      "visibility": null,
-      "width": null
-     }
-    },
-    "e2dc407afcd945c798e30597fddfcb3c": {
-     "model_module": "@jupyter-widgets/base",
-     "model_module_version": "2.0.0",
-     "model_name": "LayoutModel",
-     "state": {
-      "_model_module": "@jupyter-widgets/base",
-      "_model_module_version": "2.0.0",
-      "_model_name": "LayoutModel",
-      "_view_count": null,
-      "_view_module": "@jupyter-widgets/base",
-      "_view_module_version": "2.0.0",
-      "_view_name": "LayoutView",
-      "align_content": null,
-      "align_items": null,
-      "align_self": null,
-      "border_bottom": null,
-      "border_left": null,
-      "border_right": null,
-      "border_top": null,
-      "bottom": null,
-      "display": null,
-      "flex": null,
-      "flex_flow": null,
-      "grid_area": null,
-      "grid_auto_columns": null,
-      "grid_auto_flow": null,
-      "grid_auto_rows": null,
-      "grid_column": null,
-      "grid_gap": null,
-      "grid_row": null,
-      "grid_template_areas": null,
-      "grid_template_columns": null,
-      "grid_template_rows": null,
-      "height": null,
-      "justify_content": null,
-      "justify_items": null,
-      "left": null,
-      "margin": null,
-      "max_height": null,
-      "max_width": null,
-      "min_height": null,
-      "min_width": null,
-      "object_fit": null,
-      "object_position": null,
-      "order": null,
-      "overflow": null,
-      "padding": null,
-      "right": null,
-      "top": null,
-      "visibility": null,
-      "width": null
-     }
-    },
-    "ee44487f58454dacb522b1e084ffb733": {
-     "model_module": "@jupyter-widgets/controls",
-     "model_module_version": "2.0.0",
-     "model_name": "ProgressStyleModel",
-     "state": {
-      "_model_module": "@jupyter-widgets/controls",
-      "_model_module_version": "2.0.0",
-      "_model_name": "ProgressStyleModel",
-      "_view_count": null,
-      "_view_module": "@jupyter-widgets/base",
-      "_view_module_version": "2.0.0",
-      "_view_name": "StyleView",
-      "bar_color": null,
-      "description_width": ""
-     }
-    },
-    "eebf8874618746b39cf4a21a2728dc7f": {
-     "model_module": "@jupyter-widgets/controls",
-     "model_module_version": "2.0.0",
-     "model_name": "HTMLModel",
-     "state": {
-      "_dom_classes": [],
-      "_model_module": "@jupyter-widgets/controls",
-      "_model_module_version": "2.0.0",
-      "_model_name": "HTMLModel",
-      "_view_count": null,
-      "_view_module": "@jupyter-widgets/controls",
-      "_view_module_version": "2.0.0",
-      "_view_name": "HTMLView",
-      "description": "",
-      "description_allow_html": false,
-      "layout": "IPY_MODEL_d2c41e71a3f441deaed091b620ac5603",
-      "placeholder": "​",
-      "style": "IPY_MODEL_3326b6141a1a4eba9f316df528a9b99a",
-      "tabbable": null,
-      "tooltip": null,
-      "value": " 2.18M/2.18M [00:00&lt;00:00, 9.47MB/s]"
-     }
-    }
-   }
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}

+ 175 - 234
ch05/07_gpt_to_llama/standalone-llama32.ipynb

@@ -56,7 +56,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 1,
    "id": "7c201adb-747e-437b-9a62-442802941e01",
    "metadata": {},
    "outputs": [],
@@ -66,7 +66,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 2,
    "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
    "metadata": {
     "colab": {
@@ -81,9 +81,9 @@
      "output_type": "stream",
      "text": [
       "blobfile version: 3.0.0\n",
-      "huggingface_hub version: 0.25.2\n",
-      "tiktoken version: 0.8.0\n",
-      "torch version: 2.5.0\n"
+      "huggingface_hub version: 0.30.1\n",
+      "tiktoken version: 0.9.0\n",
+      "torch version: 2.6.0\n"
      ]
     }
    ],
@@ -113,7 +113,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 3,
    "id": "82076c21-9331-4dcd-b017-42b046cf1a60",
    "metadata": {
     "id": "82076c21-9331-4dcd-b017-42b046cf1a60"
@@ -140,18 +140,18 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 4,
    "id": "4b9a346f-5826-4083-9162-abd56afc03f0",
    "metadata": {
     "id": "4b9a346f-5826-4083-9162-abd56afc03f0"
    },
    "outputs": [],
    "source": [
-    "def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None):\n",
+    "def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32):\n",
     "    assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
     "\n",
     "    # Compute the inverse frequencies\n",
-    "    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n",
+    "    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))\n",
     "\n",
     "    # Frequency adjustments\n",
     "    if freq_config is not None:\n",
@@ -177,7 +177,7 @@
     "        inv_freq = inv_freq_llama\n",
     "\n",
     "    # Generate position indices\n",
-    "    positions = torch.arange(context_length)\n",
+    "    positions = torch.arange(context_length, dtype=dtype)\n",
     "\n",
     "    # Compute the angles\n",
     "    angles = positions[:, None] * inv_freq[None, :]  # Shape: (context_length, head_dim // 2)\n",
@@ -192,7 +192,7 @@
     "    return cos, sin\n",
     "\n",
     "\n",
-    "def compute_rope(x, cos, sin):\n",
+    "def apply_rope(x, cos, sin):\n",
     "    # x: (batch_size, num_heads, seq_len, head_dim)\n",
     "    batch_size, num_heads, seq_len, head_dim = x.shape\n",
     "    assert head_dim % 2 == 0, \"Head dimension must be even\"\n",
@@ -209,43 +209,23 @@
     "    rotated = torch.cat((-x2, x1), dim=-1)\n",
     "    x_rotated = (x * cos) + (rotated * sin)\n",
     "\n",
+    "    # It's ok to use lower-precision after applying cos and sin rotation\n",
     "    return x_rotated.to(dtype=x.dtype)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 5,
    "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb",
    "metadata": {
     "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb"
    },
    "outputs": [],
    "source": [
-    "class SharedBuffers:\n",
-    "    _buffers = {}\n",
-    "\n",
-    "    @staticmethod\n",
-    "    def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):\n",
-    "        key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)\n",
-    "\n",
-    "        if key not in SharedBuffers._buffers:\n",
-    "            # Create or fetch the buffers\n",
-    "            mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
-    "            cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)\n",
-    "            if dtype is not None:\n",
-    "                cos = cos.to(dtype)\n",
-    "                sin = sin.to(dtype)\n",
-    "            SharedBuffers._buffers[key] = (mask, cos, sin)\n",
-    "\n",
-    "        return SharedBuffers._buffers[key]\n",
-    "\n",
-    "\n",
     "class GroupedQueryAttention(nn.Module):\n",
     "    def __init__(\n",
-    "            self, d_in, d_out, context_length, num_heads,\n",
+    "            self, d_in, d_out, num_heads,\n",
     "            num_kv_groups,\n",
-    "            rope_base=10_000,\n",
-    "            rope_config=None,\n",
     "            dtype=None\n",
     "        ):\n",
     "        super().__init__()\n",
@@ -264,14 +244,7 @@
     "        self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
     "        self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n",
     "\n",
-    "        # Fetch buffers using SharedBuffers\n",
-    "        mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n",
-    "        self.register_buffer(\"mask\", mask, persistent=False)\n",
-    "\n",
-    "        self.register_buffer(\"cos\", cos, persistent=False)\n",
-    "        self.register_buffer(\"sin\", sin, persistent=False)\n",
-    "\n",
-    "    def forward(self, x):\n",
+    "    def forward(self, x, mask, cos, sin):\n",
     "        b, num_tokens, d_in = x.shape\n",
     "\n",
     "        queries = self.W_query(x)  # Shape: (b, num_tokens, d_out)\n",
@@ -289,8 +262,8 @@
     "        queries = queries.transpose(1, 2)  # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
     "\n",
     "        # Apply RoPE\n",
-    "        keys = compute_rope(keys, self.cos, self.sin)\n",
-    "        queries = compute_rope(queries, self.cos, self.sin)\n",
+    "        keys = apply_rope(keys, cos, sin)\n",
+    "        queries = apply_rope(queries, cos, sin)\n",
     "\n",
     "        # Expand keys and values to match the number of heads\n",
     "        # Shape: (b, num_heads, num_tokens, head_dim)\n",
@@ -307,11 +280,8 @@
     "        # Shape: (b, num_heads, num_tokens, num_tokens)\n",
     "        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head\n",
     "\n",
-    "        # Original mask truncated to the number of tokens and converted to boolean\n",
-    "        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
-    "\n",
-    "        # Use the mask to fill attention scores\n",
-    "        attn_scores.masked_fill_(mask_bool, -torch.inf)\n",
+    "        # Compute attention scores\n",
+    "        attn_scores = attn_scores.masked_fill(mask, -torch.inf)\n",
     "\n",
     "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
     "        assert keys.shape[-1] == self.head_dim\n",
@@ -328,7 +298,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 6,
    "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9",
    "metadata": {
     "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9"
@@ -338,31 +308,28 @@
     "class TransformerBlock(nn.Module):\n",
     "    def __init__(self, cfg):\n",
     "        super().__init__()\n",
-    "        self.att =  GroupedQueryAttention(\n",
+    "        self.att = GroupedQueryAttention(\n",
     "            d_in=cfg[\"emb_dim\"],\n",
     "            d_out=cfg[\"emb_dim\"],\n",
-    "            context_length=cfg[\"context_length\"],\n",
     "            num_heads=cfg[\"n_heads\"],\n",
     "            num_kv_groups=cfg[\"n_kv_groups\"],\n",
-    "            rope_base=cfg[\"rope_base\"],\n",
-    "            rope_config=cfg[\"rope_freq\"],\n",
     "            dtype=cfg[\"dtype\"]\n",
     "        )\n",
     "        self.ff = FeedForward(cfg)\n",
-    "        self.norm1 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
-    "        self.norm2 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
+    "        self.norm1 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
+    "        self.norm2 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
     "\n",
-    "    def forward(self, x):\n",
+    "    def forward(self, x, mask, cos, sin):\n",
     "        # Shortcut connection for attention block\n",
     "        shortcut = x\n",
     "        x = self.norm1(x)\n",
-    "        x = self.att(x.to(torch.bfloat16))   # Shape [batch_size, num_tokens, emb_size]\n",
+    "        x = self.att(x, mask, cos, sin)  # Shape [batch_size, num_tokens, emb_size]\n",
     "        x = x + shortcut  # Add the original input back\n",
     "\n",
     "        # Shortcut connection for feed-forward block\n",
     "        shortcut = x\n",
     "        x = self.norm2(x)\n",
-    "        x = self.ff(x.to(torch.bfloat16))\n",
+    "        x = self.ff(x)\n",
     "        x = x + shortcut  # Add the original input back\n",
     "\n",
     "        return x"
@@ -370,7 +337,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 7,
    "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4",
    "metadata": {
     "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4"
@@ -380,20 +347,41 @@
     "class Llama3Model(nn.Module):\n",
     "    def __init__(self, cfg):\n",
     "        super().__init__()\n",
+    "\n",
+    "        # Main model parameters\n",
     "        self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n",
     "\n",
-    "        self.trf_blocks = nn.Sequential(\n",
-    "            *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n",
+    "        self.trf_blocks = nn.ModuleList(  # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`\n",
+    "            [TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])]\n",
+    "        )\n",
     "\n",
-    "        self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
+    "        self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
     "        self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
     "\n",
+    "        # Reusuable utilities\n",
+    "        cos, sin = compute_rope_params(\n",
+    "            head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n",
+    "            theta_base=cfg[\"rope_base\"],\n",
+    "            context_length=cfg[\"context_length\"],\n",
+    "            freq_config=cfg[\"rope_freq\"]\n",
+    "        )\n",
+    "        self.register_buffer(\"cos\", cos, persistent=False)\n",
+    "        self.register_buffer(\"sin\", sin, persistent=False)\n",
+    "        self.cfg = cfg\n",
+    "\n",
+    "\n",
     "    def forward(self, in_idx):\n",
+    "        # Forward pass\n",
     "        tok_embeds = self.tok_emb(in_idx)\n",
     "        x = tok_embeds\n",
-    "        x = self.trf_blocks(x)\n",
+    "\n",
+    "        num_tokens = x.shape[1]\n",
+    "        mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)\n",
+    "        \n",
+    "        for block in self.trf_blocks:\n",
+    "            x = block(x, mask, self.cos, self.sin)\n",
     "        x = self.final_norm(x)\n",
-    "        logits = self.out_head(x.to(torch.bfloat16))\n",
+    "        logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n",
     "        return logits"
    ]
   },
@@ -420,7 +408,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 8,
    "id": "caa142fa-b375-4e78-b392-2072ced666f3",
    "metadata": {
     "id": "caa142fa-b375-4e78-b392-2072ced666f3"
@@ -430,16 +418,16 @@
     "# Llama 3.2 1B\n",
     "\n",
     "LLAMA32_CONFIG = {\n",
-    "    \"vocab_size\": 128_256,      # Vocabulary size\n",
-    "    \"context_length\": 131_072,  # Context length\n",
-    "    \"emb_dim\": 2048,            # Embedding dimension\n",
-    "    \"n_heads\": 32,              # Number of attention heads\n",
-    "    \"n_layers\": 16,             # Number of layers\n",
-    "    \"hidden_dim\": 8192,         # Size of the intermediate dimension in FeedForward\n",
-    "    \"n_kv_groups\": 8,           # Key-Value groups for grouped-query attention\n",
-    "    \"rope_base\": 500_000.0,     # The base in RoPE's \"theta\"\n",
-    "    \"dtype\": torch.bfloat16,    # Lower-precision dtype to reduce memory usage\n",
-    "    \"rope_freq\": {              # RoPE frequency scaling\n",
+    "    \"vocab_size\": 128_256,           # Vocabulary size\n",
+    "    \"context_length\": 131_072,       # Context length that was used to train the model\n",
+    "    \"emb_dim\": 2048,                 # Embedding dimension\n",
+    "    \"n_heads\": 32,                   # Number of attention heads\n",
+    "    \"n_layers\": 16,                  # Number of layers\n",
+    "    \"hidden_dim\": 8192,              # Size of the intermediate dimension in FeedForward\n",
+    "    \"n_kv_groups\": 8,                # Key-Value groups for grouped-query attention\n",
+    "    \"rope_base\": 500_000.0,          # The base in RoPE's \"theta\"\n",
+    "    \"dtype\": torch.bfloat16,         # Lower-precision dtype to reduce memory usage\n",
+    "    \"rope_freq\": {                   # RoPE frequency scaling\n",
     "        \"factor\": 32.0,\n",
     "        \"low_freq_factor\": 1.0,\n",
     "        \"high_freq_factor\": 4.0,\n",
@@ -450,16 +438,16 @@
     "# Llama 3.2 3B\n",
     "\n",
     "# LLAMA32_CONFIG = {\n",
-    "#     \"vocab_size\": 128_256,      # Vocabulary size\n",
-    "#     \"context_length\": 131_072,  # Context length\n",
-    "#     \"emb_dim\": 3072,            # Embedding dimension\n",
-    "#     \"n_heads\": 24,              # Number of attention heads\n",
-    "#     \"n_layers\": 28,             # Number of layers\n",
-    "#     \"hidden_dim\": 8192,         # Size of the intermediate dimension in FeedForward\n",
-    "#     \"n_kv_groups\": 8,           # Key-Value groups for grouped-query attention\n",
-    "#     \"rope_base\": 500_000.0,     # The base in RoPE's \"theta\"\n",
-    "#     \"dtype\": torch.bfloat16,    # Lower-precision dtype to reduce memory usage\n",
-    "#     \"rope_freq\": {              # RoPE frequency scaling\n",
+    "#     \"vocab_size\": 128_256,           # Vocabulary size\n",
+    "#     \"context_length\": 131_072,       # Context length that was used to train the model\n",
+    "#     \"emb_dim\": 3072,                 # Embedding dimension\n",
+    "#     \"n_heads\": 24,                   # Number of attention heads\n",
+    "#     \"n_layers\": 28,                  # Number of layers\n",
+    "#     \"hidden_dim\": 8192,              # Size of the intermediate dimension in FeedForward\n",
+    "#     \"n_kv_groups\": 8,                # Key-Value groups for grouped-query attention\n",
+    "#     \"rope_base\": 500_000.0,          # The base in RoPE's \"theta\"\n",
+    "#     \"dtype\": torch.bfloat16,         # Lower-precision dtype to reduce memory usage\n",
+    "#     \"rope_freq\": {                   # RoPE frequency scaling\n",
     "#         \"factor\": 32.0,\n",
     "#         \"low_freq_factor\": 1.0,\n",
     "#         \"high_freq_factor\": 4.0,\n",
@@ -470,54 +458,9 @@
     "LLAMA_SIZE_STR = \"1B\" if LLAMA32_CONFIG[\"emb_dim\"] == 2048 else \"3B\""
    ]
   },
-  {
-   "cell_type": "markdown",
-   "id": "34535172-797e-4dd0-84fb-65bc75ad5b06",
-   "metadata": {
-    "id": "34535172-797e-4dd0-84fb-65bc75ad5b06"
-   },
-   "source": [
-    "- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 10,
-   "id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c",
-   "metadata": {
-    "id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c"
-   },
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "New RoPE theta: 31250.0\n"
-     ]
-    }
-   ],
-   "source": [
-    "old_context_length = LLAMA32_CONFIG[\"context_length\"]\n",
-    "LLAMA32_CONFIG[\"context_length\"] = 8192\n",
-    "\n",
-    "\n",
-    "def rescale_theta(theta_old, context_length_old, context_length_new):\n",
-    "    scaling_factor = context_length_new / context_length_old\n",
-    "    theta_new = theta_old * scaling_factor\n",
-    "    return theta_new\n",
-    "\n",
-    "LLAMA32_CONFIG[\"rope_base\"] = rescale_theta(\n",
-    "    LLAMA32_CONFIG[\"rope_base\"],\n",
-    "    old_context_length,\n",
-    "    LLAMA32_CONFIG[\"context_length\"]\n",
-    ")\n",
-    "\n",
-    "print(\"New RoPE theta:\", LLAMA32_CONFIG[\"rope_base\"])"
-   ]
-  },
   {
    "cell_type": "code",
-   "execution_count": 11,
+   "execution_count": 9,
    "id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
    "metadata": {
     "id": "156253fe-aacd-4da2-8f13-705f05c4b11e"
@@ -539,36 +482,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 12,
-   "id": "0e95db6d-2712-41a5-a5e0-86c49897f4cf",
-   "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/"
-    },
-    "id": "0e95db6d-2712-41a5-a5e0-86c49897f4cf",
-    "outputId": "8efc4937-e616-40d0-cd59-670d7eb3e841"
-   },
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "True\n",
-      "True\n",
-      "True\n"
-     ]
-    }
-   ],
-   "source": [
-    "# Check buffers\n",
-    "print(model.trf_blocks[0].att.mask is model.trf_blocks[-1].att.mask)\n",
-    "print(model.trf_blocks[0].att.cos is model.trf_blocks[-1].att.cos)\n",
-    "print(model.trf_blocks[0].att.sin is model.trf_blocks[-1].att.sin)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 13,
+   "execution_count": 10,
    "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
    "metadata": {
     "colab": {
@@ -599,7 +513,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 14,
+   "execution_count": 11,
    "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
    "metadata": {
     "colab": {
@@ -613,8 +527,8 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "float32 (PyTorch default): 11.42 GB\n",
-      "bfloat16: 5.71 GB\n"
+      "float32 (PyTorch default): 11.23 GB\n",
+      "bfloat16: 5.61 GB\n"
      ]
     }
    ],
@@ -649,7 +563,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 12,
    "id": "31f12baf-f79b-499f-85c0-51328a6a20f5",
    "metadata": {
     "id": "31f12baf-f79b-499f-85c0-51328a6a20f5"
@@ -679,7 +593,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 16,
+   "execution_count": 13,
    "id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77",
    "metadata": {
     "id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77"
@@ -693,73 +607,86 @@
     "from tiktoken.load import load_tiktoken_bpe\n",
     "\n",
     "\n",
+    "\n",
     "class Tokenizer:\n",
+    "    \"\"\"Thin wrapper around tiktoken that keeps track of Llama-3 special IDs.\"\"\"\n",
     "    def __init__(self, model_path):\n",
-    "        assert os.path.isfile(model_path), f\"Model file {model_path} not found\"\n",
-    "        mergeable_ranks = load_tiktoken_bpe(model_path)\n",
+    "        if not os.path.isfile(model_path):\n",
+    "            raise FileNotFoundError(model_path)\n",
     "\n",
-    "        self.special_tokens = {\n",
+    "        mergeable = load_tiktoken_bpe(model_path)\n",
+    "\n",
+    "        # hard-coded from Meta's tokenizer.json\n",
+    "        self.special = {\n",
     "            \"<|begin_of_text|>\": 128000,\n",
     "            \"<|end_of_text|>\": 128001,\n",
     "            \"<|start_header_id|>\": 128006,\n",
     "            \"<|end_header_id|>\": 128007,\n",
     "            \"<|eot_id|>\": 128009,\n",
     "        }\n",
-    "        self.special_tokens.update({\n",
-    "            f\"<|reserved_{i}|>\": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()\n",
-    "        })\n",
+    "        self.special.update({f\"<|reserved_{i}|>\": 128002 + i\n",
+    "                             for i in range(256)\n",
+    "                             if 128002 + i not in self.special.values()})\n",
     "\n",
     "        self.model = tiktoken.Encoding(\n",
     "            name=Path(model_path).name,\n",
-    "            pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\",\n",
-    "            mergeable_ranks=mergeable_ranks,\n",
-    "            special_tokens=self.special_tokens\n",
+    "            pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)\"\n",
+    "                    r\"|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+\"\n",
+    "                    r\"|\\p{N}{1,3}\"\n",
+    "                    r\"| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*\"\n",
+    "                    r\"|\\s*[\\r\\n]+\"\n",
+    "                    r\"|\\s+(?!\\S)\"\n",
+    "                    r\"|\\s+\",\n",
+    "            mergeable_ranks=mergeable,\n",
+    "            special_tokens=self.special,\n",
     "        )\n",
     "\n",
-    "\n",
-    "    def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):\n",
-    "        if bos:\n",
-    "            tokens = [self.special_tokens[\"<|begin_of_text|>\"]]\n",
-    "        else:\n",
-    "            tokens = []\n",
-    "\n",
-    "        tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)\n",
-    "\n",
+    "    def encode(self, text, bos=False, eos=False):\n",
+    "        ids = ([self.special[\"<|begin_of_text|>\"]] if bos else []) \\\n",
+    "              + self.model.encode(text)\n",
     "        if eos:\n",
-    "            tokens.append(self.special_tokens[\"<|end_of_text|>\"])\n",
-    "        return tokens\n",
+    "            ids.append(self.special[\"<|end_of_text|>\"])\n",
+    "        return ids\n",
     "\n",
-    "    def decode(self, tokens):\n",
-    "        return self.model.decode(tokens)\n",
+    "    def decode(self, ids):\n",
+    "        return self.model.decode(ids)\n",
     "\n",
     "\n",
     "class ChatFormat:\n",
-    "    def __init__(self, tokenizer):\n",
-    "        self.tokenizer = tokenizer\n",
-    "\n",
-    "    def encode_header(self, message):\n",
-    "        tokens = []\n",
-    "        tokens.append(self.tokenizer.special_tokens[\"<|start_header_id|>\"])\n",
-    "        tokens.extend(self.tokenizer.encode(message[\"role\"], bos=False, eos=False))\n",
-    "        tokens.append(self.tokenizer.special_tokens[\"<|end_header_id|>\"])\n",
-    "        tokens.extend(self.tokenizer.encode(\"\\n\\n\", bos=False, eos=False))\n",
-    "        return tokens\n",
-    "\n",
-    "    def encode(self, text):\n",
-    "        message = {\n",
-    "            \"role\": \"user\",\n",
-    "            \"content\": text\n",
-    "        }\n",
     "\n",
-    "        tokens = self.encode_header(message)\n",
-    "        tokens.extend(\n",
-    "            self.tokenizer.encode(message[\"content\"].strip(), bos=False, eos=False)\n",
+    "    def __init__(self, tokenizer: Tokenizer, *,\n",
+    "                 default_system=\"You are a helpful assistant.\"):\n",
+    "        self.tok = tokenizer\n",
+    "        self.default_system = default_system\n",
+    "\n",
+    "    def _header(self, role):\n",
+    "        \"\"\"Encode <|start_header_id|>role<|end_header_id|>\\n\\n\"\"\"\n",
+    "        return (\n",
+    "            [self.tok.special[\"<|start_header_id|>\"]]\n",
+    "            + self.tok.encode(role)\n",
+    "            + [self.tok.special[\"<|end_header_id|>\"]]\n",
+    "            + self.tok.encode(\"\\n\\n\")\n",
     "        )\n",
-    "        tokens.append(self.tokenizer.special_tokens[\"<|eot_id|>\"])\n",
-    "        return tokens\n",
     "\n",
-    "    def decode(self, token_ids):\n",
-    "        return self.tokenizer.decode(token_ids)"
+    "    def encode(self, user_message, system_message=None):\n",
+    "        sys_msg = system_message if system_message is not None else self.default_system\n",
+    "\n",
+    "        ids = [self.tok.special[\"<|begin_of_text|>\"]]\n",
+    "\n",
+    "        # system\n",
+    "        ids += self._header(\"system\")\n",
+    "        ids += self.tok.encode(sys_msg)\n",
+    "        ids += [self.tok.special[\"<|eot_id|>\"]]\n",
+    "\n",
+    "        # user\n",
+    "        ids += self._header(\"user\")\n",
+    "        ids += self.tok.encode(user_message)\n",
+    "        ids += [self.tok.special[\"<|eot_id|>\"]]\n",
+    "\n",
+    "        # assistant header (no content yet)\n",
+    "        ids += self._header(\"assistant\")\n",
+    "\n",
+    "        return ids"
    ]
   },
   {
@@ -782,7 +709,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 17,
+   "execution_count": 14,
    "id": "e9d96dc8-603a-4cb5-8c3e-4d2ca56862ed",
    "metadata": {
     "colab": {
@@ -793,25 +720,24 @@
    },
    "outputs": [
     {
-     "name": "stdout",
+     "name": "stderr",
      "output_type": "stream",
      "text": [
-      "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n",
-      "Token is valid (permission: read).\n",
-      "Your token has been saved to /teamspace/studios/this_studio/.cache/huggingface/token\n",
-      "Login successful\n"
+      "/Users/sebastian/Developer/LLMs-from-scratch/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+      "  from .autonotebook import tqdm as notebook_tqdm\n"
      ]
     }
    ],
    "source": [
-    "from huggingface_hub import login\n",
+    "# Uncomment and run the following code if you are executing the notebook for the first time\n",
     "\n",
-    "login()"
+    "# from huggingface_hub import login\n",
+    "# login()"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 18,
+   "execution_count": 15,
    "id": "986bc1a0-804f-4154-80f8-44cefbee1368",
    "metadata": {
     "colab": {
@@ -847,7 +773,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 19,
+   "execution_count": 16,
    "id": "_gBhxDtU_nxo",
    "metadata": {
     "id": "_gBhxDtU_nxo"
@@ -871,7 +797,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 20,
+   "execution_count": 17,
    "id": "75166128-5899-4995-9b88-9672e135650e",
    "metadata": {
     "id": "75166128-5899-4995-9b88-9672e135650e"
@@ -954,7 +880,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 21,
+   "execution_count": 18,
    "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
    "metadata": {
     "colab": {
@@ -1018,7 +944,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 22,
+   "execution_count": 19,
    "id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37",
    "metadata": {
     "id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37"
@@ -1049,7 +975,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 23,
+   "execution_count": 20,
    "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
    "metadata": {
     "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
@@ -1108,7 +1034,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 24,
+   "execution_count": 21,
    "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
    "metadata": {
     "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
@@ -1118,23 +1044,31 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
+      "Time: 18.20 sec\n",
+      "\n",
+      "\n",
       "Output text:\n",
-      " Llamas are herbivores, which means they primarily eat plants. Their diet consists mainly of:\n",
       "\n",
-      "1. Grasses: Llamas love to graze on various types of grasses, including tall grasses and grassy meadows.\n",
-      "2. Hay: Llamas also eat hay, which is a dry, compressed form of grass or other plants.\n",
-      "3. Alfalfa: Alfalfa is a legume that is commonly fed to llamas. It is high in protein and fiber.\n",
-      "4. Other plants: Llamas will also eat other plants, such as wild grasses, shrubs, and trees.\n",
+      " Llamas are herbivores, which means they primarily eat plants and plant-based foods. Their diet typically consists of:\n",
       "\n",
-      "It's worth noting that the diet of llamas can vary depending on the region, climate,\n"
+      "1. Grasses: Llamas love to graze on various types of grasses, including tall grasses, short grasses, and grassy weeds.\n",
+      "2. Hay: They also enjoy munching on hay, which is a dry, compressed form of grass or other plant material.\n",
+      "3. Leaves: Llamas will eat leaves from trees and shrubs, including leaves from plants like clover, alfalfa, and grasses.\n",
+      "4. Fruits and vegetables: In the wild, llamas will eat fruits and vegetables like berries, apples, and carrots.\n",
+      "5. Browse: Llamas will also\n"
      ]
     }
    ],
    "source": [
+    "import time\n",
+    "\n",
+    "\n",
     "PROMPT = \"What do llamas eat?\"\n",
     "\n",
     "torch.manual_seed(123)\n",
     "\n",
+    "start = time.time()\n",
+    "\n",
     "token_ids = generate(\n",
     "    model=model,\n",
     "    idx=text_to_token_ids(PROMPT, chat_tokenizer).to(device),\n",
@@ -1144,6 +1078,13 @@
     "    temperature=0.\n",
     ")\n",
     "\n",
+    "print(f\"Time: {time.time() - start:.2f} sec\")\n",
+    "\n",
+    "if torch.cuda.is_available():\n",
+    "    max_mem_bytes = torch.cuda.max_memory_allocated()\n",
+    "    max_mem_gb = max_mem_bytes / (1024 ** 3)\n",
+    "    print(f\"Max memory allocated: {max_mem_gb:.2f} GB\")\n",
+    "\n",
     "output_text = token_ids_to_text(token_ids, tokenizer)\n",
     "\n",
     "\n",
@@ -1158,7 +1099,7 @@
     "        # If the token is not found, return the original text\n",
     "        return text\n",
     "\n",
-    "print(\"Output text:\\n\", clean_text(output_text))"
+    "print(\"\\n\\nOutput text:\\n\\n\", clean_text(output_text))"
    ]
   },
   {

+ 10 - 1
pkg/llms_from_scratch/README.md

@@ -110,12 +110,21 @@ from llms_from_scratch.appendix_a import NeuralNetwork, ToyDataset
 
 from llms_from_scratch.appendix_d import find_highest_gradient, train_model
 
+```
+
+
+
+### Llama  3 (Bonus material)
+
+```python
 from llms_from_scratch.llama3 import (
     Llama3Model,
+    Llama3ModelFast,
     Llama3Tokenizer,
     ChatFormat,
     clean_text
 )
 ```
 
-(For the `llms_from_scratch.llama3` usage information, please see [this bonus section](../../ch05/07_gpt_to_llama/README.md).
+
+For the `llms_from_scratch.llama3` usage information, please see [this bonus section](../../ch05/07_gpt_to_llama/README.md).

+ 75 - 75
pkg/llms_from_scratch/llama3.py

@@ -15,8 +15,7 @@ from tiktoken.load import load_tiktoken_bpe
 
 LLAMA32_CONFIG_1B = {
     "vocab_size": 128_256,           # Vocabulary size
-    "context_length": 8192,          # Maximum context length to use (reduced to save memory)
-    "orig_context_length": 131_072,  # Context length that was used to train the model
+    "context_length": 131_072,       # Context length that was used to train the model
     "emb_dim": 2048,                 # Embedding dimension
     "n_heads": 32,                   # Number of attention heads
     "n_layers": 16,                  # Number of layers
@@ -34,8 +33,7 @@ LLAMA32_CONFIG_1B = {
 
 LLAMA32_CONFIG_3B = {
     "vocab_size": 128_256,           # Vocabulary size
-    "context_length": 8192,          # Maximum context length to use (reduced to save memory)
-    "orig_context_length": 131_072,  # Context length that was used to train the model
+    "context_length": 131_072,       # Context length that was used to train the model
     "emb_dim": 3072,                 # Embedding dimension
     "n_heads": 24,                   # Number of attention heads
     "n_layers": 28,                  # Number of layers
@@ -67,17 +65,6 @@ class Llama3Model(nn.Module):
         self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
 
         # Reusuable utilities
-        self.register_buffer(
-            "mask", torch.triu(torch.ones(cfg["context_length"], cfg["context_length"]), diagonal=1).bool(),
-            persistent=False
-        )
-
-        if cfg["orig_context_length"] != cfg["context_length"]:
-            cfg["rope_base"] = rescale_theta(
-                            cfg["rope_base"],
-                            cfg["orig_context_length"],
-                            cfg["context_length"]
-                        )
         cos, sin = compute_rope_params(
             head_dim=cfg["emb_dim"] // cfg["n_heads"],
             theta_base=cfg["rope_base"],
@@ -92,8 +79,11 @@ class Llama3Model(nn.Module):
         tok_embeds = self.tok_emb(in_idx)
         x = tok_embeds
 
+        num_tokens = x.shape[1]
+        mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)
+
         for block in self.trf_blocks:
-            x = block(x, self.mask, self.cos, self.sin)
+            x = block(x, mask, self.cos, self.sin)
         x = self.final_norm(x)
         logits = self.out_head(x.to(self.cfg["dtype"]))
         return logits
@@ -281,88 +271,104 @@ def apply_rope(x, cos, sin):
     return x_rotated.to(dtype=x.dtype)
 
 
-def rescale_theta(theta_old, context_length_old, context_length_new):
-    scaling_factor = context_length_new / context_length_old
-    theta_new = theta_old * scaling_factor
-    return theta_new
-
-
 ##########################################
 # Tokenizer
 ##########################################
 
 
 class Llama3Tokenizer:
+    """Thin wrapper around tiktoken that keeps track of Llama-3 special IDs."""
     def __init__(self, model_path):
-        assert os.path.isfile(model_path), f"Model file {model_path} not found"
-        mergeable_ranks = load_tiktoken_bpe(model_path)
+        if not os.path.isfile(model_path):
+            raise FileNotFoundError(model_path)
 
-        self.special_tokens = {
+        mergeable = load_tiktoken_bpe(model_path)
+
+        # hard-coded from Meta's tokenizer.json
+        self.special = {
             "<|begin_of_text|>": 128000,
             "<|end_of_text|>": 128001,
             "<|start_header_id|>": 128006,
             "<|end_header_id|>": 128007,
             "<|eot_id|>": 128009,
         }
-        self.special_tokens.update({
-            f"<|reserved_{i}|>": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()
-        })
+        self.special.update({f"<|reserved_{i}|>": 128002 + i
+                             for i in range(256)
+                             if 128002 + i not in self.special.values()})
 
         self.model = tiktoken.Encoding(
             name=Path(model_path).name,
-            pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
-            mergeable_ranks=mergeable_ranks,
-            special_tokens=self.special_tokens
+            pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)"
+                    r"|[^\r\n\p{L}\p{N}]?\p{L}+"
+                    r"|\p{N}{1,3}"
+                    r"| ?[^\s\p{L}\p{N}]+[\r\n]*"
+                    r"|\s*[\r\n]+"
+                    r"|\s+(?!\S)"
+                    r"|\s+",
+            mergeable_ranks=mergeable,
+            special_tokens=self.special,
         )
 
-    def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):
-        if bos:
-            tokens = [self.special_tokens["<|begin_of_text|>"]]
-        else:
-            tokens = []
+    def encode(self, text, bos=False, eos=False, allowed_special=set()):
+        ids: list[int] = []
 
-        tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)
+        if bos:
+            ids.append(self.special_tokens["<|begin_of_text|>"])
 
+        # delegate to underlying tiktoken.Encoding.encode
+        ids.extend(
+            self.model.encode(
+                text,
+                allowed_special=allowed_special,
+            )
+        )
         if eos:
-            tokens.append(self.special_tokens["<|end_of_text|>"])
-        return tokens
+            ids.append(self.special_tokens["<|end_of_text|>"])
 
-    def decode(self, tokens):
-        return self.model.decode(tokens)
+        return ids
+
+    def decode(self, ids):
+        return self.model.decode(ids)
 
 
 class ChatFormat:
-    def __init__(self, tokenizer):
-        self.tokenizer = tokenizer
-
-    def encode_header(self, message):
-        tokens = []
-        tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
-        tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False))
-        tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
-        tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
-        return tokens
-
-    def encode(self, text, allowed_special=None):
-        message = {
-            "role": "user",
-            "content": text
-        }
 
-        tokens = self.encode_header(message)
-        tokens.extend(
-            self.tokenizer.encode(
-                message["content"].strip(),
-                bos=False,
-                eos=False,
-                allowed_special=allowed_special
-            )
+    def __init__(self, tokenizer: Llama3Tokenizer, *,
+                 default_system="You are a helpful assistant."):
+        self.tok = tokenizer
+        self.default_system = default_system
+
+    def _header(self, role):
+        """Encode <|start_header_id|>role<|end_header_id|>\n\n"""
+        return (
+            [self.tok.special["<|start_header_id|>"]]
+            + self.tok.encode(role)
+            + [self.tok.special["<|end_header_id|>"]]
+            + self.tok.encode("\n\n")
         )
-        tokens.append(self.tokenizer.special_tokens["<|eot_id|>"])
-        return tokens
 
-    def decode(self, token_ids):
-        return self.tokenizer.decode(token_ids)
+    def encode(self, user_message, system_message=None, allowed_special=None):
+        sys_msg = system_message if system_message is not None else self.default_system
+
+        ids = [self.tok.special["<|begin_of_text|>"]]
+
+        # system
+        ids += self._header("system")
+        ids += self.tok.encode(sys_msg, allowed_special=allowed_special)
+        ids += [self.tok.special["<|eot_id|>"]]
+
+        # user
+        ids += self._header("user")
+        ids += self.tok.encode(user_message)
+        ids += [self.tok.special["<|eot_id|>"]]
+
+        # assistant header (no content yet)
+        ids += self._header("assistant")
+
+        return ids
+
+    def decode(self, ids):
+        return self.tok.decode(ids)
 
 
 def clean_text(text, header_end="assistant<|end_header_id|>\n\n"):
@@ -483,12 +489,6 @@ class Llama3ModelFast(nn.Module):
         self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
         self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
 
-        if cfg["orig_context_length"] != cfg["context_length"]:
-            cfg["rope_base"] = rescale_theta(
-                            cfg["rope_base"],
-                            cfg["orig_context_length"],
-                            cfg["context_length"]
-                        )
         cos, sin = compute_rope_params(
             head_dim=cfg["emb_dim"] // cfg["n_heads"],
             theta_base=cfg["rope_base"],

+ 1 - 19
pkg/llms_from_scratch/tests/test_llama3.py

@@ -7,7 +7,6 @@ from llms_from_scratch.ch04 import generate_text_simple
 from llms_from_scratch.llama3 import (
     compute_rope_params,
     apply_rope,
-    rescale_theta,
     LLAMA32_CONFIG_1B,
     GroupedQueryAttention,
     GroupedQueryAttentionFast,
@@ -102,23 +101,6 @@ GPT_CONFIG_124M = {
 }
 
 
-def test_rescale():
-
-    new_theta = rescale_theta(
-        theta_old=500_000.,
-        context_length_old=131_072,
-        context_length_new=8192
-    )
-    assert new_theta == 31250.
-
-    old_theta = rescale_theta(
-        theta_old=new_theta,
-        context_length_old=8192,
-        context_length_new=131_072
-    )
-    assert old_theta == 500_000.
-
-
 def test_grouped_query_attention_equivalence():
     torch.manual_seed(42)
     b, t, d_in, d_out, num_heads, num_kv_groups = 2, 8, 32, 64, 4, 2
@@ -194,6 +176,6 @@ def test_gpt_model_variants(ModelClass, llama3_weights_path):
     )
     print("Encoded output text:", out)
     expect = torch.tensor([
-        [43,   2543,    292,   4483, 100383,   8113,  21197,  33804,  54419]
+        [43, 2543, 292, 4483, 100383, 8113, 76873, 42175, 72641]
     ])
     assert torch.equal(expect, out)

+ 1 - 1
pyproject.toml

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
 
 [project]
 name = "llms-from-scratch"
-version = "1.0.6"
+version = "1.0.7"
 description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step"
 readme = "README.md"
 requires-python = ">=3.10"