فهرست منبع

fix: preserve newline tokens in BPE encoder (#495)

* fix: preserve newline tokens in BPE encoder

* further fixes

* more fixes

---------

Co-authored-by: rasbt <mail@sebastianraschka.com>
Austin Welch 10 ماه پیش
والد
کامیت
0f35e370ed
1فایلهای تغییر یافته به همراه126 افزوده شده و 81 حذف شده
  1. 126 81
      ch02/05_bpe-from-scratch/bpe-from-scratch.ipynb

+ 126 - 81
ch02/05_bpe-from-scratch/bpe-from-scratch.ipynb

@@ -39,7 +39,7 @@
     "- The BPE algorithm was originally described in 1994: \"[A New Algorithm for Data Compression](http://www.pennelynn.com/Documents/CUJ/HTML/94HTML/19940045.HTM)\" by Philip Gage\n",
     "- Most projects, including Llama 3, nowadays use OpenAI's open-source [tiktoken library](https://github.com/openai/tiktoken) due to its computational performance; it allows loading pretrained GPT-2 and GPT-4 tokenizers, for example (the Llama 3 models were trained using the GPT-4 tokenizer as well)\n",
     "- The difference between the implementations above and my implementation in this notebook, besides it being is that it also includes a function for training the tokenizer (for educational purposes)\n",
-    "- There's also an implementation called [minBPE](https://github.com/karpathy/minbpe) with training support, which is maybe more performant (my implementation here is focused on educational purposes); in contrast to `minbpe` my implementation additionally allows loading the original OpenAI tokenizer vocabulary and merges"
+    "- There's also an implementation called [minBPE](https://github.com/karpathy/minbpe) with training support, which is maybe more performant (my implementation here is focused on educational purposes); in contrast to `minbpe` my implementation additionally allows loading the original OpenAI tokenizer vocabulary and BPE \"merges\" (additionally, Hugging Face tokenizers are also capable of training and loading various tokenizers; see [this GitHub discussion](https://github.com/rasbt/LLMs-from-scratch/discussions/485) by a reader who trained a BPE tokenizer on the Nepali language for more info)"
    ]
   },
   {
@@ -382,7 +382,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 77,
+   "execution_count": 4,
    "id": "3e4a15ec-2667-4f56-b7c1-34e8071b621d",
    "metadata": {},
    "outputs": [],
@@ -431,8 +431,8 @@
     "        unique_chars.extend(char for char in sorted(set(processed_text)) if char not in unique_chars)\n",
     "\n",
     "        # Optionally, ensure 'Ġ' is included if it is relevant to your text processing\n",
-    "        if 'Ġ' not in unique_chars:\n",
-    "            unique_chars.append('Ġ')\n",
+    "        if \"Ġ\" not in unique_chars:\n",
+    "            unique_chars.append(\"Ġ\")\n",
     "\n",
     "        # Now create the vocab and inverse vocab dictionaries\n",
     "        self.vocab = {i: char for i, char in enumerate(unique_chars)}\n",
@@ -474,9 +474,23 @@
     "        # Load vocabulary\n",
     "        with open(vocab_path, \"r\", encoding=\"utf-8\") as file:\n",
     "            loaded_vocab = json.load(file)\n",
-    "            # loaded_vocab maps token_str to token_id\n",
-    "            self.vocab = {int(v): k for k, v in loaded_vocab.items()}  # token_id: token_str\n",
-    "            self.inverse_vocab = {k: int(v) for k, v in loaded_vocab.items()}  # token_str: token_id\n",
+    "            # Convert loaded vocabulary to correct format\n",
+    "            self.vocab = {int(v): k for k, v in loaded_vocab.items()}\n",
+    "            self.inverse_vocab = {k: int(v) for k, v in loaded_vocab.items()}\n",
+    "\n",
+    "        # Handle newline character without adding a new token\n",
+    "        if \"\\n\" not in self.inverse_vocab:\n",
+    "            # Use an existing token ID as a placeholder for '\\n'\n",
+    "            # Preferentially use \"<|endoftext|>\" if available\n",
+    "            fallback_token = next((token for token in [\"<|endoftext|>\", \"Ġ\", \"\"] if token in self.inverse_vocab), None)\n",
+    "            if fallback_token is not None:\n",
+    "                newline_token_id = self.inverse_vocab[fallback_token]\n",
+    "            else:\n",
+    "                # If no fallback token is available, raise an error\n",
+    "                raise KeyError(\"No suitable token found in vocabulary to map '\\\\n'.\")\n",
+    "\n",
+    "            self.inverse_vocab[\"\\n\"] = newline_token_id\n",
+    "            self.vocab[newline_token_id] = \"\\n\"\n",
     "\n",
     "        # Load BPE merges\n",
     "        with open(bpe_merges_path, \"r\", encoding=\"utf-8\") as file:\n",
@@ -487,17 +501,15 @@
     "\n",
     "            for rank, line in enumerate(lines):\n",
     "                pair = tuple(line.strip().split())\n",
-    "                if len(pair) != 2:\n",
-    "                    print(f\"Line {rank+1} has more than 2 entries: {line.strip()}\")\n",
-    "                    continue\n",
-    "                token1, token2 = pair\n",
-    "                if token1 in self.inverse_vocab and token2 in self.inverse_vocab:\n",
-    "                    token_id1 = self.inverse_vocab[token1]\n",
-    "                    token_id2 = self.inverse_vocab[token2]\n",
-    "                    merged_token = token1 + token2\n",
-    "                    if merged_token in self.inverse_vocab:\n",
-    "                        merged_token_id = self.inverse_vocab[merged_token]\n",
-    "                        self.bpe_merges[(token_id1, token_id2)] = merged_token_id\n",
+    "                if len(pair) == 2:\n",
+    "                    token1, token2 = pair\n",
+    "                    if token1 in self.inverse_vocab and token2 in self.inverse_vocab:\n",
+    "                        token_id1 = self.inverse_vocab[token1]\n",
+    "                        token_id2 = self.inverse_vocab[token2]\n",
+    "                        merged_token = token1 + token2\n",
+    "                        if merged_token in self.inverse_vocab:\n",
+    "                            merged_token_id = self.inverse_vocab[merged_token]\n",
+    "                            self.bpe_merges[(token_id1, token_id2)] = merged_token_id\n",
     "                        # print(f\"Loaded merge: '{token1}' + '{token2}' -> '{merged_token}' (ID: {merged_token_id})\")\n",
     "                    else:\n",
     "                        print(f\"Merged token '{merged_token}' not found in vocab. Skipping.\")\n",
@@ -515,21 +527,27 @@
     "            List[int]: The list of token IDs.\n",
     "        \"\"\"\n",
     "        tokens = []\n",
-    "        # Split text into tokens, keeping newlines intact\n",
-    "        words = text.replace(\"\\n\", \" \\n \").split()  # Ensure '\\n' is treated as a separate token\n",
-    "\n",
-    "        for i, word in enumerate(words):\n",
-    "            if i > 0 and not word.startswith(\"\\n\"):\n",
-    "                tokens.append(\"Ġ\" + word)  # Add 'Ġ' to words that follow a space or newline\n",
-    "            else:\n",
-    "                tokens.append(word)  # Handle first word or standalone '\\n'\n",
+    "        # First split on newlines to preserve them\n",
+    "        lines = text.split(\"\\n\")\n",
+    "        for i, line in enumerate(lines):\n",
+    "            if i > 0:\n",
+    "                tokens.append(\"\\n\")  # Add newline token separately\n",
+    "            words = line.split()\n",
+    "            for j, word in enumerate(words):\n",
+    "                if j == 0:\n",
+    "                    if i > 0:  # Start of a new line but not the first line\n",
+    "                        tokens.append(\"Ġ\" + word)  # Ensure it's marked as a new segment\n",
+    "                    else:\n",
+    "                        tokens.append(word)\n",
+    "                else:\n",
+    "                    # Prefix words in the middle of a line with 'Ġ'\n",
+    "                    tokens.append(\"Ġ\" + word)\n",
     "\n",
     "        token_ids = []\n",
     "        for token in tokens:\n",
     "            if token in self.inverse_vocab:\n",
     "                # token is contained in the vocabulary as is\n",
-    "                token_id = self.inverse_vocab[token]\n",
-    "                token_ids.append(token_id)\n",
+    "                token_ids.append(self.inverse_vocab[token])\n",
     "            else:\n",
     "                # Attempt to handle subword tokenization via BPE\n",
     "                sub_token_ids = self.tokenize_with_bpe(token)\n",
@@ -587,12 +605,15 @@
     "            str: The decoded string.\n",
     "        \"\"\"\n",
     "        decoded_string = \"\"\n",
-    "        for token_id in token_ids:\n",
+    "        for i, token_id in enumerate(token_ids):\n",
     "            if token_id not in self.vocab:\n",
     "                raise ValueError(f\"Token ID {token_id} not found in vocab.\")\n",
     "            token = self.vocab[token_id]\n",
-    "            if token.startswith(\"Ġ\"):\n",
-    "                # Replace 'Ġ' with a space\n",
+    "            if token == \"\\n\":\n",
+    "                if decoded_string and not decoded_string.endswith(\" \"):\n",
+    "                    decoded_string += \" \"  # Add space if not present before a newline\n",
+    "                decoded_string += token\n",
+    "            elif token.startswith(\"Ġ\"):\n",
     "                decoded_string += \" \" + token[1:]\n",
     "            else:\n",
     "                decoded_string += token\n",
@@ -634,8 +655,8 @@
     "        with open(bpe_merges_path, \"r\", encoding=\"utf-8\") as file:\n",
     "            merges_list = json.load(file)\n",
     "            for merge in merges_list:\n",
-    "                pair = tuple(merge['pair'])\n",
-    "                new_id = merge['new_id']\n",
+    "                pair = tuple(merge[\"pair\"])\n",
+    "                new_id = merge[\"new_id\"]\n",
     "                self.bpe_merges[pair] = new_id\n",
     "\n",
     "    @lru_cache(maxsize=None)\n",
@@ -714,7 +735,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 78,
+   "execution_count": 5,
    "id": "4d197cad-ed10-4a42-b01c-a763859781fb",
    "metadata": {},
    "outputs": [],
@@ -745,7 +766,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 79,
+   "execution_count": 6,
    "id": "027348fd-d52f-4396-93dd-38eed142df9b",
    "metadata": {},
    "outputs": [],
@@ -764,7 +785,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 80,
+   "execution_count": 7,
    "id": "f705a283-355e-4460-b940-06bbc2ae4e61",
    "metadata": {},
    "outputs": [
@@ -791,7 +812,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 81,
+   "execution_count": 8,
    "id": "3da42d1c-f75c-4ba7-a6c5-4cb8543d4a44",
    "metadata": {},
    "outputs": [
@@ -825,7 +846,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 82,
+   "execution_count": 9,
    "id": "e1db5cce-e015-412b-ad56-060b8b638078",
    "metadata": {},
    "outputs": [
@@ -845,7 +866,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 83,
+   "execution_count": 10,
    "id": "1ed1b344-f7d4-4e9e-ac34-2a04b5c5b7a8",
    "metadata": {},
    "outputs": [
@@ -881,7 +902,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 84,
+   "execution_count": 11,
    "id": "da0e1faf-1933-43d9-b681-916c282a8f86",
    "metadata": {},
    "outputs": [
@@ -899,7 +920,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 85,
+   "execution_count": 12,
    "id": "8b690e83-5d6b-409a-804e-321c287c24a4",
    "metadata": {},
    "outputs": [
@@ -925,7 +946,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 86,
+   "execution_count": 13,
    "id": "2b9e6289-92cb-4d88-b3c8-e836d7c8095f",
    "metadata": {},
    "outputs": [
@@ -979,7 +1000,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 87,
+   "execution_count": 14,
    "id": "c7056cb1-a9a3-4cf6-8364-29fb493ae240",
    "metadata": {},
    "outputs": [
@@ -989,13 +1010,38 @@
        "'This is some text.'"
       ]
      },
-     "execution_count": 87,
+     "execution_count": 14,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "tokenizer.decode(\n",
+    "    tokenizer.encode(\"This is some text.\")\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "id": "37bc6753-8f35-4ec7-b23e-df4a12103cb4",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "'This is some text with \\n newline characters.'"
+      ]
+     },
+     "execution_count": 15,
      "metadata": {},
      "output_type": "execute_result"
     }
    ],
    "source": [
-    "tokenizer.decode(tokenizer.encode(\"This is some text.\"))"
+    "tokenizer.decode(\n",
+    "    tokenizer.encode(\"This is some text with \\n newline characters.\")\n",
+    ")"
    ]
   },
   {
@@ -1016,7 +1062,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 88,
+   "execution_count": 16,
    "id": "955181cb-0910-4c6a-9c22-d8292a3ec1fc",
    "metadata": {},
    "outputs": [],
@@ -1027,7 +1073,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 89,
+   "execution_count": 17,
    "id": "6e5ccfe7-ac67-42f3-b727-87886a8867f1",
    "metadata": {},
    "outputs": [],
@@ -1047,7 +1093,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 90,
+   "execution_count": 18,
    "id": "00d9bf8f-756f-48bf-81b8-b890e2c2ef13",
    "metadata": {},
    "outputs": [
@@ -1063,6 +1109,29 @@
     "print(tokenizer2.decode(token_ids))"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "id": "e7addb64-2892-4e1c-85dd-4f5152740099",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "'This is some text with \\n newline characters.'"
+      ]
+     },
+     "execution_count": 19,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "tokenizer2.decode(\n",
+    "    tokenizer2.encode(\"This is some text with \\n newline characters.\")\n",
+    ")"
+   ]
+  },
   {
    "cell_type": "markdown",
    "id": "b24d10b2-1ab8-44ee-b51a-14248e30d662",
@@ -1082,7 +1151,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 91,
+   "execution_count": 20,
    "id": "b45b4366-2c2b-4309-9a14-febf3add8512",
    "metadata": {},
    "outputs": [
@@ -1090,8 +1159,8 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "vocab.bpe already exists\n",
-      "encoder.json already exists\n"
+      "vocab.bpe already exists in ../02_bonus_bytepair-encoder/gpt2_model/vocab.bpe\n",
+      "encoder.json already exists in ../02_bonus_bytepair-encoder/gpt2_model/encoder.json\n"
      ]
     }
    ],
