ch05.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. from .ch04 import generate_text_simple
  6. import json
  7. import os
  8. import urllib.request
  9. import numpy as np
  10. import matplotlib.pyplot as plt
  11. from matplotlib.ticker import MaxNLocator
  12. import torch
  13. from tqdm import tqdm
  14. def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):
  15. # For-loop is the same as before: Get logits, and only focus on last time step
  16. for _ in range(max_new_tokens):
  17. idx_cond = idx[:, -context_size:]
  18. with torch.no_grad():
  19. logits = model(idx_cond)
  20. logits = logits[:, -1, :]
  21. # New: Filter logits with top_k sampling
  22. if top_k is not None:
  23. # Keep only top_k values
  24. top_logits, _ = torch.topk(logits, top_k)
  25. min_val = top_logits[:, -1]
  26. logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
  27. # New: Apply temperature scaling
  28. if temperature > 0.0:
  29. logits = logits / temperature
  30. # Apply softmax to get probabilities
  31. probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
  32. # Sample from the distribution
  33. idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
  34. # Otherwise same as before: get idx of the vocab entry with the highest logits value
  35. else:
  36. idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
  37. if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified
  38. break
  39. # Same as before: append sampled index to the running sequence
  40. idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)
  41. return idx
  42. def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
  43. eval_freq, eval_iter, start_context, tokenizer):
  44. # Initialize lists to track losses and tokens seen
  45. train_losses, val_losses, track_tokens_seen = [], [], []
  46. tokens_seen, global_step = 0, -1
  47. # Main training loop
  48. for epoch in range(num_epochs):
  49. model.train() # Set model to training mode
  50. for input_batch, target_batch in train_loader:
  51. optimizer.zero_grad() # Reset loss gradients from previous batch iteration
  52. loss = calc_loss_batch(input_batch, target_batch, model, device)
  53. loss.backward() # Calculate loss gradients
  54. optimizer.step() # Update model weights using loss gradients
  55. tokens_seen += input_batch.numel()
  56. global_step += 1
  57. # Optional evaluation step
  58. if global_step % eval_freq == 0:
  59. train_loss, val_loss = evaluate_model(
  60. model, train_loader, val_loader, device, eval_iter)
  61. train_losses.append(train_loss)
  62. val_losses.append(val_loss)
  63. track_tokens_seen.append(tokens_seen)
  64. print(f"Ep {epoch+1} (Step {global_step:06d}): "
  65. f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
  66. # Print a sample text after each epoch
  67. generate_and_print_sample(
  68. model, tokenizer, device, start_context
  69. )
  70. return train_losses, val_losses, track_tokens_seen
  71. def evaluate_model(model, train_loader, val_loader, device, eval_iter):
  72. model.eval()
  73. with torch.no_grad():
  74. train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
  75. val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
  76. model.train()
  77. return train_loss, val_loss
  78. def generate_and_print_sample(model, tokenizer, device, start_context):
  79. model.eval()
  80. context_size = model.pos_emb.weight.shape[0]
  81. encoded = text_to_token_ids(start_context, tokenizer).to(device)
  82. with torch.no_grad():
  83. token_ids = generate_text_simple(
  84. model=model, idx=encoded,
  85. max_new_tokens=50, context_size=context_size
  86. )
  87. decoded_text = token_ids_to_text(token_ids, tokenizer)
  88. print(decoded_text.replace("\n", " ")) # Compact print format
  89. model.train()
  90. def assign(left, right):
  91. if left.shape != right.shape:
  92. raise ValueError(f"Shape mismatch. Left: {left.shape}, Right: {right.shape}")
  93. return torch.nn.Parameter(torch.tensor(right))
  94. def load_weights_into_gpt(gpt, params):
  95. gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
  96. gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
  97. for b in range(len(params["blocks"])):
  98. q_w, k_w, v_w = np.split(
  99. (params["blocks"][b]["attn"]["c_attn"])["w"], 3, axis=-1)
  100. gpt.trf_blocks[b].att.W_query.weight = assign(
  101. gpt.trf_blocks[b].att.W_query.weight, q_w.T)
  102. gpt.trf_blocks[b].att.W_key.weight = assign(
  103. gpt.trf_blocks[b].att.W_key.weight, k_w.T)
  104. gpt.trf_blocks[b].att.W_value.weight = assign(
  105. gpt.trf_blocks[b].att.W_value.weight, v_w.T)
  106. q_b, k_b, v_b = np.split(
  107. (params["blocks"][b]["attn"]["c_attn"])["b"], 3, axis=-1)
  108. gpt.trf_blocks[b].att.W_query.bias = assign(
  109. gpt.trf_blocks[b].att.W_query.bias, q_b)
  110. gpt.trf_blocks[b].att.W_key.bias = assign(
  111. gpt.trf_blocks[b].att.W_key.bias, k_b)
  112. gpt.trf_blocks[b].att.W_value.bias = assign(
  113. gpt.trf_blocks[b].att.W_value.bias, v_b)
  114. gpt.trf_blocks[b].att.out_proj.weight = assign(
  115. gpt.trf_blocks[b].att.out_proj.weight,
  116. params["blocks"][b]["attn"]["c_proj"]["w"].T)
  117. gpt.trf_blocks[b].att.out_proj.bias = assign(
  118. gpt.trf_blocks[b].att.out_proj.bias,
  119. params["blocks"][b]["attn"]["c_proj"]["b"])
  120. gpt.trf_blocks[b].ff.layers[0].weight = assign(
  121. gpt.trf_blocks[b].ff.layers[0].weight,
  122. params["blocks"][b]["mlp"]["c_fc"]["w"].T)
  123. gpt.trf_blocks[b].ff.layers[0].bias = assign(
  124. gpt.trf_blocks[b].ff.layers[0].bias,
  125. params["blocks"][b]["mlp"]["c_fc"]["b"])
  126. gpt.trf_blocks[b].ff.layers[2].weight = assign(
  127. gpt.trf_blocks[b].ff.layers[2].weight,
  128. params["blocks"][b]["mlp"]["c_proj"]["w"].T)
  129. gpt.trf_blocks[b].ff.layers[2].bias = assign(
  130. gpt.trf_blocks[b].ff.layers[2].bias,
  131. params["blocks"][b]["mlp"]["c_proj"]["b"])
  132. gpt.trf_blocks[b].norm1.scale = assign(
  133. gpt.trf_blocks[b].norm1.scale,
  134. params["blocks"][b]["ln_1"]["g"])
  135. gpt.trf_blocks[b].norm1.shift = assign(
  136. gpt.trf_blocks[b].norm1.shift,
  137. params["blocks"][b]["ln_1"]["b"])
  138. gpt.trf_blocks[b].norm2.scale = assign(
  139. gpt.trf_blocks[b].norm2.scale,
  140. params["blocks"][b]["ln_2"]["g"])
  141. gpt.trf_blocks[b].norm2.shift = assign(
  142. gpt.trf_blocks[b].norm2.shift,
  143. params["blocks"][b]["ln_2"]["b"])
  144. gpt.final_norm.scale = assign(gpt.final_norm.scale, params["g"])
  145. gpt.final_norm.shift = assign(gpt.final_norm.shift, params["b"])
  146. gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
  147. def text_to_token_ids(text, tokenizer):
  148. encoded = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
  149. encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
  150. return encoded_tensor
  151. def token_ids_to_text(token_ids, tokenizer):
  152. flat = token_ids.squeeze(0) # remove batch dimension
  153. return tokenizer.decode(flat.tolist())
  154. def calc_loss_batch(input_batch, target_batch, model, device):
  155. input_batch, target_batch = input_batch.to(device), target_batch.to(device)
  156. logits = model(input_batch)
  157. loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
  158. return loss
  159. def calc_loss_loader(data_loader, model, device, num_batches=None):
  160. total_loss = 0.
  161. if len(data_loader) == 0:
  162. return float("nan")
  163. elif num_batches is None:
  164. num_batches = len(data_loader)
  165. else:
  166. # Reduce the number of batches to match the total number of batches in the data loader
  167. # if num_batches exceeds the number of batches in the data loader
  168. num_batches = min(num_batches, len(data_loader))
  169. for i, (input_batch, target_batch) in enumerate(data_loader):
  170. if i < num_batches:
  171. loss = calc_loss_batch(input_batch, target_batch, model, device)
  172. total_loss += loss.item()
  173. else:
  174. break
  175. return total_loss / num_batches
  176. def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
  177. fig, ax1 = plt.subplots(figsize=(5, 3))
  178. # Plot training and validation loss against epochs
  179. ax1.plot(epochs_seen, train_losses, label="Training loss")
  180. ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss")
  181. ax1.set_xlabel("Epochs")
  182. ax1.set_ylabel("Loss")
  183. ax1.legend(loc="upper right")
  184. ax1.xaxis.set_major_locator(MaxNLocator(integer=True)) # only show integer labels on x-axis
  185. # Create a second x-axis for tokens seen
  186. ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis
  187. ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks
  188. ax2.set_xlabel("Tokens seen")
  189. fig.tight_layout() # Adjust layout to make room
  190. plt.savefig("loss-plot.pdf")
  191. plt.show()
  192. def download_and_load_gpt2(model_size, models_dir):
  193. import tensorflow as tf
  194. # Validate model size
  195. allowed_sizes = ("124M", "355M", "774M", "1558M")
  196. if model_size not in allowed_sizes:
  197. raise ValueError(f"Model size not in {allowed_sizes}")
  198. # Define paths
  199. model_dir = os.path.join(models_dir, model_size)
  200. base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
  201. backup_base_url = "https://f001.backblazeb2.com/file/LLMs-from-scratch/gpt2"
  202. filenames = [
  203. "checkpoint", "encoder.json", "hparams.json",
  204. "model.ckpt.data-00000-of-00001", "model.ckpt.index",
  205. "model.ckpt.meta", "vocab.bpe"
  206. ]
  207. # Download files
  208. os.makedirs(model_dir, exist_ok=True)
  209. for filename in filenames:
  210. file_url = os.path.join(base_url, model_size, filename)
  211. backup_url = os.path.join(backup_base_url, model_size, filename)
  212. file_path = os.path.join(model_dir, filename)
  213. download_file(file_url, file_path, backup_url)
  214. # Load settings and params
  215. tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
  216. settings = json.load(open(os.path.join(model_dir, "hparams.json"), "r", encoding="utf-8"))
  217. params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)
  218. return settings, params
  219. def download_file(url, destination, backup_url=None):
  220. def _attempt_download(download_url):
  221. with urllib.request.urlopen(download_url) as response:
  222. # Get the total file size from headers, defaulting to 0 if not present
  223. file_size = int(response.headers.get("Content-Length", 0))
  224. # Check if file exists and has the same size
  225. if os.path.exists(destination):
  226. file_size_local = os.path.getsize(destination)
  227. if file_size == file_size_local:
  228. print(f"File already exists and is up-to-date: {destination}")
  229. return True # Indicate success without re-downloading
  230. block_size = 1024 # 1 Kilobyte
  231. # Initialize the progress bar with total file size
  232. progress_bar_description = os.path.basename(download_url)
  233. with tqdm(total=file_size, unit="iB", unit_scale=True, desc=progress_bar_description) as progress_bar:
  234. with open(destination, "wb") as file:
  235. while True:
  236. chunk = response.read(block_size)
  237. if not chunk:
  238. break
  239. file.write(chunk)
  240. progress_bar.update(len(chunk))
  241. return True
  242. try:
  243. if _attempt_download(url):
  244. return
  245. except (urllib.error.HTTPError, urllib.error.URLError):
  246. if backup_url is not None:
  247. print(f"Primary URL ({url}) failed. Attempting backup URL: {backup_url}")
  248. try:
  249. if _attempt_download(backup_url):
  250. return
  251. except urllib.error.HTTPError:
  252. pass
  253. # If we reach here, both attempts have failed
  254. error_message = (
  255. f"Failed to download from both primary URL ({url})"
  256. f"{' and backup URL (' + backup_url + ')' if backup_url else ''}."
  257. "\nCheck your internet connection or the file availability.\n"
  258. "For help, visit: https://github.com/rasbt/LLMs-from-scratch/discussions/273"
  259. )
  260. print(error_message)
  261. except Exception as e:
  262. print(f"An unexpected error occurred: {e}")
  263. def load_gpt2_params_from_tf_ckpt(ckpt_path, settings):
  264. import tensorflow as tf
  265. # Initialize parameters dictionary with empty blocks for each layer
  266. params = {"blocks": [{} for _ in range(settings["n_layer"])]}
  267. # Iterate over each variable in the checkpoint
  268. for name, _ in tf.train.list_variables(ckpt_path):
  269. # Load the variable and remove singleton dimensions
  270. variable_array = np.squeeze(tf.train.load_variable(ckpt_path, name))
  271. # Process the variable name to extract relevant parts
  272. variable_name_parts = name.split("/")[1:] # Skip the 'model/' prefix
  273. # Identify the target dictionary for the variable
  274. target_dict = params
  275. if variable_name_parts[0].startswith("h"):
  276. layer_number = int(variable_name_parts[0][1:])
  277. target_dict = params["blocks"][layer_number]
  278. # Recursively access or create nested dictionaries
  279. for key in variable_name_parts[1:-1]:
  280. target_dict = target_dict.setdefault(key, {})
  281. # Assign the variable array to the last key
  282. last_key = variable_name_parts[-1]
  283. target_dict[last_key] = variable_array
  284. return params