|
@@ -83,7 +83,7 @@
|
|
|
"name": "stdout",
|
|
"name": "stdout",
|
|
|
"output_type": "stream",
|
|
"output_type": "stream",
|
|
|
"text": [
|
|
"text": [
|
|
|
- "huggingface_hub version: 0.33.0\n",
|
|
|
|
|
|
|
+ "huggingface_hub version: 0.33.2\n",
|
|
|
"sentencepiece version: 0.2.0\n",
|
|
"sentencepiece version: 0.2.0\n",
|
|
|
"torch version: 2.6.0\n"
|
|
"torch version: 2.6.0\n"
|
|
|
]
|
|
]
|
|
@@ -1306,22 +1306,7 @@
|
|
|
"id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4",
|
|
"id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4",
|
|
|
"outputId": "0d8942cc-e5e2-4e77-ec41-1ac7bec7d94f"
|
|
"outputId": "0d8942cc-e5e2-4e77-ec41-1ac7bec7d94f"
|
|
|
},
|
|
},
|
|
|
- "outputs": [
|
|
|
|
|
- {
|
|
|
|
|
- "data": {
|
|
|
|
|
- "application/vnd.jupyter.widget-view+json": {
|
|
|
|
|
- "model_id": "66e777955e8748df878f118f07f38dab",
|
|
|
|
|
- "version_major": 2,
|
|
|
|
|
- "version_minor": 0
|
|
|
|
|
- },
|
|
|
|
|
- "text/plain": [
|
|
|
|
|
- "consolidated.00.pth: 0%| | 0.00/13.5G [00:00<?, ?B/s]"
|
|
|
|
|
- ]
|
|
|
|
|
- },
|
|
|
|
|
- "metadata": {},
|
|
|
|
|
- "output_type": "display_data"
|
|
|
|
|
- }
|
|
|
|
|
- ],
|
|
|
|
|
|
|
+ "outputs": [],
|
|
|
"source": [
|
|
"source": [
|
|
|
"weights_file = hf_hub_download(\n",
|
|
"weights_file = hf_hub_download(\n",
|
|
|
" repo_id=\"meta-llama/Llama-2-7b\",\n",
|
|
" repo_id=\"meta-llama/Llama-2-7b\",\n",
|
|
@@ -1405,7 +1390,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 29,
|
|
|
|
|
|
|
+ "execution_count": 32,
|
|
|
"id": "3820e2a7-4f26-41bc-953b-f3879b0aff65",
|
|
"id": "3820e2a7-4f26-41bc-953b-f3879b0aff65",
|
|
|
"metadata": {
|
|
"metadata": {
|
|
|
"id": "3820e2a7-4f26-41bc-953b-f3879b0aff65"
|
|
"id": "3820e2a7-4f26-41bc-953b-f3879b0aff65"
|
|
@@ -1422,19 +1407,40 @@
|
|
|
" return torch.nn.Parameter(torch.tensor(right))\n",
|
|
" return torch.nn.Parameter(torch.tensor(right))\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
|
|
+ "def permute(w: torch.Tensor, n_heads, out_dim, in_dim):\n",
|
|
|
|
|
+ " return (w.view(n_heads, out_dim // n_heads // 2, 2, in_dim)\n",
|
|
|
|
|
+ " .transpose(1, 2) # put axis 2 next to heads\n",
|
|
|
|
|
+ " .reshape(out_dim, in_dim))\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "\n",
|
|
|
"def load_weights_into_llama(model, param_config, params):\n",
|
|
"def load_weights_into_llama(model, param_config, params):\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " cfg = LLAMA2_CONFIG_7B\n",
|
|
|
|
|
+ " \n",
|
|
|
" model.tok_emb.weight = assign(model.tok_emb.weight, params[\"tok_embeddings.weight\"])\n",
|
|
" model.tok_emb.weight = assign(model.tok_emb.weight, params[\"tok_embeddings.weight\"])\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
" for l in range(param_config[\"n_layers\"]):\n",
|
|
" for l in range(param_config[\"n_layers\"]):\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
- " # Load attention weights\n",
|
|
|
|
|
|
|
+ " # The original Meta/Llama checkpoints store Q and K so that the two numbers \n",
|
|
|
|
|
+ " # that form one complex RoPE pair sit next to each other inside the head dimension (\"sliced\" layout).\n",
|
|
|
|
|
+ " # Our RoPE implementation, similar to the one in Hugging Face, expects an interleaved layout\n",
|
|
|
|
|
+ " # For example, with n_heads=2 and head_dim = 8\n",
|
|
|
|
|
+ " # ┌── pair 0 ──┐ ┌── pair 1 ──┐\n",
|
|
|
|
|
+ " # Meta (sliced): [ h0: r0 r1 r2 r3, h1: r0 r1 r2 r3 ]\n",
|
|
|
|
|
+ " # Ours & HF (interleaved): [ h0: r0 r0 r1 r1 r2 r2 r3 r3 , h1: ... ]\n",
|
|
|
|
|
+ " # For more information, please see the discussion in the PR: https://github.com/rasbt/LLMs-from-scratch/pull/747 \n",
|
|
|
|
|
+ " \n",
|
|
|
|
|
+ " # So, below, for q_raw and k_raw, we must re‑order the checkpoint weights using the slices_to_interleave helper\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " q_raw = params[f\"layers.{l}.attention.wq.weight\"]\n",
|
|
|
" model.trf_blocks[l].att.W_query.weight = assign(\n",
|
|
" model.trf_blocks[l].att.W_query.weight = assign(\n",
|
|
|
" model.trf_blocks[l].att.W_query.weight,\n",
|
|
" model.trf_blocks[l].att.W_query.weight,\n",
|
|
|
- " params[f\"layers.{l}.attention.wq.weight\"]\n",
|
|
|
|
|
|
|
+ " permute(q_raw, cfg[\"n_heads\"], cfg[\"emb_dim\"], cfg[\"emb_dim\"])\n",
|
|
|
" )\n",
|
|
" )\n",
|
|
|
|
|
+ " k_raw = params[f\"layers.{l}.attention.wk.weight\"]\n",
|
|
|
" model.trf_blocks[l].att.W_key.weight = assign(\n",
|
|
" model.trf_blocks[l].att.W_key.weight = assign(\n",
|
|
|
" model.trf_blocks[l].att.W_key.weight,\n",
|
|
" model.trf_blocks[l].att.W_key.weight,\n",
|
|
|
- " params[f\"layers.{l}.attention.wk.weight\"]\n",
|
|
|
|
|
|
|
+ " permute(k_raw, cfg[\"n_heads\"], cfg[\"emb_dim\"], cfg[\"emb_dim\"])\n",
|
|
|
" )\n",
|
|
" )\n",
|
|
|
" model.trf_blocks[l].att.W_value.weight = assign(\n",
|
|
" model.trf_blocks[l].att.W_value.weight = assign(\n",
|
|
|
" model.trf_blocks[l].att.W_value.weight,\n",
|
|
" model.trf_blocks[l].att.W_value.weight,\n",
|
|
@@ -1489,7 +1495,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 30,
|
|
|
|
|
|
|
+ "execution_count": 33,
|
|
|
"id": "240987e8-a023-462e-9376-9edfb27559ec",
|
|
"id": "240987e8-a023-462e-9376-9edfb27559ec",
|
|
|
"metadata": {
|
|
"metadata": {
|
|
|
"colab": {
|
|
"colab": {
|
|
@@ -1504,7 +1510,7 @@
|
|
|
"output_type": "stream",
|
|
"output_type": "stream",
|
|
|
"text": [
|
|
"text": [
|
|
|
"Output text:\n",
|
|
"Output text:\n",
|
|
|
- " Every effort has been made to ensure that the information contained in this website is accurate and up to date and correct at the time of publication\n"
|
|
|
|
|
|
|
+ " Every effort has been made to ensure the accuracy of the information contained in this website. However, the information contained in this website is not\n"
|
|
|
]
|
|
]
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
@@ -1544,7 +1550,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 34,
|
|
|
|
|
|
|
+ "execution_count": 35,
|
|
|
"id": "nbvAV7vaz6yc",
|
|
"id": "nbvAV7vaz6yc",
|
|
|
"metadata": {
|
|
"metadata": {
|
|
|
"colab": {
|
|
"colab": {
|
|
@@ -1568,27 +1574,14 @@
|
|
|
"outputId": "724f5508-d976-4e31-b3d7-95fa65b2c1e8"
|
|
"outputId": "724f5508-d976-4e31-b3d7-95fa65b2c1e8"
|
|
|
},
|
|
},
|
|
|
"outputs": [
|
|
"outputs": [
|
|
|
- {
|
|
|
|
|
- "data": {
|
|
|
|
|
- "application/vnd.jupyter.widget-view+json": {
|
|
|
|
|
- "model_id": "3b2448a60f5f4ba5b2c686037c8ecd78",
|
|
|
|
|
- "version_major": 2,
|
|
|
|
|
- "version_minor": 0
|
|
|
|
|
- },
|
|
|
|
|
- "text/plain": [
|
|
|
|
|
- "consolidated.00.pth: 0%| | 0.00/13.5G [00:00<?, ?B/s]"
|
|
|
|
|
- ]
|
|
|
|
|
- },
|
|
|
|
|
- "metadata": {},
|
|
|
|
|
- "output_type": "display_data"
|
|
|
|
|
- },
|
|
|
|
|
{
|
|
{
|
|
|
"name": "stdout",
|
|
"name": "stdout",
|
|
|
"output_type": "stream",
|
|
"output_type": "stream",
|
|
|
"text": [
|
|
"text": [
|
|
|
"Output text:\n",
|
|
"Output text:\n",
|
|
|
" What do llamas eat?\n",
|
|
" What do llamas eat?\n",
|
|
|
- "Llamas and alpacas are herbivores, which means they eat grasses, leaves, grass\n"
|
|
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "Llamas are herbivores, which means they eat plants for their food. They feed on a variety\n"
|
|
|
]
|
|
]
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
@@ -1601,6 +1594,7 @@
|
|
|
" local_dir=\"Llama-2-7b-chat\"\n",
|
|
" local_dir=\"Llama-2-7b-chat\"\n",
|
|
|
")\n",
|
|
")\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
|
|
+ "weights = torch.load(weights_file, weights_only=True)\n",
|
|
|
"model = Llama2Model(LLAMA2_CONFIG_7B)\n",
|
|
"model = Llama2Model(LLAMA2_CONFIG_7B)\n",
|
|
|
"load_weights_into_llama(model, LLAMA2_CONFIG_7B, weights)\n",
|
|
"load_weights_into_llama(model, LLAMA2_CONFIG_7B, weights)\n",
|
|
|
"model.to(device);\n",
|
|
"model.to(device);\n",
|