@@ -1139,7 +1208,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 92,
+   "execution_count": 21,
    "id": "74306e6c-47d3-45a3-9e0f-93f7303ef601",
    "metadata": {},
    "outputs": [],
@@ -1160,7 +1229,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 93,
+   "execution_count": 22,
    "id": "2bb722b4-dbf5-4a0c-9120-efda3293f132",
    "metadata": {},
    "outputs": [
@@ -1170,7 +1239,7 @@
        "50257"
       ]
      },
-     "execution_count": 93,
+     "execution_count": 22,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -1189,7 +1258,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 97,
+   "execution_count": 23,
    "id": "e4866de7-fb32-4dd6-a878-469ec734641c",
    "metadata": {},
    "outputs": [
@@ -1209,7 +1278,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 98,
+   "execution_count": 24,
    "id": "3da8d9b2-af55-4b09-95d7-fabd983e919e",
    "metadata": {},
    "outputs": [
@@ -1225,30 +1294,6 @@
     "print(tokenizer_gpt2.decode(token_ids))"
    ]
   },
-  {
-   "cell_type": "code",
-   "execution_count": 99,
-   "id": "460deb85-8de7-40c7-ba18-3c17831fa8ab",
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "[1212, 318, 617, 2420]"
-      ]
-     },
-     "execution_count": 99,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "import tiktoken\n",
-    "\n",
-    "tik_tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
-    "tik_tokenizer.encode(input_text)"
-   ]
-  },
   {
    "cell_type": "markdown",
    "id": "b3b1e2dc-f69b-4533-87ef-549e6fb9b5a0",
@@ -1303,7 +1348,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.11.4"
+   "version": "3.10.6"
   }
  },
  "nbformat": 4,