Explorar el Código

Interleaved Q and K for RoPE in Llama 2 (#750)

Sebastian Raschka hace 4 meses
padre
commit
b12dbf6c68
Se han modificado 1 ficheros con 33 adiciones y 39 borrados
  1. 33 39
      ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb

+ 33 - 39
ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb

@@ -83,7 +83,7 @@
      "name": "stdout",
      "name": "stdout",
      "output_type": "stream",
      "output_type": "stream",
      "text": [
      "text": [
-      "huggingface_hub version: 0.33.0\n",
+      "huggingface_hub version: 0.33.2\n",
       "sentencepiece version: 0.2.0\n",
       "sentencepiece version: 0.2.0\n",
       "torch version: 2.6.0\n"
       "torch version: 2.6.0\n"
      ]
      ]
@@ -1306,22 +1306,7 @@
     "id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4",
     "id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4",
     "outputId": "0d8942cc-e5e2-4e77-ec41-1ac7bec7d94f"
     "outputId": "0d8942cc-e5e2-4e77-ec41-1ac7bec7d94f"
    },
    },
-   "outputs": [
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "66e777955e8748df878f118f07f38dab",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "consolidated.00.pth:   0%|          | 0.00/13.5G [00:00<?, ?B/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    }
-   ],
+   "outputs": [],
    "source": [
    "source": [
     "weights_file = hf_hub_download(\n",
     "weights_file = hf_hub_download(\n",
     "   repo_id=\"meta-llama/Llama-2-7b\",\n",
     "   repo_id=\"meta-llama/Llama-2-7b\",\n",
@@ -1405,7 +1390,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 29,
+   "execution_count": 32,
    "id": "3820e2a7-4f26-41bc-953b-f3879b0aff65",
    "id": "3820e2a7-4f26-41bc-953b-f3879b0aff65",
    "metadata": {
    "metadata": {
     "id": "3820e2a7-4f26-41bc-953b-f3879b0aff65"
     "id": "3820e2a7-4f26-41bc-953b-f3879b0aff65"
@@ -1422,19 +1407,40 @@
     "        return torch.nn.Parameter(torch.tensor(right))\n",
     "        return torch.nn.Parameter(torch.tensor(right))\n",
     "\n",
     "\n",
     "\n",
     "\n",
+    "def permute(w: torch.Tensor, n_heads, out_dim, in_dim):\n",
+    "    return (w.view(n_heads, out_dim // n_heads // 2, 2, in_dim)\n",
+    "             .transpose(1, 2)          # put axis 2 next to heads\n",
+    "             .reshape(out_dim, in_dim))\n",
+    "\n",
+    "\n",
     "def load_weights_into_llama(model, param_config, params):\n",
     "def load_weights_into_llama(model, param_config, params):\n",
+    "\n",
+    "    cfg = LLAMA2_CONFIG_7B\n",
+    "    \n",
     "    model.tok_emb.weight = assign(model.tok_emb.weight, params[\"tok_embeddings.weight\"])\n",
     "    model.tok_emb.weight = assign(model.tok_emb.weight, params[\"tok_embeddings.weight\"])\n",
     "\n",
     "\n",
     "    for l in range(param_config[\"n_layers\"]):\n",
     "    for l in range(param_config[\"n_layers\"]):\n",
     "\n",
     "\n",
-    "        # Load attention weights\n",
+    "        # The original Meta/Llama checkpoints store Q and K so that the two numbers \n",
+    "        # that form one complex RoPE pair sit next to each other inside the head dimension (\"sliced\" layout).\n",
+    "        # Our RoPE implementation, similar to the one in Hugging Face, expects an interleaved layout\n",
+    "        # For example, with n_heads=2 and head_dim = 8\n",
+    "        #                         ┌── pair 0 ──┐      ┌── pair 1 ──┐\n",
+    "        # Meta (sliced):    [ h0:  r0 r1 r2 r3,   h1:  r0 r1 r2 r3  ]\n",
+    "        # Ours & HF (interleaved):  [ h0: r0 r0 r1 r1 r2 r2 r3 r3  , h1: ... ]\n",
+    "        # For more information, please see the discussion in the PR: https://github.com/rasbt/LLMs-from-scratch/pull/747 \n",
+    "        \n",
+    "        # So, below, for q_raw and k_raw, we must re‑order the checkpoint weights using the slices_to_interleave helper\n",
+    "\n",
+    "        q_raw = params[f\"layers.{l}.attention.wq.weight\"]\n",
     "        model.trf_blocks[l].att.W_query.weight = assign(\n",
     "        model.trf_blocks[l].att.W_query.weight = assign(\n",
     "            model.trf_blocks[l].att.W_query.weight,\n",
     "            model.trf_blocks[l].att.W_query.weight,\n",
-    "            params[f\"layers.{l}.attention.wq.weight\"]\n",
+    "            permute(q_raw, cfg[\"n_heads\"], cfg[\"emb_dim\"], cfg[\"emb_dim\"])\n",
     "        )\n",
     "        )\n",
+    "        k_raw = params[f\"layers.{l}.attention.wk.weight\"]\n",
     "        model.trf_blocks[l].att.W_key.weight = assign(\n",
     "        model.trf_blocks[l].att.W_key.weight = assign(\n",
     "            model.trf_blocks[l].att.W_key.weight,\n",
     "            model.trf_blocks[l].att.W_key.weight,\n",
-    "            params[f\"layers.{l}.attention.wk.weight\"]\n",
+    "            permute(k_raw, cfg[\"n_heads\"], cfg[\"emb_dim\"], cfg[\"emb_dim\"])\n",
     "        )\n",
     "        )\n",
     "        model.trf_blocks[l].att.W_value.weight = assign(\n",
     "        model.trf_blocks[l].att.W_value.weight = assign(\n",
     "            model.trf_blocks[l].att.W_value.weight,\n",
     "            model.trf_blocks[l].att.W_value.weight,\n",
@@ -1489,7 +1495,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 30,
+   "execution_count": 33,
    "id": "240987e8-a023-462e-9376-9edfb27559ec",
    "id": "240987e8-a023-462e-9376-9edfb27559ec",
    "metadata": {
    "metadata": {
     "colab": {
     "colab": {
@@ -1504,7 +1510,7 @@
      "output_type": "stream",
      "output_type": "stream",
      "text": [
      "text": [
       "Output text:\n",
       "Output text:\n",
-      " Every effort has been made to ensure that the information contained in this website is accurate and up to date and correct at the time of publication\n"
+      " Every effort has been made to ensure the accuracy of the information contained in this website. However, the information contained in this website is not\n"
      ]
      ]
     }
     }
    ],
    ],
@@ -1544,7 +1550,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 34,
+   "execution_count": 35,
    "id": "nbvAV7vaz6yc",
    "id": "nbvAV7vaz6yc",
    "metadata": {
    "metadata": {
     "colab": {
     "colab": {
@@ -1568,27 +1574,14 @@
     "outputId": "724f5508-d976-4e31-b3d7-95fa65b2c1e8"
     "outputId": "724f5508-d976-4e31-b3d7-95fa65b2c1e8"
    },
    },
    "outputs": [
    "outputs": [
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "3b2448a60f5f4ba5b2c686037c8ecd78",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "consolidated.00.pth:   0%|          | 0.00/13.5G [00:00<?, ?B/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
     {
     {
      "name": "stdout",
      "name": "stdout",
      "output_type": "stream",
      "output_type": "stream",
      "text": [
      "text": [
       "Output text:\n",
       "Output text:\n",
       " What do llamas eat?\n",
       " What do llamas eat?\n",
-      "Llamas and alpacas are herbivores, which means they eat grasses, leaves, grass\n"
+      "\n",
+      "Llamas are herbivores, which means they eat plants for their food. They feed on a variety\n"
      ]
      ]
     }
     }
    ],
    ],
@@ -1601,6 +1594,7 @@
     "   local_dir=\"Llama-2-7b-chat\"\n",
     "   local_dir=\"Llama-2-7b-chat\"\n",
     ")\n",
     ")\n",
     "\n",
     "\n",
+    "weights = torch.load(weights_file, weights_only=True)\n",
     "model = Llama2Model(LLAMA2_CONFIG_7B)\n",
     "model = Llama2Model(LLAMA2_CONFIG_7B)\n",
     "load_weights_into_llama(model, LLAMA2_CONFIG_7B, weights)\n",
     "load_weights_into_llama(model, LLAMA2_CONFIG_7B, weights)\n",
     "model.to(device);\n",
     "model.to(device);\n",