|
|
@@ -0,0 +1,1240 @@
|
|
|
+{
|
|
|
+ "cells": [
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c",
|
|
|
+ "metadata": {
|
|
|
+ "id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "<table style=\"width:100%\">\n",
|
|
|
+ "<tr>\n",
|
|
|
+ "<td style=\"vertical-align:middle; text-align:left;\">\n",
|
|
|
+ "<font size=\"2\">\n",
|
|
|
+ "Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
|
|
|
+ "<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
|
|
|
+ "</font>\n",
|
|
|
+ "</td>\n",
|
|
|
+ "<td style=\"vertical-align:middle; text-align:left;\">\n",
|
|
|
+ "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
|
|
|
+ "</td>\n",
|
|
|
+ "</tr>\n",
|
|
|
+ "</table>"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "efde77f2-6af3-4781-8597-89ecd3f41a52",
|
|
|
+ "metadata": {
|
|
|
+ "id": "efde77f2-6af3-4781-8597-89ecd3f41a52"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "# Qwen3 Mixture-of-Experts From Scratch (A Standalone Notebook)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d",
|
|
|
+ "metadata": {
|
|
|
+ "id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "- This notebook is purposefully minimal and focuses on the code to implement Qwen3-30B-A3B model (with support for **Coder**, **Instruct** and **Thinking** variants); for more information about this model, please see the original blog post, technical report, and model hub pages:\n",
|
|
|
+ " - [Qwen3: Think Deeper, Act Faster](https://qwenlm.github.io/blog/qwen3/)\n",
|
|
|
+ " - [Qwen3 Technical Report](https://arxiv.org/abs/2505.09388)\n",
|
|
|
+ " - https://huggingface.co/Qwen/Qwen3-Coder-30B-A3B-Instruct (Qwen3 Coder Flash)\n",
|
|
|
+ " - https://huggingface.co/Qwen/Qwen3-30B-A3B-Thinking-2507 (new thinking model)\n",
|
|
|
+ " - https://huggingface.co/Qwen/Qwen3-235B-A22B-Instruct-2507 (new instruct model)\n",
|
|
|
+ " - https://huggingface.co/Qwen/Qwen3-30B-A3B (original Instruct/Thinking hybrid model)\n",
|
|
|
+ "- Many architectural components in Qwen3 are similar to Llama 3; for a step-by-step guide that explains the individual components and the relationship between GPT and the components used here, you may like the GPT-to-Llama conversion notebooks:\n",
|
|
|
+ " - [Converting a From-Scratch GPT Architecture to Llama 2](../07_gpt_to_llama/converting-gpt-to-llama2.ipynb)\n",
|
|
|
+ " - [Converting Llama 2 to Llama 3.2 From Scratch](../07_gpt_to_llama/converting-llama2-to-llama3.ipynb)\n",
|
|
|
+ " \n",
|
|
|
+ "\n",
|
|
|
+ "**By default, this notebook runs Qwen3-Coder-30B-A3B-Instruct (aka Qwen3 Coder Flash) and requires 80 GB of VRAM (e.g., a single A100 or H100)**\n",
|
|
|
+ "\n",
|
|
|
+ "<br>\n",
|
|
|
+ "\n",
|
|
|
+ "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/qwen/qwen3-coder-flash-overview.webp?123\" width=\"600px\">\n",
|
|
|
+ "\n",
|
|
|
+ "<br>\n",
|
|
|
+ " \n",
|
|
|
+ "- About the code:\n",
|
|
|
+ " - all code is my own code, mapping the Qwen3 architecture onto the model code implemented in my [Build A Large Language Model (From Scratch)](http://mng.bz/orYv) book; the code is released under a permissive open-source Apache 2.0 license (see [LICENSE.txt](https://github.com/rasbt/LLMs-from-scratch/blob/main/LICENSE.txt))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 1,
|
|
|
+ "id": "7c201adb-747e-437b-9a62-442802941e01",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 2,
|
|
|
+ "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
|
|
|
+ "metadata": {
|
|
|
+ "colab": {
|
|
|
+ "base_uri": "https://localhost:8080/"
|
|
|
+ },
|
|
|
+ "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
|
|
|
+ "outputId": "4f762354-e0a3-4cc2-e5d4-e61a227a202c"
|
|
|
+ },
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "huggingface_hub version: 0.34.3\n",
|
|
|
+ "tokenizers version: 0.21.4\n",
|
|
|
+ "torch version: 2.7.1+cu128\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "from importlib.metadata import version\n",
|
|
|
+ "\n",
|
|
|
+ "pkgs = [\n",
|
|
|
+ " \"huggingface_hub\", # to download pretrained weights\n",
|
|
|
+ " \"tokenizers\", # to implement the tokenizer\n",
|
|
|
+ " \"torch\", # to implement the model\n",
|
|
|
+ "]\n",
|
|
|
+ "for p in pkgs:\n",
|
|
|
+ " print(f\"{p} version: {version(p)}\")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "653410a6-dd2b-4eb2-a722-23d9782e726d",
|
|
|
+ "metadata": {
|
|
|
+ "id": "653410a6-dd2b-4eb2-a722-23d9782e726d"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ " \n",
|
|
|
+ "# 1. Architecture code"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 3,
|
|
|
+ "id": "82076c21-9331-4dcd-b017-42b046cf1a60",
|
|
|
+ "metadata": {
|
|
|
+ "id": "82076c21-9331-4dcd-b017-42b046cf1a60"
|
|
|
+ },
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "import torch\n",
|
|
|
+ "import torch.nn as nn\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "class FeedForward(nn.Module):\n",
|
|
|
+ " def __init__(self, cfg):\n",
|
|
|
+ " super().__init__()\n",
|
|
|
+ " self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
|
|
|
+ " self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
|
|
|
+ " self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
|
|
|
+ "\n",
|
|
|
+ " def forward(self, x):\n",
|
|
|
+ " x_fc1 = self.fc1(x)\n",
|
|
|
+ " x_fc2 = self.fc2(x)\n",
|
|
|
+ " x = nn.functional.silu(x_fc1) * x_fc2\n",
|
|
|
+ " return self.fc3(x)\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "class MoEFeedForward(nn.Module):\n",
|
|
|
+ " def __init__(self, cfg):\n",
|
|
|
+ " super().__init__()\n",
|
|
|
+ " self.num_experts_per_tok = cfg[\"num_experts_per_tok\"]\n",
|
|
|
+ " self.num_experts = cfg[\"num_experts\"]\n",
|
|
|
+ " self.gate = nn.Linear(cfg[\"emb_dim\"], cfg[\"num_experts\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
|
|
+ "\n",
|
|
|
+ " meta_device = torch.device(\"meta\") # to reduce memory pressure and only load them when used (trades compute for memory)\n",
|
|
|
+ " self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n",
|
|
|
+ " for _ in range(cfg[\"num_experts\"])])\n",
|
|
|
+ " self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n",
|
|
|
+ " for _ in range(cfg[\"num_experts\"])])\n",
|
|
|
+ " self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n",
|
|
|
+ " for _ in range(cfg[\"num_experts\"])])\n",
|
|
|
+ "\n",
|
|
|
+ " def forward(self, x):\n",
|
|
|
+ " b, seq_len, embed_dim = x.shape\n",
|
|
|
+ " scores = self.gate(x) # (b, seq_len, num_experts)\n",
|
|
|
+ " topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)\n",
|
|
|
+ " topk_probs = torch.softmax(topk_scores, dim=-1)\n",
|
|
|
+ " \n",
|
|
|
+ " expert_outputs = []\n",
|
|
|
+ " for e in range(self.num_experts):\n",
|
|
|
+ " hidden = torch.nn.functional.silu(self.fc1[e](x)) * self.fc2[e](x)\n",
|
|
|
+ " out = self.fc3[e](hidden)\n",
|
|
|
+ " expert_outputs.append(out.unsqueeze(-2))\n",
|
|
|
+ " expert_outputs = torch.cat(expert_outputs, dim=-2) # (b, t, num_experts, emb_dim)\n",
|
|
|
+ "\n",
|
|
|
+ " gating_probs = torch.zeros_like(scores)\n",
|
|
|
+ "\n",
|
|
|
+ " for i in range(self.num_experts_per_tok):\n",
|
|
|
+ " indices = topk_indices[..., i:i+1]\n",
|
|
|
+ " prob = topk_probs[..., i:i+1]\n",
|
|
|
+ " gating_probs.scatter_(dim=-1, index=indices, src=prob)\n",
|
|
|
+ " gating_probs = gating_probs.unsqueeze(-1) # (b, t, num_experts, 1)\n",
|
|
|
+ " \n",
|
|
|
+ " # Weighted sum over experts\n",
|
|
|
+ " y = (gating_probs * expert_outputs).sum(dim=-2)\n",
|
|
|
+ " return y\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ " # For some reason, the version below is slower than the naive version\n",
|
|
|
+ " # above that computes all experts, even the unused ones\n",
|
|
|
+ "\n",
|
|
|
+ " # def forward(self, x):\n",
|
|
|
+ " # scores = self.gate(x) # (b, seq_len, num_experts)\n",
|
|
|
+ " # topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)\n",
|
|
|
+ " # topk_probs = torch.softmax(topk_scores, dim=-1)\n",
|
|
|
+ " # y = torch.zeros_like(x)\n",
|
|
|
+ "\n",
|
|
|
+ " # for i in range(self.num_experts_per_tok):\n",
|
|
|
+ " # # expert_indices is (b, seq_len) with values in [0, num_experts)\n",
|
|
|
+ " # expert_indices = topk_indices[..., i]\n",
|
|
|
+ " # prob = topk_probs[..., i].unsqueeze(-1) # (b, seq_len, 1)\n",
|
|
|
+ "\n",
|
|
|
+ " # # For each expert, process only the tokens assigned to it\n",
|
|
|
+ " # for e in range(self.num_experts):\n",
|
|
|
+ " # mask = (expert_indices == e) # (b, seq_len) boolean mask\n",
|
|
|
+ " # if mask.any():\n",
|
|
|
+ " # selected = x[mask] # (num_tokens_e, emb_dim)\n",
|
|
|
+ " # # Compute FF for expert e\n",
|
|
|
+ " # out = self.fc3[e](torch.nn.functional.silu(self.fc1[e](selected)) * self.fc2[e](selected))\n",
|
|
|
+ " # # Scale by gating prob and scatter back\n",
|
|
|
+ " # y[mask] += prob[mask] * out\n",
|
|
|
+ " # return y"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 4,
|
|
|
+ "id": "56715760-37e1-433e-89da-04864c139a9e",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "class RMSNorm(nn.Module):\n",
|
|
|
+ " def __init__(self, emb_dim, eps=1e-6, bias=False, qwen3_compatible=True):\n",
|
|
|
+ " super().__init__()\n",
|
|
|
+ " self.eps = eps\n",
|
|
|
+ " self.qwen3_compatible = qwen3_compatible\n",
|
|
|
+ " self.scale = nn.Parameter(torch.ones(emb_dim))\n",
|
|
|
+ " self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None\n",
|
|
|
+ "\n",
|
|
|
+ " def forward(self, x):\n",
|
|
|
+ " input_dtype = x.dtype\n",
|
|
|
+ "\n",
|
|
|
+ " if self.qwen3_compatible:\n",
|
|
|
+ " x = x.to(torch.float32)\n",
|
|
|
+ "\n",
|
|
|
+ " variance = x.pow(2).mean(dim=-1, keepdim=True)\n",
|
|
|
+ " norm_x = x * torch.rsqrt(variance + self.eps)\n",
|
|
|
+ " norm_x = norm_x * self.scale\n",
|
|
|
+ "\n",
|
|
|
+ " if self.shift is not None:\n",
|
|
|
+ " norm_x = norm_x + self.shift\n",
|
|
|
+ "\n",
|
|
|
+ " return norm_x.to(input_dtype)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 5,
|
|
|
+ "id": "4b9a346f-5826-4083-9162-abd56afc03f0",
|
|
|
+ "metadata": {
|
|
|
+ "id": "4b9a346f-5826-4083-9162-abd56afc03f0"
|
|
|
+ },
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):\n",
|
|
|
+ " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
|
|
|
+ "\n",
|
|
|
+ " # Compute the inverse frequencies\n",
|
|
|
+ " inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))\n",
|
|
|
+ "\n",
|
|
|
+ " # Generate position indices\n",
|
|
|
+ " positions = torch.arange(context_length, dtype=dtype)\n",
|
|
|
+ "\n",
|
|
|
+ " # Compute the angles\n",
|
|
|
+ " angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
|
|
|
+ "\n",
|
|
|
+ " # Expand angles to match the head_dim\n",
|
|
|
+ " angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n",
|
|
|
+ "\n",
|
|
|
+ " # Precompute sine and cosine\n",
|
|
|
+ " cos = torch.cos(angles)\n",
|
|
|
+ " sin = torch.sin(angles)\n",
|
|
|
+ "\n",
|
|
|
+ " return cos, sin\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def apply_rope(x, cos, sin, offset=0):\n",
|
|
|
+ " # x: (batch_size, num_heads, seq_len, head_dim)\n",
|
|
|
+ " batch_size, num_heads, seq_len, head_dim = x.shape\n",
|
|
|
+ " assert head_dim % 2 == 0, \"Head dimension must be even\"\n",
|
|
|
+ "\n",
|
|
|
+ " # Split x into first half and second half\n",
|
|
|
+ " x1 = x[..., : head_dim // 2] # First half\n",
|
|
|
+ " x2 = x[..., head_dim // 2:] # Second half\n",
|
|
|
+ "\n",
|
|
|
+ " # Adjust sin and cos shapes\n",
|
|
|
+ " cos = cos[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)\n",
|
|
|
+ " sin = sin[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)\n",
|
|
|
+ "\n",
|
|
|
+ " # Apply the rotary transformation\n",
|
|
|
+ " rotated = torch.cat((-x2, x1), dim=-1)\n",
|
|
|
+ " x_rotated = (x * cos) + (rotated * sin)\n",
|
|
|
+ "\n",
|
|
|
+ " # It's ok to use lower-precision after applying cos and sin rotation\n",
|
|
|
+ " return x_rotated.to(dtype=x.dtype)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 6,
|
|
|
+ "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb",
|
|
|
+ "metadata": {
|
|
|
+ "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb"
|
|
|
+ },
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "class GroupedQueryAttention(nn.Module):\n",
|
|
|
+ " def __init__(\n",
|
|
|
+ " self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None\n",
|
|
|
+ " ):\n",
|
|
|
+ " super().__init__()\n",
|
|
|
+ " assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n",
|
|
|
+ "\n",
|
|
|
+ " self.num_heads = num_heads\n",
|
|
|
+ " self.num_kv_groups = num_kv_groups\n",
|
|
|
+ " self.group_size = num_heads // num_kv_groups\n",
|
|
|
+ "\n",
|
|
|
+ " if head_dim is None:\n",
|
|
|
+ " assert d_in % num_heads == 0, \"`d_in` must be divisible by `num_heads` if `head_dim` is not set\"\n",
|
|
|
+ " head_dim = d_in // num_heads\n",
|
|
|
+ "\n",
|
|
|
+ " self.head_dim = head_dim\n",
|
|
|
+ " self.d_out = num_heads * head_dim\n",
|
|
|
+ "\n",
|
|
|
+ " self.W_query = nn.Linear(d_in, self.d_out, bias=False, dtype=dtype)\n",
|
|
|
+ " self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)\n",
|
|
|
+ " self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)\n",
|
|
|
+ "\n",
|
|
|
+ " self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype)\n",
|
|
|
+ "\n",
|
|
|
+ " if qk_norm:\n",
|
|
|
+ " self.q_norm = RMSNorm(head_dim, eps=1e-6)\n",
|
|
|
+ " self.k_norm = RMSNorm(head_dim, eps=1e-6)\n",
|
|
|
+ " else:\n",
|
|
|
+ " self.q_norm = self.k_norm = None\n",
|
|
|
+ "\n",
|
|
|
+ " def forward(self, x, mask, cos, sin, start_pos=0, cache=None):\n",
|
|
|
+ " b, num_tokens, _ = x.shape\n",
|
|
|
+ "\n",
|
|
|
+ " # Apply projections\n",
|
|
|
+ " queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)\n",
|
|
|
+ " keys = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)\n",
|
|
|
+ " values = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)\n",
|
|
|
+ "\n",
|
|
|
+ " # Reshape\n",
|
|
|
+ " queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)\n",
|
|
|
+ " keys_new = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n",
|
|
|
+ " values_new = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n",
|
|
|
+ "\n",
|
|
|
+ " # Optional normalization\n",
|
|
|
+ " if self.q_norm:\n",
|
|
|
+ " queries = self.q_norm(queries)\n",
|
|
|
+ " if self.k_norm:\n",
|
|
|
+ " keys_new = self.k_norm(keys_new)\n",
|
|
|
+ "\n",
|
|
|
+ " # Apply RoPE\n",
|
|
|
+ " queries = apply_rope(queries, cos, sin, offset=start_pos)\n",
|
|
|
+ " keys_new = apply_rope(keys_new, cos, sin, offset=start_pos)\n",
|
|
|
+ "\n",
|
|
|
+ " if cache is not None:\n",
|
|
|
+ " prev_k, prev_v = cache\n",
|
|
|
+ " keys = torch.cat([prev_k, keys_new], dim=2)\n",
|
|
|
+ " values = torch.cat([prev_v, values_new], dim=2)\n",
|
|
|
+ " next_cache = (keys, values)\n",
|
|
|
+ " else:\n",
|
|
|
+ " start_pos = 0 # reset RoPE\n",
|
|
|
+ " keys, values = keys_new, values_new\n",
|
|
|
+ " next_cache = (keys, values)\n",
|
|
|
+ "\n",
|
|
|
+ " # Expand K and V to match number of heads\n",
|
|
|
+ " keys = keys.repeat_interleave(self.group_size, dim=1)\n",
|
|
|
+ " values = values.repeat_interleave(self.group_size, dim=1)\n",
|
|
|
+ "\n",
|
|
|
+ " # Attention\n",
|
|
|
+ " attn_scores = queries @ keys.transpose(2, 3)\n",
|
|
|
+ " attn_scores = attn_scores.masked_fill(mask, -torch.inf)\n",
|
|
|
+ " attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)\n",
|
|
|
+ "\n",
|
|
|
+ " context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)\n",
|
|
|
+ " return self.out_proj(context), next_cache"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 7,
|
|
|
+ "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9",
|
|
|
+ "metadata": {
|
|
|
+ "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9"
|
|
|
+ },
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "class TransformerBlock(nn.Module):\n",
|
|
|
+ " def __init__(self, cfg):\n",
|
|
|
+ " super().__init__()\n",
|
|
|
+ " self.att = GroupedQueryAttention(\n",
|
|
|
+ " d_in=cfg[\"emb_dim\"],\n",
|
|
|
+ " num_heads=cfg[\"n_heads\"],\n",
|
|
|
+ " head_dim=cfg[\"head_dim\"],\n",
|
|
|
+ " num_kv_groups=cfg[\"n_kv_groups\"],\n",
|
|
|
+ " qk_norm=cfg[\"qk_norm\"],\n",
|
|
|
+ " dtype=cfg[\"dtype\"]\n",
|
|
|
+ " )\n",
|
|
|
+ " if cfg[\"num_experts\"] > 0:\n",
|
|
|
+ " self.ff = MoEFeedForward(cfg)\n",
|
|
|
+ " else:\n",
|
|
|
+ " self.ff = FeedForward(cfg)\n",
|
|
|
+ " self.norm1 = RMSNorm(cfg[\"emb_dim\"], eps=1e-6)\n",
|
|
|
+ " self.norm2 = RMSNorm(cfg[\"emb_dim\"], eps=1e-6)\n",
|
|
|
+ "\n",
|
|
|
+ " def forward(self, x, mask, cos, sin, start_pos=0, cache=None):\n",
|
|
|
+ " # Shortcut connection for attention block\n",
|
|
|
+ " shortcut = x\n",
|
|
|
+ " x = self.norm1(x)\n",
|
|
|
+ " x, next_cache = self.att(x, mask, cos, sin, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size]\n",
|
|
|
+ " x = x + shortcut # Add the original input back\n",
|
|
|
+ "\n",
|
|
|
+ " # Shortcut connection for feed-forward block\n",
|
|
|
+ " shortcut = x\n",
|
|
|
+ " x = self.norm2(x)\n",
|
|
|
+ " x = self.ff(x)\n",
|
|
|
+ " x = x + shortcut # Add the original input back\n",
|
|
|
+ "\n",
|
|
|
+ " return x, next_cache\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 8,
|
|
|
+ "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4",
|
|
|
+ "metadata": {
|
|
|
+ "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4"
|
|
|
+ },
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "class Qwen3Model(nn.Module):\n",
|
|
|
+ " def __init__(self, cfg):\n",
|
|
|
+ " super().__init__()\n",
|
|
|
+ "\n",
|
|
|
+ " # Main model parameters\n",
|
|
|
+ " self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n",
|
|
|
+ "\n",
|
|
|
+ " self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`\n",
|
|
|
+ " [TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])]\n",
|
|
|
+ " )\n",
|
|
|
+ "\n",
|
|
|
+ " self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
|
|
|
+ " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
|
|
+ "\n",
|
|
|
+ " # Reusuable utilities\n",
|
|
|
+ " if cfg[\"head_dim\"] is None:\n",
|
|
|
+ " head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n",
|
|
|
+ " else:\n",
|
|
|
+ " head_dim = cfg[\"head_dim\"]\n",
|
|
|
+ " cos, sin = compute_rope_params(\n",
|
|
|
+ " head_dim=head_dim,\n",
|
|
|
+ " theta_base=cfg[\"rope_base\"],\n",
|
|
|
+ " context_length=cfg[\"context_length\"]\n",
|
|
|
+ " )\n",
|
|
|
+ " self.register_buffer(\"cos\", cos, persistent=False)\n",
|
|
|
+ " self.register_buffer(\"sin\", sin, persistent=False)\n",
|
|
|
+ " self.cfg = cfg\n",
|
|
|
+ " self.current_pos = 0 # Track current position in KV cache\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ " def forward(self, in_idx, cache=None):\n",
|
|
|
+ " # Forward pass\n",
|
|
|
+ " tok_embeds = self.tok_emb(in_idx)\n",
|
|
|
+ " x = tok_embeds\n",
|
|
|
+ "\n",
|
|
|
+ " num_tokens = x.shape[1]\n",
|
|
|
+ " if cache is not None:\n",
|
|
|
+ " pos_start = self.current_pos\n",
|
|
|
+ " pos_end = pos_start + num_tokens\n",
|
|
|
+ " self.current_pos = pos_end\n",
|
|
|
+ " mask = torch.triu(\n",
|
|
|
+ " torch.ones(pos_end, pos_end, device=x.device, dtype=torch.bool), diagonal=1\n",
|
|
|
+ " )[pos_start:pos_end, :pos_end]\n",
|
|
|
+ " else:\n",
|
|
|
+ " pos_start = 0 # Not strictly necessary but helps torch.compile\n",
|
|
|
+ " mask = torch.triu(\n",
|
|
|
+ " torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1\n",
|
|
|
+ " )\n",
|
|
|
+ " # Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads\n",
|
|
|
+ " mask = mask[None, None, :, :]\n",
|
|
|
+ "\n",
|
|
|
+ " next_cache = []\n",
|
|
|
+ " for i, block in enumerate(self.trf_blocks):\n",
|
|
|
+ " blk_cache = cache.get(i) if cache else None\n",
|
|
|
+ " x, new_blk_cache = block(x, mask, self.cos, self.sin,\n",
|
|
|
+ " start_pos=pos_start,\n",
|
|
|
+ " cache=blk_cache)\n",
|
|
|
+ " if cache is not None:\n",
|
|
|
+ " cache.update(i, new_blk_cache)\n",
|
|
|
+ " next_cache.append(new_blk_cache)\n",
|
|
|
+ "\n",
|
|
|
+ " x = self.final_norm(x)\n",
|
|
|
+ " logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n",
|
|
|
+ " return logits\n",
|
|
|
+ "\n",
|
|
|
+ " def reset_kv_cache(self):\n",
|
|
|
+ " self.current_pos = 0"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 9,
|
|
|
+ "id": "bc04d120",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "class KVCache:\n",
|
|
|
+ " def __init__(self, n_layers):\n",
|
|
|
+ " self.cache = [None] * n_layers\n",
|
|
|
+ "\n",
|
|
|
+ " def get(self, layer_idx):\n",
|
|
|
+ " return self.cache[layer_idx]\n",
|
|
|
+ "\n",
|
|
|
+ " def update(self, layer_idx, value):\n",
|
|
|
+ " self.cache[layer_idx] = value\n",
|
|
|
+ "\n",
|
|
|
+ " def get_all(self):\n",
|
|
|
+ " return self.cache\n",
|
|
|
+ "\n",
|
|
|
+ " def reset(self):\n",
|
|
|
+ " for i in range(len(self.cache)):\n",
|
|
|
+ " self.cache[i] = None"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "be2d201f-74ad-4d63-ab9c-601b00674a48",
|
|
|
+ "metadata": {
|
|
|
+ "id": "be2d201f-74ad-4d63-ab9c-601b00674a48"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ " \n",
|
|
|
+ "# 2. Initialize model"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 10,
|
|
|
+ "id": "caa142fa-b375-4e78-b392-2072ced666f3",
|
|
|
+ "metadata": {
|
|
|
+ "id": "caa142fa-b375-4e78-b392-2072ced666f3"
|
|
|
+ },
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# Same config for\n",
|
|
|
+ "\n",
|
|
|
+ "# https://huggingface.co/Qwen/Qwen3-Coder-30B-A3B-Instruct (Qwen3 Coder Flash)\n",
|
|
|
+ "# https://huggingface.co/Qwen/Qwen3-30B-A3B-Thinking-2507\n",
|
|
|
+ "# https://huggingface.co/Qwen/Qwen3-235B-A22B-Instruct-2507\n",
|
|
|
+ "# https://huggingface.co/Qwen/Qwen3-30B-A3B (original Instruct/Thinking hybrid model)\n",
|
|
|
+ "\n",
|
|
|
+ "QWEN3_CONFIG = {\n",
|
|
|
+ " \"vocab_size\": 151_936,\n",
|
|
|
+ " \"context_length\": 262_144,\n",
|
|
|
+ " \"emb_dim\": 2048,\n",
|
|
|
+ " \"n_heads\": 32,\n",
|
|
|
+ " \"n_layers\": 48,\n",
|
|
|
+ " \"head_dim\": 128,\n",
|
|
|
+ " \"qk_norm\": True,\n",
|
|
|
+ " \"n_kv_groups\": 4,\n",
|
|
|
+ " \"rope_base\": 10_000_000.0,\n",
|
|
|
+ " \"dtype\": torch.bfloat16,\n",
|
|
|
+ " \"num_experts\": 128,\n",
|
|
|
+ " \"num_experts_per_tok\": 8,\n",
|
|
|
+ " \"moe_intermediate_size\": 768,\n",
|
|
|
+ "}"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 11,
|
|
|
+ "id": "313effd0",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "cuda\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "if torch.cuda.is_available():\n",
|
|
|
+ " device = torch.device(\"cuda\")\n",
|
|
|
+ "elif torch.backends.mps.is_available():\n",
|
|
|
+ " device = torch.device(\"mps\")\n",
|
|
|
+ "else:\n",
|
|
|
+ " device = torch.device(\"cpu\")\n",
|
|
|
+ "\n",
|
|
|
+ "print(device)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 12,
|
|
|
+ "id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
|
|
|
+ "metadata": {
|
|
|
+ "id": "156253fe-aacd-4da2-8f13-705f05c4b11e"
|
|
|
+ },
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "torch.manual_seed(123)\n",
|
|
|
+ "\n",
|
|
|
+ "with device:\n",
|
|
|
+ " model = Qwen3Model(QWEN3_CONFIG)\n",
|
|
|
+ "\n",
|
|
|
+ "#model.to(device)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "90aca91d-4bee-45ce-993a-4ec5393abe2b",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "- A quick check that the forward pass works before continuing (nan values are ok for now since we are using a \"meta\" device upon instantiation to save memory):"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 13,
|
|
|
+ "id": "adf0a6b7-b688-42c9-966e-c223d34db99f",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "data": {
|
|
|
+ "text/plain": [
|
|
|
+ "tensor([[[nan, nan, nan, ..., nan, nan, nan],\n",
|
|
|
+ " [nan, nan, nan, ..., nan, nan, nan],\n",
|
|
|
+ " [nan, nan, nan, ..., nan, nan, nan]]], device='cuda:0',\n",
|
|
|
+ " dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "execution_count": 13,
|
|
|
+ "metadata": {},
|
|
|
+ "output_type": "execute_result"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "model(torch.tensor([1, 2, 3]).unsqueeze(0).to(device))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
|
|
|
+ "metadata": {
|
|
|
+ "colab": {
|
|
|
+ "base_uri": "https://localhost:8080/"
|
|
|
+ },
|
|
|
+ "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
|
|
|
+ "outputId": "00d7e983-262e-4c65-f322-f4d999311988"
|
|
|
+ },
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "Total number of parameters: 30,532,122,624\n",
|
|
|
+ "\n",
|
|
|
+ "Total number of unique parameters: 30,220,957,696\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "total_params = sum(p.numel() for p in model.parameters())\n",
|
|
|
+ "print(f\"Total number of parameters: {total_params:,}\")\n",
|
|
|
+ "\n",
|
|
|
+ "# Account for weight tying\n",
|
|
|
+ "total_params_normalized = total_params - model.tok_emb.weight.numel()\n",
|
|
|
+ "print(f\"\\nTotal number of unique parameters: {total_params_normalized:,}\")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 15,
|
|
|
+ "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
|
|
|
+ "metadata": {
|
|
|
+ "colab": {
|
|
|
+ "base_uri": "https://localhost:8080/"
|
|
|
+ },
|
|
|
+ "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
|
|
|
+ "outputId": "65c1a95e-b502-4150-9e2e-da619d9053d5"
|
|
|
+ },
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "float32 (PyTorch default): 227.73 GB\n",
|
|
|
+ "bfloat16: 113.87 GB\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "def model_memory_size(model, input_dtype=torch.float32):\n",
|
|
|
+ " total_params = 0\n",
|
|
|
+ " total_grads = 0\n",
|
|
|
+ " for param in model.parameters():\n",
|
|
|
+ " # Calculate total number of elements per parameter\n",
|
|
|
+ " param_size = param.numel()\n",
|
|
|
+ " total_params += param_size\n",
|
|
|
+ " # Check if gradients are stored for this parameter\n",
|
|
|
+ " if param.requires_grad:\n",
|
|
|
+ " total_grads += param_size\n",
|
|
|
+ "\n",
|
|
|
+ " # Calculate buffer size (non-parameters that require memory)\n",
|
|
|
+ " total_buffers = sum(buf.numel() for buf in model.buffers())\n",
|
|
|
+ "\n",
|
|
|
+ " # Size in bytes = (Number of elements) * (Size of each element in bytes)\n",
|
|
|
+ " # We assume parameters and gradients are stored in the same type as input dtype\n",
|
|
|
+ " element_size = torch.tensor(0, dtype=input_dtype).element_size()\n",
|
|
|
+ " total_memory_bytes = (total_params + total_grads + total_buffers) * element_size\n",
|
|
|
+ "\n",
|
|
|
+ " # Convert bytes to gigabytes\n",
|
|
|
+ " total_memory_gb = total_memory_bytes / (1024**3)\n",
|
|
|
+ "\n",
|
|
|
+ " return total_memory_gb\n",
|
|
|
+ "\n",
|
|
|
+ "print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
|
|
|
+ "print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "4686eeb7-281f-4c5c-b37a-ed21d0a10427",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "- Don't be concerned; the model runs fine on an A100 card with 80 GB RAM due to offloading some layers to CPU RAM"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "c172f89f-d301-439f-b809-46169e5f5945",
|
|
|
+ "metadata": {
|
|
|
+ "id": "c172f89f-d301-439f-b809-46169e5f5945"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ " \n",
|
|
|
+ "# 4. Load pretrained weights"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 16,
|
|
|
+ "id": "75166128-5899-4995-9b88-9672e135650e",
|
|
|
+ "metadata": {
|
|
|
+ "id": "75166128-5899-4995-9b88-9672e135650e"
|
|
|
+ },
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def load_weights_into_qwen(model, param_config, params):\n",
|
|
|
+ " def assign(left, right, tensor_name=\"unknown\"):\n",
|
|
|
+ " if left.shape != right.shape:\n",
|
|
|
+ " raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
|
|
|
+ " return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))\n",
|
|
|
+ "\n",
|
|
|
+ " model.tok_emb.weight = assign(model.tok_emb.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n",
|
|
|
+ "\n",
|
|
|
+ " for l in range(param_config[\"n_layers\"]):\n",
|
|
|
+ " block = model.trf_blocks[l]\n",
|
|
|
+ " att = block.att\n",
|
|
|
+ "\n",
|
|
|
+ " # Q, K, V projections\n",
|
|
|
+ " att.W_query.weight = assign(\n",
|
|
|
+ " att.W_query.weight,\n",
|
|
|
+ " params[f\"model.layers.{l}.self_attn.q_proj.weight\"],\n",
|
|
|
+ " f\"model.layers.{l}.self_attn.q_proj.weight\"\n",
|
|
|
+ " )\n",
|
|
|
+ " att.W_key.weight = assign(\n",
|
|
|
+ " att.W_key.weight,\n",
|
|
|
+ " params[f\"model.layers.{l}.self_attn.k_proj.weight\"],\n",
|
|
|
+ " f\"model.layers.{l}.self_attn.k_proj.weight\"\n",
|
|
|
+ " )\n",
|
|
|
+ " att.W_value.weight = assign(\n",
|
|
|
+ " att.W_value.weight,\n",
|
|
|
+ " params[f\"model.layers.{l}.self_attn.v_proj.weight\"],\n",
|
|
|
+ " f\"model.layers.{l}.self_attn.v_proj.weight\"\n",
|
|
|
+ " )\n",
|
|
|
+ "\n",
|
|
|
+ " # Output projection\n",
|
|
|
+ " att.out_proj.weight = assign(\n",
|
|
|
+ " att.out_proj.weight,\n",
|
|
|
+ " params[f\"model.layers.{l}.self_attn.o_proj.weight\"],\n",
|
|
|
+ " f\"model.layers.{l}.self_attn.o_proj.weight\"\n",
|
|
|
+ " )\n",
|
|
|
+ "\n",
|
|
|
+ " # QK norms\n",
|
|
|
+ " if hasattr(att, \"q_norm\") and att.q_norm is not None:\n",
|
|
|
+ " att.q_norm.scale = assign(\n",
|
|
|
+ " att.q_norm.scale,\n",
|
|
|
+ " params[f\"model.layers.{l}.self_attn.q_norm.weight\"],\n",
|
|
|
+ " f\"model.layers.{l}.self_attn.q_norm.weight\"\n",
|
|
|
+ " )\n",
|
|
|
+ " if hasattr(att, \"k_norm\") and att.k_norm is not None:\n",
|
|
|
+ " att.k_norm.scale = assign(\n",
|
|
|
+ " att.k_norm.scale,\n",
|
|
|
+ " params[f\"model.layers.{l}.self_attn.k_norm.weight\"],\n",
|
|
|
+ " f\"model.layers.{l}.self_attn.k_norm.weight\"\n",
|
|
|
+ " )\n",
|
|
|
+ "\n",
|
|
|
+ " # Attention layernorm\n",
|
|
|
+ " block.norm1.scale = assign(\n",
|
|
|
+ " block.norm1.scale,\n",
|
|
|
+ " params[f\"model.layers.{l}.input_layernorm.weight\"],\n",
|
|
|
+ " f\"model.layers.{l}.input_layernorm.weight\"\n",
|
|
|
+ " )\n",
|
|
|
+ "\n",
|
|
|
+ " # Feedforward weights\n",
|
|
|
+ " if \"num_experts\" in param_config:\n",
|
|
|
+ " # Load router (gating) weights\n",
|
|
|
+ " block.ff.gate.weight = assign(\n",
|
|
|
+ " block.ff.gate.weight,\n",
|
|
|
+ " params[f\"model.layers.{l}.mlp.gate.weight\"],\n",
|
|
|
+ " f\"model.layers.{l}.mlp.gate.weight\"\n",
|
|
|
+ " )\n",
|
|
|
+ " # Load expert weights\n",
|
|
|
+ " for e in range(param_config[\"num_experts\"]):\n",
|
|
|
+ " prefix = f\"model.layers.{l}.mlp.experts.{e}\"\n",
|
|
|
+ " block.ff.fc1[e].weight = assign(\n",
|
|
|
+ " block.ff.fc1[e].weight,\n",
|
|
|
+ " params[f\"{prefix}.gate_proj.weight\"],\n",
|
|
|
+ " f\"{prefix}.gate_proj.weight\"\n",
|
|
|
+ " )\n",
|
|
|
+ " block.ff.fc2[e].weight = assign(\n",
|
|
|
+ " block.ff.fc2[e].weight,\n",
|
|
|
+ " params[f\"{prefix}.up_proj.weight\"],\n",
|
|
|
+ " f\"{prefix}.up_proj.weight\"\n",
|
|
|
+ " )\n",
|
|
|
+ " block.ff.fc3[e].weight = assign(\n",
|
|
|
+ " block.ff.fc3[e].weight,\n",
|
|
|
+ " params[f\"{prefix}.down_proj.weight\"],\n",
|
|
|
+ " f\"{prefix}.down_proj.weight\"\n",
|
|
|
+ " )\n",
|
|
|
+ " # After assigning weights, move the expert layers from meta to CPU\n",
|
|
|
+ " block.ff.fc1[e] = block.ff.fc1[e].to(\"cpu\")\n",
|
|
|
+ " block.ff.fc2[e] = block.ff.fc2[e].to(\"cpu\")\n",
|
|
|
+ " block.ff.fc3[e] = block.ff.fc3[e].to(\"cpu\")\n",
|
|
|
+ "\n",
|
|
|
+ " else:\n",
|
|
|
+ " block.ff.fc1.weight = assign(\n",
|
|
|
+ " block.ff.fc1.weight,\n",
|
|
|
+ " params[f\"model.layers.{l}.mlp.gate_proj.weight\"],\n",
|
|
|
+ " f\"model.layers.{l}.mlp.gate_proj.weight\"\n",
|
|
|
+ " )\n",
|
|
|
+ " block.ff.fc2.weight = assign(\n",
|
|
|
+ " block.ff.fc2.weight,\n",
|
|
|
+ " params[f\"model.layers.{l}.mlp.up_proj.weight\"],\n",
|
|
|
+ " f\"model.layers.{l}.mlp.up_proj.weight\"\n",
|
|
|
+ " )\n",
|
|
|
+ " block.ff.fc3.weight = assign(\n",
|
|
|
+ " block.ff.fc3.weight,\n",
|
|
|
+ " params[f\"model.layers.{l}.mlp.down_proj.weight\"],\n",
|
|
|
+ " f\"model.layers.{l}.mlp.down_proj.weight\"\n",
|
|
|
+ " )\n",
|
|
|
+ "\n",
|
|
|
+ " block.norm2.scale = assign(\n",
|
|
|
+ " block.norm2.scale,\n",
|
|
|
+ " params[f\"model.layers.{l}.post_attention_layernorm.weight\"],\n",
|
|
|
+ " f\"model.layers.{l}.post_attention_layernorm.weight\"\n",
|
|
|
+ " )\n",
|
|
|
+ "\n",
|
|
|
+ " # Final normalization and output head\n",
|
|
|
+ " model.final_norm.scale = assign(model.final_norm.scale, params[\"model.norm.weight\"], \"model.norm.weight\")\n",
|
|
|
+ "\n",
|
|
|
+ " if \"lm_head.weight\" in params:\n",
|
|
|
+ " model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n",
|
|
|
+ " else:\n",
|
|
|
+ " # Model uses weight tying, hence we reuse the embedding layer weights here\n",
|
|
|
+ " print(\"Model uses weight tying.\")\n",
|
|
|
+ " model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 17,
|
|
|
+ "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
|
|
|
+ "metadata": {
|
|
|
+ "colab": {
|
|
|
+ "base_uri": "https://localhost:8080/",
|
|
|
+ "height": 17,
|
|
|
+ "referenced_widgets": [
|
|
|
+ "9881b6995c3f49dc89e6992fd9ab660b",
|
|
|
+ "17a3174e65c54476b2e0d1faf8f011ca",
|
|
|
+ "1bbf2e62c0754d1593beb4105a7f1ac1",
|
|
|
+ "b82112e1dec645d98aa1c1ba64abcb61",
|
|
|
+ "271e2bd6a35e4a8b92de8697f7c0be5f",
|
|
|
+ "90a79523187446dfa692723b2e5833a7",
|
|
|
+ "431ffb83b8c14bf182f0430e07ea6154",
|
|
|
+ "a8f1b72a33dd4b548de23fbd95e0da18",
|
|
|
+ "25cc36132d384189acfbecc59483134b",
|
|
|
+ "bfd06423ad544218968648016e731a46",
|
|
|
+ "d029630b63ff44cf807ade428d2eb421"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
|
|
|
+ "outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d"
|
|
|
+ },
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "data": {
|
|
|
+ "application/vnd.jupyter.widget-view+json": {
|
|
|
+ "model_id": "acdfb3a707444d7691bc8f1b053224b1",
|
|
|
+ "version_major": 2,
|
|
|
+ "version_minor": 0
|
|
|
+ },
|
|
|
+ "text/plain": [
|
|
|
+ "Fetching 27 files: 0%| | 0/27 [00:00<?, ?it/s]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "metadata": {},
|
|
|
+ "output_type": "display_data"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "import json\n",
|
|
|
+ "import os\n",
|
|
|
+ "from pathlib import Path\n",
|
|
|
+ "from safetensors.torch import load_file\n",
|
|
|
+ "from huggingface_hub import snapshot_download\n",
|
|
|
+ "\n",
|
|
|
+ "repo_id = \"Qwen/Qwen3-30B-A3B\" # Original Instruct/Thinking hybrind model\n",
|
|
|
+ "repo_id = \"Qwen/Qwen3-235B-A22B-Instruct-2507\" # New instruct model\n",
|
|
|
+ "repo_id = \"Qwen/Qwen3-30B-A3B-Thinking-2507\" # New thinking model\n",
|
|
|
+ "repo_id = \"Qwen/Qwen3-Coder-30B-A3B-Instruct\" # (Qwen3 Coder Flash)\n",
|
|
|
+ "\n",
|
|
|
+ "local_dir = Path(repo_id).parts[-1]\n",
|
|
|
+ "\n",
|
|
|
+ "repo_dir = snapshot_download(repo_id=repo_id, local_dir=local_dir)\n",
|
|
|
+ "index_path = os.path.join(repo_dir, \"model.safetensors.index.json\")\n",
|
|
|
+ "with open(index_path, \"r\") as f:\n",
|
|
|
+ " index = json.load(f)\n",
|
|
|
+ "\n",
|
|
|
+ "weights_dict = {}\n",
|
|
|
+ "for filename in set(index[\"weight_map\"].values()):\n",
|
|
|
+ " shard_path = os.path.join(repo_dir, filename)\n",
|
|
|
+ " shard = load_file(shard_path)\n",
|
|
|
+ " weights_dict.update(shard)\n",
|
|
|
+ "\n",
|
|
|
+ "load_weights_into_qwen(model, QWEN3_CONFIG, weights_dict)\n",
|
|
|
+ "model.to(device);"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "6b345491-3510-4397-92d3-cd0a3fa3deee",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ " \n",
|
|
|
+ "# 4. Load tokenizer"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 18,
|
|
|
+ "id": "b68ab489-48e5-471e-a814-56cda2d60f81",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "import re\n",
|
|
|
+ "from tokenizers import Tokenizer\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "class Qwen3Tokenizer:\n",
|
|
|
+ " _SPECIALS = [\n",
|
|
|
+ " \"<|endoftext|>\",\n",
|
|
|
+ " \"<|im_start|>\", \"<|im_end|>\",\n",
|
|
|
+ " \"<|object_ref_start|>\", \"<|object_ref_end|>\",\n",
|
|
|
+ " \"<|box_start|>\", \"<|box_end|>\",\n",
|
|
|
+ " \"<|quad_start|>\", \"<|quad_end|>\",\n",
|
|
|
+ " \"<|vision_start|>\", \"<|vision_end|>\",\n",
|
|
|
+ " \"<|vision_pad|>\", \"<|image_pad|>\", \"<|video_pad|>\",\n",
|
|
|
+ " ]\n",
|
|
|
+ " _SPLIT_RE = re.compile(r\"(<\\|[^>]+?\\|>)\")\n",
|
|
|
+ "\n",
|
|
|
+ " def __init__(self, tokenizer_file_path=\"tokenizer.json\", repo_id=None,\n",
|
|
|
+ " apply_chat_template=True, add_generation_prompt=False, add_thinking=False):\n",
|
|
|
+ "\n",
|
|
|
+ " self.apply_chat_template = apply_chat_template\n",
|
|
|
+ " self.add_generation_prompt = add_generation_prompt\n",
|
|
|
+ " self.add_thinking = add_thinking\n",
|
|
|
+ "\n",
|
|
|
+ " tok_file = Path(tokenizer_file_path)\n",
|
|
|
+ " self._tok = Tokenizer.from_file(str(tok_file))\n",
|
|
|
+ " self._special_to_id = {t: self._tok.token_to_id(t) for t in self._SPECIALS}\n",
|
|
|
+ "\n",
|
|
|
+ " self.pad_token_id = self._special_to_id.get(\"<|endoftext|>\")\n",
|
|
|
+ " self.eos_token_id = self.pad_token_id\n",
|
|
|
+ "\n",
|
|
|
+ " if repo_id and \"Base\" not in repo_id:\n",
|
|
|
+ " eos_token = \"<|im_end|>\"\n",
|
|
|
+ " else:\n",
|
|
|
+ " eos_token = \"<|endoftext|>\"\n",
|
|
|
+ " if eos_token in self._special_to_id:\n",
|
|
|
+ " self.eos_token_id = self._special_to_id[eos_token]\n",
|
|
|
+ "\n",
|
|
|
+ " def encode(self, text, chat_wrapped=None):\n",
|
|
|
+ " if chat_wrapped is None:\n",
|
|
|
+ " chat_wrapped = self.apply_chat_template\n",
|
|
|
+ "\n",
|
|
|
+ " stripped = text.strip()\n",
|
|
|
+ " if stripped in self._special_to_id and \"\\n\" not in stripped:\n",
|
|
|
+ " return [self._special_to_id[stripped]]\n",
|
|
|
+ "\n",
|
|
|
+ " if chat_wrapped:\n",
|
|
|
+ " text = self._wrap_chat(text)\n",
|
|
|
+ "\n",
|
|
|
+ " ids = []\n",
|
|
|
+ " for part in filter(None, self._SPLIT_RE.split(text)):\n",
|
|
|
+ " if part in self._special_to_id:\n",
|
|
|
+ " ids.append(self._special_to_id[part])\n",
|
|
|
+ " else:\n",
|
|
|
+ " ids.extend(self._tok.encode(part).ids)\n",
|
|
|
+ " return ids\n",
|
|
|
+ "\n",
|
|
|
+ " def decode(self, ids):\n",
|
|
|
+ " return self._tok.decode(ids, skip_special_tokens=False)\n",
|
|
|
+ "\n",
|
|
|
+ " def _wrap_chat(self, user_msg):\n",
|
|
|
+ " s = f\"<|im_start|>user\\n{user_msg}<|im_end|>\\n\"\n",
|
|
|
+ " if self.add_generation_prompt:\n",
|
|
|
+ " s += \"<|im_start|>assistant\"\n",
|
|
|
+ " if self.add_thinking:\n",
|
|
|
+ " s += \"\\n\"\n",
|
|
|
+ " else:\n",
|
|
|
+ " s += \"\\n<think>\\n\\n</think>\\n\\n\"\n",
|
|
|
+ " return s"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 19,
|
|
|
+ "id": "7b6df8bc-7308-468e-93ce-2d5529ea7866",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "tokenizer_file_path = f\"{Path(repo_id).parts[-1]}/tokenizer.json\"\n",
|
|
|
+ "\n",
|
|
|
+ "tokenizer = Qwen3Tokenizer(\n",
|
|
|
+ " tokenizer_file_path=tokenizer_file_path,\n",
|
|
|
+ " repo_id=repo_id,\n",
|
|
|
+ " add_generation_prompt=True,\n",
|
|
|
+ " add_thinking=True\n",
|
|
|
+ ")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 21,
|
|
|
+ "id": "1946b534-e3af-431a-a222-391a60bfa892",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "data": {
|
|
|
+ "text/plain": [
|
|
|
+ "'<|im_start|>user\\nImplement a binary search function in Python<|im_end|>\\n<|im_start|>assistant\\n'"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "execution_count": 21,
|
|
|
+ "metadata": {},
|
|
|
+ "output_type": "execute_result"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# prompt = \"Give me a short introduction to large language models.\"\n",
|
|
|
+ "prompt = \"Implement a binary search function in Python\"\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "input_token_ids = tokenizer.encode(prompt)\n",
|
|
|
+ "text = tokenizer.decode(input_token_ids)\n",
|
|
|
+ "text"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "57d07df1-4401-4792-b549-7c4cc5632323",
|
|
|
+ "metadata": {
|
|
|
+ "id": "57d07df1-4401-4792-b549-7c4cc5632323"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ " \n",
|
|
|
+ "# 5. Generate text"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 22,
|
|
|
+ "id": "60b9fc72",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def generate_text_basic_stream(model, token_ids, max_new_tokens, eos_token_id=None, context_size=None):\n",
|
|
|
+ " model.eval()\n",
|
|
|
+ "\n",
|
|
|
+ " with torch.no_grad():\n",
|
|
|
+ " cache = KVCache(n_layers=model.cfg[\"n_layers\"])\n",
|
|
|
+ " model.reset_kv_cache()\n",
|
|
|
+ "\n",
|
|
|
+ " # Prime the cache with the initial context\n",
|
|
|
+ " logits = model(token_ids, cache=cache)\n",
|
|
|
+ "\n",
|
|
|
+ " for _ in range(max_new_tokens):\n",
|
|
|
+ " next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)\n",
|
|
|
+ "\n",
|
|
|
+ " if eos_token_id is not None and torch.all(next_token == eos_token_id):\n",
|
|
|
+ " break\n",
|
|
|
+ "\n",
|
|
|
+ " yield next_token\n",
|
|
|
+ "\n",
|
|
|
+ " token_ids = torch.cat([token_ids, next_token], dim=1)\n",
|
|
|
+ "\n",
|
|
|
+ " # Feed only the new token to the model; cache handles history\n",
|
|
|
+ " logits = model(next_token, cache=cache)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 23,
|
|
|
+ "id": "a5b30753",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "Here's a comprehensive implementation of binary search in Python with both iterative and recursive approaches:\n",
|
|
|
+ "\n",
|
|
|
+ "## Iterative Binary Search\n",
|
|
|
+ "\n",
|
|
|
+ "```python\n",
|
|
|
+ "def binary_search(arr, target):\n",
|
|
|
+ " \"\"\"\n",
|
|
|
+ " Iterative binary search implementation\n",
|
|
|
+ " \n",
|
|
|
+ " Args:\n",
|
|
|
+ " arr: Sorted list of elements\n",
|
|
|
+ " target: Element to search for\n",
|
|
|
+ " \n",
|
|
|
+ " Returns:\n",
|
|
|
+ " int: Index of target if found, -1 if not found\n",
|
|
|
+ " \n",
|
|
|
+ " Time Complexity: O(log n)\n",
|
|
|
+ " Space Complexity: O(1)\n",
|
|
|
+ " \"\"\"\n",
|
|
|
+ " left = 0\n",
|
|
|
+ " right = len(arr) - 1\n",
|
|
|
+ " \n",
|
|
|
+ " while left <= right:\n",
|
|
|
+ " # Calculate middle index (avoiding potential overflow)\n",
|
|
|
+ " mid = left + (right - left) // 2\n",
|
|
|
+ " \n",
|
|
|
+ " if arr[mid] == target:\n",
|
|
|
+ " return mid\n",
|
|
|
+ " elif arr[mid] < target:\n",
|
|
|
+ " left = mid + 1\n",
|
|
|
+ " else:\n",
|
|
|
+ " right = mid - 1\n",
|
|
|
+ " \n",
|
|
|
+ " return -1 # Target not found\n",
|
|
|
+ "```\n",
|
|
|
+ "\n",
|
|
|
+ "## Recursive Binary Search\n",
|
|
|
+ "\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "for token in generate_text_basic_stream(\n",
|
|
|
+ " model=model,\n",
|
|
|
+ " token_ids=input_token_ids_tensor,\n",
|
|
|
+ " max_new_tokens=200,\n",
|
|
|
+ " eos_token_id=tokenizer.eos_token_id\n",
|
|
|
+ "):\n",
|
|
|
+ " token_id = token.squeeze(0).tolist()\n",
|
|
|
+ " print(\n",
|
|
|
+ " tokenizer.decode(token_id),\n",
|
|
|
+ " end=\"\",\n",
|
|
|
+ " flush=True\n",
|
|
|
+ " )"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "549324d6-5c71-4147-ae21-2e67675faa3d",
|
|
|
+ "metadata": {
|
|
|
+ "id": "549324d6-5c71-4147-ae21-2e67675faa3d"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ " \n",
|
|
|
+ "# What's next?"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c",
|
|
|
+ "metadata": {
|
|
|
+ "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "- Check out the [README.md](./README.md), to use this model via the `llms_from_scratch` package\n",
|
|
|
+ "- For those interested in a comprehensive guide on building a large language model from scratch and gaining a deeper understanding of its mechanics, you might like my [Build a Large Language Model (From Scratch)](http://mng.bz/orYv)\n",
|
|
|
+ "\n",
|
|
|
+ "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "metadata": {
|
|
|
+ "accelerator": "GPU",
|
|
|
+ "colab": {
|
|
|
+ "gpuType": "A100",
|
|
|
+ "provenance": []
|
|
|
+ },
|
|
|
+ "kernelspec": {
|
|
|
+ "display_name": "Python 3 (ipykernel)",
|
|
|
+ "language": "python",
|
|
|
+ "name": "python3"
|
|
|
+ },
|
|
|
+ "language_info": {
|
|
|
+ "codemirror_mode": {
|
|
|
+ "name": "ipython",
|
|
|
+ "version": 3
|
|
|
+ },
|
|
|
+ "file_extension": ".py",
|
|
|
+ "mimetype": "text/x-python",
|
|
|
+ "name": "python",
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
+ "version": "3.10.16"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "nbformat": 4,
|
|
|
+ "nbformat_minor": 5
|
|
|
+}
|