Browse Source

add GPT2TokenizerFast to BPE comparison (#498)

* added HF BPE Fast

* update benchmarks

* add note about performance

* revert accidental changes

---------

Co-authored-by: rasbt <mail@sebastianraschka.com>
Daniel Kleine 10 months ago
parent
commit
dce46038da
1 changed files with 92 additions and 16 deletions
  1. 92 16
      ch02/02_bonus_bytepair-encoder/compare-bpe-tiktoken.ipynb

+ 92 - 16
ch02/02_bonus_bytepair-encoder/compare-bpe-tiktoken.ipynb

@@ -180,8 +180,8 @@
      "name": "stderr",
      "name": "stderr",
      "output_type": "stream",
      "output_type": "stream",
      "text": [
      "text": [
-      "Fetching encoder.json: 1.04Mit [00:00, 3.47Mit/s]                                                   \n",
-      "Fetching vocab.bpe: 457kit [00:00, 2.07Mit/s]                                                       \n"
+      "Fetching encoder.json: 1.04Mit [00:00, 4.13Mit/s]                                                   \n",
+      "Fetching vocab.bpe: 457kit [00:00, 2.56Mit/s]                                                       \n"
      ]
      ]
     }
     }
    ],
    ],
@@ -306,6 +306,39 @@
     "hf_tokenizer(strings)[\"input_ids\"]"
     "hf_tokenizer(strings)[\"input_ids\"]"
    ]
    ]
   },
   },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "id": "a6233552",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from transformers import GPT2TokenizerFast\n",
+    "\n",
+    "hf_tokenizer_fast = GPT2TokenizerFast.from_pretrained(\"gpt2\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "id": "fa5ca643",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "[15496, 11, 995, 13, 1148, 428, 438, 257, 1332, 30]"
+      ]
+     },
+     "execution_count": 16,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "hf_tokenizer_fast(strings)[\"input_ids\"]"
+   ]
+  },
   {
   {
    "cell_type": "markdown",
    "cell_type": "markdown",
    "id": "9d0f2e95-8ae8-4606-a8e0-b0fce91cfac9",
    "id": "9d0f2e95-8ae8-4606-a8e0-b0fce91cfac9",
@@ -319,7 +352,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 17,
    "id": "b6e6b1a5-9dc0-4b20-9a8b-c02aa0e3191c",
    "id": "b6e6b1a5-9dc0-4b20-9a8b-c02aa0e3191c",
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
@@ -365,7 +398,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 16,
+   "execution_count": 18,
    "id": "04fbd764-ec98-44f1-9b0a-e9db9a3bb91e",
    "id": "04fbd764-ec98-44f1-9b0a-e9db9a3bb91e",
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
@@ -382,7 +415,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 17,
+   "execution_count": 19,
    "id": "5a5def88-1d2c-4550-a5e8-ee82b72b92d7",
    "id": "5a5def88-1d2c-4550-a5e8-ee82b72b92d7",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -413,7 +446,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 18,
+   "execution_count": 20,
    "id": "a61bb445-b151-4a2f-8180-d4004c503754",
    "id": "a61bb445-b151-4a2f-8180-d4004c503754",
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
@@ -432,7 +465,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 19,
+   "execution_count": 21,
    "id": "57f7c0a3-c1fd-4313-af34-68e78eb33653",
    "id": "57f7c0a3-c1fd-4313-af34-68e78eb33653",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -440,7 +473,7 @@
      "name": "stdout",
      "name": "stdout",
      "output_type": "stream",
      "output_type": "stream",
      "text": [
      "text": [
-      "3.44 ms ± 54 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
+      "3.39 ms ± 21.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
      ]
      ]
     }
     }
    ],
    ],
@@ -458,7 +491,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 20,
+   "execution_count": 22,
    "id": "036dd628-3591-46c9-a5ce-b20b105a8062",
    "id": "036dd628-3591-46c9-a5ce-b20b105a8062",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -466,7 +499,7 @@
      "name": "stdout",
      "name": "stdout",
      "output_type": "stream",
      "output_type": "stream",
      "text": [
      "text": [
-      "1.08 ms ± 4.69 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
+      "1.08 ms ± 5.99 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
      ]
      ]
     }
     }
    ],
    ],
@@ -484,7 +517,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 21,
+   "execution_count": 23,
    "id": "b9c85b58-bfbc-465e-9a7e-477e53d55c90",
    "id": "b9c85b58-bfbc-465e-9a7e-477e53d55c90",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -499,7 +532,7 @@
      "name": "stdout",
      "name": "stdout",
      "output_type": "stream",
      "output_type": "stream",
      "text": [
      "text": [
-      "10.3 ms ± 180 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
+      "10.2 ms ± 115 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
      ]
      ]
     }
     }
    ],
    ],
@@ -509,7 +542,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 22,
+   "execution_count": 24,
    "id": "7117107f-22a6-46b4-a442-712d50b3ac7a",
    "id": "7117107f-22a6-46b4-a442-712d50b3ac7a",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -517,7 +550,7 @@
      "name": "stdout",
      "name": "stdout",
      "output_type": "stream",
      "output_type": "stream",
      "text": [
      "text": [
-      "10.2 ms ± 72.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
+      "10 ms ± 36.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
      ]
      ]
     }
     }
    ],
    ],
@@ -525,6 +558,49 @@
     "%timeit hf_tokenizer(raw_text, max_length=5145, truncation=True)[\"input_ids\"]"
     "%timeit hf_tokenizer(raw_text, max_length=5145, truncation=True)[\"input_ids\"]"
    ]
    ]
   },
   },
+  {
+   "cell_type": "code",
+   "execution_count": 25,
+   "id": "d6bfc7f0",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Token indices sequence length is longer than the specified maximum sequence length for this model (5145 > 1024). Running this sequence through the model will result in indexing errors\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "3.79 ms ± 48.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
+     ]
+    }
+   ],
+   "source": [
+    "%timeit hf_tokenizer_fast(raw_text)[\"input_ids\"]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 26,
+   "id": "da57c95a",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "3.83 ms ± 58.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
+     ]
+    }
+   ],
+   "source": [
+    "%timeit hf_tokenizer_fast(raw_text, max_length=5145, truncation=True)[\"input_ids\"]"
+   ]
+  },
   {
   {
    "cell_type": "markdown",
    "cell_type": "markdown",
    "id": "91ac2876-f36e-498c-bd75-8597a39f2d4b",
    "id": "91ac2876-f36e-498c-bd75-8597a39f2d4b",
@@ -535,7 +611,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 23,
+   "execution_count": 27,
    "id": "3b4ff4d5-f2d9-4ea6-a51c-023dbba15429",
    "id": "3b4ff4d5-f2d9-4ea6-a51c-023dbba15429",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -543,7 +619,7 @@
      "name": "stdout",
      "name": "stdout",
      "output_type": "stream",
      "output_type": "stream",
      "text": [
      "text": [
-      "1.74 ms ± 48.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
+      "1.59 ms ± 11.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
      ]
      ]
     }
     }
    ],
    ],