|
@@ -382,7 +382,7 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
|
|
|
|
+ "execution_count": 4,
|
|
|
"id": "3e4a15ec-2667-4f56-b7c1-34e8071b621d",
|
|
"id": "3e4a15ec-2667-4f56-b7c1-34e8071b621d",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
"outputs": [],
|
|
@@ -401,6 +401,10 @@
|
|
|
" # Dictionary of BPE merges: {(token_id1, token_id2): merged_token_id}\n",
|
|
" # Dictionary of BPE merges: {(token_id1, token_id2): merged_token_id}\n",
|
|
|
" self.bpe_merges = {}\n",
|
|
" self.bpe_merges = {}\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
|
|
+ " # For the official OpenAI GPT-2 merges, use a rank dict:\n",
|
|
|
|
|
+ " # of form {(string_A, string_B): rank}, where lower rank = higher priority\n",
|
|
|
|
|
+ " self.bpe_ranks = {}\n",
|
|
|
|
|
+ "\n",
|
|
|
" def train(self, text, vocab_size, allowed_special={\"<|endoftext|>\"}):\n",
|
|
" def train(self, text, vocab_size, allowed_special={\"<|endoftext|>\"}):\n",
|
|
|
" \"\"\"\n",
|
|
" \"\"\"\n",
|
|
|
" Train the BPE tokenizer from scratch.\n",
|
|
" Train the BPE tokenizer from scratch.\n",
|
|
@@ -411,7 +415,7 @@
|
|
|
" allowed_special (set): A set of special tokens to include.\n",
|
|
" allowed_special (set): A set of special tokens to include.\n",
|
|
|
" \"\"\"\n",
|
|
" \"\"\"\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
- " # Preprocess: Replace spaces with 'Ġ'\n",
|
|
|
|
|
|
|
+ " # Preprocess: Replace spaces with \"Ġ\"\n",
|
|
|
" # Note that Ġ is a particularity of the GPT-2 BPE implementation\n",
|
|
" # Note that Ġ is a particularity of the GPT-2 BPE implementation\n",
|
|
|
" # E.g., \"Hello world\" might be tokenized as [\"Hello\", \"Ġworld\"]\n",
|
|
" # E.g., \"Hello world\" might be tokenized as [\"Hello\", \"Ġworld\"]\n",
|
|
|
" # (GPT-4 BPE would tokenize it as [\"Hello\", \" world\"])\n",
|
|
" # (GPT-4 BPE would tokenize it as [\"Hello\", \" world\"])\n",
|
|
@@ -423,18 +427,16 @@
|
|
|
" processed_text.append(char)\n",
|
|
" processed_text.append(char)\n",
|
|
|
" processed_text = \"\".join(processed_text)\n",
|
|
" processed_text = \"\".join(processed_text)\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
- " # Initialize vocab with unique characters, including 'Ġ' if present\n",
|
|
|
|
|
|
|
+ " # Initialize vocab with unique characters, including \"Ġ\" if present\n",
|
|
|
" # Start with the first 256 ASCII characters\n",
|
|
" # Start with the first 256 ASCII characters\n",
|
|
|
" unique_chars = [chr(i) for i in range(256)]\n",
|
|
" unique_chars = [chr(i) for i in range(256)]\n",
|
|
|
- "\n",
|
|
|
|
|
- " # Extend unique_chars with characters from processed_text that are not already included\n",
|
|
|
|
|
- " unique_chars.extend(char for char in sorted(set(processed_text)) if char not in unique_chars)\n",
|
|
|
|
|
- "\n",
|
|
|
|
|
- " # Optionally, ensure 'Ġ' is included if it is relevant to your text processing\n",
|
|
|
|
|
|
|
+ " unique_chars.extend(\n",
|
|
|
|
|
+ " char for char in sorted(set(processed_text))\n",
|
|
|
|
|
+ " if char not in unique_chars\n",
|
|
|
|
|
+ " )\n",
|
|
|
" if \"Ġ\" not in unique_chars:\n",
|
|
" if \"Ġ\" not in unique_chars:\n",
|
|
|
" unique_chars.append(\"Ġ\")\n",
|
|
" unique_chars.append(\"Ġ\")\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
- " # Now create the vocab and inverse vocab dictionaries\n",
|
|
|
|
|
" self.vocab = {i: char for i, char in enumerate(unique_chars)}\n",
|
|
" self.vocab = {i: char for i, char in enumerate(unique_chars)}\n",
|
|
|
" self.inverse_vocab = {char: i for i, char in self.vocab.items()}\n",
|
|
" self.inverse_vocab = {char: i for i, char in self.vocab.items()}\n",
|
|
|
"\n",
|
|
"\n",
|
|
@@ -452,7 +454,7 @@
|
|
|
" # BPE steps 1-3: Repeatedly find and replace frequent pairs\n",
|
|
" # BPE steps 1-3: Repeatedly find and replace frequent pairs\n",
|
|
|
" for new_id in range(len(self.vocab), vocab_size):\n",
|
|
" for new_id in range(len(self.vocab), vocab_size):\n",
|
|
|
" pair_id = self.find_freq_pair(token_ids, mode=\"most\")\n",
|
|
" pair_id = self.find_freq_pair(token_ids, mode=\"most\")\n",
|
|
|
- " if pair_id is None: # No more pairs to merge. Stopping training.\n",
|
|
|
|
|
|
|
+ " if pair_id is None:\n",
|
|
|
" break\n",
|
|
" break\n",
|
|
|
" token_ids = self.replace_pair(token_ids, pair_id, new_id)\n",
|
|
" token_ids = self.replace_pair(token_ids, pair_id, new_id)\n",
|
|
|
" self.bpe_merges[pair_id] = new_id\n",
|
|
" self.bpe_merges[pair_id] = new_id\n",
|
|
@@ -492,29 +494,24 @@
|
|
|
" self.inverse_vocab[\"\\n\"] = newline_token_id\n",
|
|
" self.inverse_vocab[\"\\n\"] = newline_token_id\n",
|
|
|
" self.vocab[newline_token_id] = \"\\n\"\n",
|
|
" self.vocab[newline_token_id] = \"\\n\"\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
- " # Load BPE merges\n",
|
|
|
|
|
|
|
+ " # Load GPT-2 merges and store them with an assigned \"rank\"\n",
|
|
|
|
|
+ " self.bpe_ranks = {} # reset ranks\n",
|
|
|
" with open(bpe_merges_path, \"r\", encoding=\"utf-8\") as file:\n",
|
|
" with open(bpe_merges_path, \"r\", encoding=\"utf-8\") as file:\n",
|
|
|
" lines = file.readlines()\n",
|
|
" lines = file.readlines()\n",
|
|
|
- " # Skip header line if present\n",
|
|
|
|
|
" if lines and lines[0].startswith(\"#\"):\n",
|
|
" if lines and lines[0].startswith(\"#\"):\n",
|
|
|
" lines = lines[1:]\n",
|
|
" lines = lines[1:]\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
|
|
+ " rank = 0\n",
|
|
|
" for line in lines:\n",
|
|
" for line in lines:\n",
|
|
|
" pair = tuple(line.strip().split())\n",
|
|
" pair = tuple(line.strip().split())\n",
|
|
|
" if len(pair) == 2:\n",
|
|
" if len(pair) == 2:\n",
|
|
|
" token1, token2 = pair\n",
|
|
" token1, token2 = pair\n",
|
|
|
|
|
+ " # If token1 or token2 not in vocab, skip\n",
|
|
|
" if token1 in self.inverse_vocab and token2 in self.inverse_vocab:\n",
|
|
" if token1 in self.inverse_vocab and token2 in self.inverse_vocab:\n",
|
|
|
- " token_id1 = self.inverse_vocab[token1]\n",
|
|
|
|
|
- " token_id2 = self.inverse_vocab[token2]\n",
|
|
|
|
|
- " merged_token = token1 + token2\n",
|
|
|
|
|
- " if merged_token in self.inverse_vocab:\n",
|
|
|
|
|
- " merged_token_id = self.inverse_vocab[merged_token]\n",
|
|
|
|
|
- " self.bpe_merges[(token_id1, token_id2)] = merged_token_id\n",
|
|
|
|
|
- " # print(f\"Loaded merge: '{token1}' + '{token2}' -> '{merged_token}' (ID: {merged_token_id})\")\n",
|
|
|
|
|
- " else:\n",
|
|
|
|
|
- " print(f\"Merged token '{merged_token}' not found in vocab. Skipping.\")\n",
|
|
|
|
|
|
|
+ " self.bpe_ranks[(token1, token2)] = rank\n",
|
|
|
|
|
+ " rank += 1\n",
|
|
|
" else:\n",
|
|
" else:\n",
|
|
|
- " print(f\"Skipping pair {pair} as one of the tokens is not in the vocabulary.\")\n",
|
|
|
|
|
|
|
+ " print(f\"Skipping pair {pair} as one token is not in the vocabulary.\")\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
" def encode(self, text):\n",
|
|
" def encode(self, text):\n",
|
|
|
" \"\"\"\n",
|
|
" \"\"\"\n",
|
|
@@ -540,7 +537,7 @@
|
|
|
" else:\n",
|
|
" else:\n",
|
|
|
" tokens.append(word)\n",
|
|
" tokens.append(word)\n",
|
|
|
" else:\n",
|
|
" else:\n",
|
|
|
- " # Prefix words in the middle of a line with 'Ġ'\n",
|
|
|
|
|
|
|
+ " # Prefix words in the middle of a line with \"Ġ\"\n",
|
|
|
" tokens.append(\"Ġ\" + word)\n",
|
|
" tokens.append(\"Ġ\" + word)\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
" token_ids = []\n",
|
|
" token_ids = []\n",
|
|
@@ -571,28 +568,74 @@
|
|
|
" missing_chars = [char for char, tid in zip(token, token_ids) if tid is None]\n",
|
|
" missing_chars = [char for char, tid in zip(token, token_ids) if tid is None]\n",
|
|
|
" raise ValueError(f\"Characters not found in vocab: {missing_chars}\")\n",
|
|
" raise ValueError(f\"Characters not found in vocab: {missing_chars}\")\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
- " can_merge = True\n",
|
|
|
|
|
- " while can_merge and len(token_ids) > 1:\n",
|
|
|
|
|
- " can_merge = False\n",
|
|
|
|
|
- " new_tokens = []\n",
|
|
|
|
|
|
|
+ " # If we haven't loaded OpenAI's GPT-2 merges, use my approach\n",
|
|
|
|
|
+ " if not self.bpe_ranks:\n",
|
|
|
|
|
+ " can_merge = True\n",
|
|
|
|
|
+ " while can_merge and len(token_ids) > 1:\n",
|
|
|
|
|
+ " can_merge = False\n",
|
|
|
|
|
+ " new_tokens = []\n",
|
|
|
|
|
+ " i = 0\n",
|
|
|
|
|
+ " while i < len(token_ids) - 1:\n",
|
|
|
|
|
+ " pair = (token_ids[i], token_ids[i + 1])\n",
|
|
|
|
|
+ " if pair in self.bpe_merges:\n",
|
|
|
|
|
+ " merged_token_id = self.bpe_merges[pair]\n",
|
|
|
|
|
+ " new_tokens.append(merged_token_id)\n",
|
|
|
|
|
+ " # Uncomment for educational purposes:\n",
|
|
|
|
|
+ " # print(f\"Merged pair {pair} -> {merged_token_id} ('{self.vocab[merged_token_id]}')\")\n",
|
|
|
|
|
+ " i += 2 # Skip the next token as it's merged\n",
|
|
|
|
|
+ " can_merge = True\n",
|
|
|
|
|
+ " else:\n",
|
|
|
|
|
+ " new_tokens.append(token_ids[i])\n",
|
|
|
|
|
+ " i += 1\n",
|
|
|
|
|
+ " if i < len(token_ids):\n",
|
|
|
|
|
+ " new_tokens.append(token_ids[i])\n",
|
|
|
|
|
+ " token_ids = new_tokens\n",
|
|
|
|
|
+ " return token_ids\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # Otherwise, do GPT-2-style merging with the ranks:\n",
|
|
|
|
|
+ " # 1) Convert token_ids back to string \"symbols\" for each ID\n",
|
|
|
|
|
+ " symbols = [self.vocab[id_num] for id_num in token_ids]\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # Repeatedly merge all occurrences of the lowest-rank pair\n",
|
|
|
|
|
+ " while True:\n",
|
|
|
|
|
+ " # Collect all adjacent pairs\n",
|
|
|
|
|
+ " pairs = set(zip(symbols, symbols[1:]))\n",
|
|
|
|
|
+ " if not pairs:\n",
|
|
|
|
|
+ " break\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # Find the pair with the best (lowest) rank\n",
|
|
|
|
|
+ " min_rank = 1_000_000_000\n",
|
|
|
|
|
+ " bigram = None\n",
|
|
|
|
|
+ " for p in pairs:\n",
|
|
|
|
|
+ " r = self.bpe_ranks.get(p, 1_000_000_000)\n",
|
|
|
|
|
+ " if r < min_rank:\n",
|
|
|
|
|
+ " min_rank = r\n",
|
|
|
|
|
+ " bigram = p\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # If no valid ranked pair is present, we're done\n",
|
|
|
|
|
+ " if bigram is None or bigram not in self.bpe_ranks:\n",
|
|
|
|
|
+ " break\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # Merge all occurrences of that pair\n",
|
|
|
|
|
+ " first, second = bigram\n",
|
|
|
|
|
+ " new_symbols = []\n",
|
|
|
" i = 0\n",
|
|
" i = 0\n",
|
|
|
- " while i < len(token_ids) - 1:\n",
|
|
|
|
|
- " pair = (token_ids[i], token_ids[i + 1])\n",
|
|
|
|
|
- " if pair in self.bpe_merges:\n",
|
|
|
|
|
- " merged_token_id = self.bpe_merges[pair]\n",
|
|
|
|
|
- " new_tokens.append(merged_token_id)\n",
|
|
|
|
|
- " # Uncomment for educational purposes:\n",
|
|
|
|
|
- " # print(f\"Merged pair {pair} -> {merged_token_id} ('{self.vocab[merged_token_id]}')\")\n",
|
|
|
|
|
- " i += 2 # Skip the next token as it's merged\n",
|
|
|
|
|
- " can_merge = True\n",
|
|
|
|
|
|
|
+ " while i < len(symbols):\n",
|
|
|
|
|
+ " # If we see (first, second) at position i, merge them\n",
|
|
|
|
|
+ " if i < len(symbols) - 1 and symbols[i] == first and symbols[i+1] == second:\n",
|
|
|
|
|
+ " new_symbols.append(first + second) # merged symbol\n",
|
|
|
|
|
+ " i += 2\n",
|
|
|
" else:\n",
|
|
" else:\n",
|
|
|
- " new_tokens.append(token_ids[i])\n",
|
|
|
|
|
|
|
+ " new_symbols.append(symbols[i])\n",
|
|
|
" i += 1\n",
|
|
" i += 1\n",
|
|
|
- " if i < len(token_ids):\n",
|
|
|
|
|
- " new_tokens.append(token_ids[i])\n",
|
|
|
|
|
- " token_ids = new_tokens\n",
|
|
|
|
|
|
|
+ " symbols = new_symbols\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
- " return token_ids\n",
|
|
|
|
|
|
|
+ " if len(symbols) == 1:\n",
|
|
|
|
|
+ " break\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # Finally, convert merged symbols back to IDs\n",
|
|
|
|
|
+ " merged_ids = [self.inverse_vocab[sym] for sym in symbols]\n",
|
|
|
|
|
+ " return merged_ids\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
" def decode(self, token_ids):\n",
|
|
" def decode(self, token_ids):\n",
|
|
|
" \"\"\"\n",
|
|
" \"\"\"\n",
|
|
@@ -738,22 +781,49 @@
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"cell_type": "code",
|
|
"cell_type": "code",
|
|
|
- "execution_count": 5,
|
|
|
|
|
- "id": "4d197cad-ed10-4a42-b01c-a763859781fb",
|
|
|
|
|
|
|
+ "execution_count": 25,
|
|
|
|
|
+ "id": "51872c08-e01b-40c3-a8a0-e8d6a773e3df",
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
- "outputs": [],
|
|
|
|
|
|
|
+ "outputs": [
|
|
|
|
|
+ {
|
|
|
|
|
+ "name": "stdout",
|
|
|
|
|
+ "output_type": "stream",
|
|
|
|
|
+ "text": [
|
|
|
|
|
+ "the-verdict.txt already exists in ./the-verdict.txt\n"
|
|
|
|
|
+ ]
|
|
|
|
|
+ }
|
|
|
|
|
+ ],
|
|
|
"source": [
|
|
"source": [
|
|
|
"import os\n",
|
|
"import os\n",
|
|
|
"import urllib.request\n",
|
|
"import urllib.request\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
- "if not os.path.exists(\"../01_main-chapter-code/the-verdict.txt\"):\n",
|
|
|
|
|
- " url = (\"https://raw.githubusercontent.com/rasbt/\"\n",
|
|
|
|
|
- " \"LLMs-from-scratch/main/ch02/01_main-chapter-code/\"\n",
|
|
|
|
|
- " \"the-verdict.txt\")\n",
|
|
|
|
|
- " file_path = \"../01_main-chapter-code/the-verdict.txt\"\n",
|
|
|
|
|
- " urllib.request.urlretrieve(url, file_path)\n",
|
|
|
|
|
|
|
+ "def download_file_if_absent(url, filename, search_dirs):\n",
|
|
|
|
|
+ " for directory in search_dirs:\n",
|
|
|
|
|
+ " file_path = os.path.join(directory, filename)\n",
|
|
|
|
|
+ " if os.path.exists(file_path):\n",
|
|
|
|
|
+ " print(f\"{filename} already exists in {file_path}\")\n",
|
|
|
|
|
+ " return file_path\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " target_path = os.path.join(search_dirs[0], filename)\n",
|
|
|
|
|
+ " try:\n",
|
|
|
|
|
+ " with urllib.request.urlopen(url) as response, open(target_path, \"wb\") as out_file:\n",
|
|
|
|
|
+ " out_file.write(response.read())\n",
|
|
|
|
|
+ " print(f\"Downloaded {filename} to {target_path}\")\n",
|
|
|
|
|
+ " except Exception as e:\n",
|
|
|
|
|
+ " print(f\"Failed to download {filename}. Error: {e}\")\n",
|
|
|
|
|
+ " return target_path\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
- "with open(\"../01_main-chapter-code/the-verdict.txt\", \"r\", encoding=\"utf-8\") as f: # added ../01_main-chapter-code/\n",
|
|
|
|
|
|
|
+ "verdict_path = download_file_if_absent(\n",
|
|
|
|
|
+ " url=(\n",
|
|
|
|
|
+ " \"https://raw.githubusercontent.com/rasbt/\"\n",
|
|
|
|
|
+ " \"LLMs-from-scratch/main/ch02/01_main-chapter-code/\"\n",
|
|
|
|
|
+ " \"the-verdict.txt\"\n",
|
|
|
|
|
+ " ),\n",
|
|
|
|
|
+ " filename=\"the-verdict.txt\",\n",
|
|
|
|
|
+ " search_dirs=\".\"\n",
|
|
|
|
|
+ ")\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "with open(verdict_path, \"r\", encoding=\"utf-8\") as f: # added ../01_main-chapter-code/\n",
|
|
|
" text = f.read()"
|
|
" text = f.read()"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
@@ -1168,24 +1238,7 @@
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
|
"source": [
|
|
"source": [
|
|
|
- "import os\n",
|
|
|
|
|
- "import urllib.request\n",
|
|
|
|
|
- "\n",
|
|
|
|
|
- "def download_file_if_absent(url, filename, search_dirs):\n",
|
|
|
|
|
- " for directory in search_dirs:\n",
|
|
|
|
|
- " file_path = os.path.join(directory, filename)\n",
|
|
|
|
|
- " if os.path.exists(file_path):\n",
|
|
|
|
|
- " print(f\"{filename} already exists in {file_path}\")\n",
|
|
|
|
|
- " return file_path\n",
|
|
|
|
|
- "\n",
|
|
|
|
|
- " target_path = os.path.join(search_dirs[0], filename)\n",
|
|
|
|
|
- " try:\n",
|
|
|
|
|
- " with urllib.request.urlopen(url) as response, open(target_path, \"wb\") as out_file:\n",
|
|
|
|
|
- " out_file.write(response.read())\n",
|
|
|
|
|
- " print(f\"Downloaded {filename} to {target_path}\")\n",
|
|
|
|
|
- " except Exception as e:\n",
|
|
|
|
|
- " print(f\"Failed to download {filename}. Error: {e}\")\n",
|
|
|
|
|
- " return target_path\n",
|
|
|
|
|
|
|
+ "# Download files if not already present in this directory\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
"# Define the directories to search and the files to download\n",
|
|
"# Define the directories to search and the files to download\n",
|
|
|
"search_directories = [\".\", \"../02_bonus_bytepair-encoder/gpt2_model/\"]\n",
|
|
"search_directories = [\".\", \"../02_bonus_bytepair-encoder/gpt2_model/\"]\n",
|
|
@@ -1351,7 +1404,7 @@
|
|
|
"name": "python",
|
|
"name": "python",
|
|
|
"nbconvert_exporter": "python",
|
|
"nbconvert_exporter": "python",
|
|
|
"pygments_lexer": "ipython3",
|
|
"pygments_lexer": "ipython3",
|
|
|
- "version": "3.10.6"
|
|
|
|
|
|
|
+ "version": "3.10.16"
|
|
|
}
|
|
}
|
|
|
},
|
|
},
|
|
|
"nbformat": 4,
|
|
"nbformat": 4,
|