previous_chapters.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. # This file collects all the relevant code that we covered thus far
  6. # throughout Chapters 2-4.
  7. # This file can be run as a standalone script.
  8. import tiktoken
  9. import torch
  10. import torch.nn as nn
  11. from torch.utils.data import Dataset, DataLoader
  12. #####################################
  13. # Chapter 2
  14. #####################################
  15. class GPTDatasetV1(Dataset):
  16. def __init__(self, txt, tokenizer, max_length, stride):
  17. self.tokenizer = tokenizer
  18. self.input_ids = []
  19. self.target_ids = []
  20. # Tokenize the entire text
  21. token_ids = tokenizer.encode(txt)
  22. # Use a sliding window to chunk the book into overlapping sequences of max_length
  23. for i in range(0, len(token_ids) - max_length, stride):
  24. input_chunk = token_ids[i:i + max_length]
  25. target_chunk = token_ids[i + 1: i + max_length + 1]
  26. self.input_ids.append(torch.tensor(input_chunk))
  27. self.target_ids.append(torch.tensor(target_chunk))
  28. def __len__(self):
  29. return len(self.input_ids)
  30. def __getitem__(self, idx):
  31. return self.input_ids[idx], self.target_ids[idx]
  32. def create_dataloader_v1(txt, batch_size=4, max_length=256,
  33. stride=128, shuffle=True, drop_last=True):
  34. # Initialize the tokenizer
  35. tokenizer = tiktoken.get_encoding("gpt2")
  36. # Create dataset
  37. dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
  38. # Create dataloader
  39. dataloader = DataLoader(
  40. dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
  41. return dataloader
  42. #####################################
  43. # Chapter 3
  44. #####################################
  45. class MultiHeadAttention(nn.Module):
  46. def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
  47. super().__init__()
  48. assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
  49. self.d_out = d_out
  50. self.num_heads = num_heads
  51. self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
  52. self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
  53. self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
  54. self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
  55. self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
  56. self.dropout = nn.Dropout(dropout)
  57. self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
  58. def forward(self, x):
  59. b, num_tokens, d_in = x.shape
  60. keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
  61. queries = self.W_query(x)
  62. values = self.W_value(x)
  63. # We implicitly split the matrix by adding a `num_heads` dimension
  64. # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
  65. keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
  66. values = values.view(b, num_tokens, self.num_heads, self.head_dim)
  67. queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
  68. # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
  69. keys = keys.transpose(1, 2)
  70. queries = queries.transpose(1, 2)
  71. values = values.transpose(1, 2)
  72. # Compute scaled dot-product attention (aka self-attention) with a causal mask
  73. attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
  74. # Original mask truncated to the number of tokens and converted to boolean
  75. mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
  76. # Use the mask to fill attention scores
  77. attn_scores.masked_fill_(mask_bool, -torch.inf)
  78. attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
  79. attn_weights = self.dropout(attn_weights)
  80. # Shape: (b, num_tokens, num_heads, head_dim)
  81. context_vec = (attn_weights @ values).transpose(1, 2)
  82. # Combine heads, where self.d_out = self.num_heads * self.head_dim
  83. context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
  84. context_vec = self.out_proj(context_vec) # optional projection
  85. return context_vec
  86. #####################################
  87. # Chapter 4
  88. #####################################
  89. class LayerNorm(nn.Module):
  90. def __init__(self, emb_dim):
  91. super().__init__()
  92. self.eps = 1e-5
  93. self.scale = nn.Parameter(torch.ones(emb_dim))
  94. self.shift = nn.Parameter(torch.zeros(emb_dim))
  95. def forward(self, x):
  96. mean = x.mean(dim=-1, keepdim=True)
  97. var = x.var(dim=-1, keepdim=True, unbiased=False)
  98. norm_x = (x - mean) / torch.sqrt(var + self.eps)
  99. return self.scale * norm_x + self.shift
  100. class GELU(nn.Module):
  101. def __init__(self):
  102. super().__init__()
  103. def forward(self, x):
  104. return 0.5 * x * (1 + torch.tanh(
  105. torch.sqrt(torch.tensor(2.0 / torch.pi)) *
  106. (x + 0.044715 * torch.pow(x, 3))
  107. ))
  108. class FeedForward(nn.Module):
  109. def __init__(self, cfg):
  110. super().__init__()
  111. self.layers = nn.Sequential(
  112. nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
  113. GELU(),
  114. nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
  115. )
  116. def forward(self, x):
  117. return self.layers(x)
  118. class TransformerBlock(nn.Module):
  119. def __init__(self, cfg):
  120. super().__init__()
  121. self.att = MultiHeadAttention(
  122. d_in=cfg["emb_dim"],
  123. d_out=cfg["emb_dim"],
  124. context_length=cfg["context_length"],
  125. num_heads=cfg["n_heads"],
  126. dropout=cfg["drop_rate"],
  127. qkv_bias=cfg["qkv_bias"])
  128. self.ff = FeedForward(cfg)
  129. self.norm1 = LayerNorm(cfg["emb_dim"])
  130. self.norm2 = LayerNorm(cfg["emb_dim"])
  131. self.drop_resid = nn.Dropout(cfg["drop_rate"])
  132. def forward(self, x):
  133. # Shortcut connection for attention block
  134. shortcut = x
  135. x = self.norm1(x)
  136. x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
  137. x = self.drop_resid(x)
  138. x = x + shortcut # Add the original input back
  139. # Shortcut connection for feed-forward block
  140. shortcut = x
  141. x = self.norm2(x)
  142. x = self.ff(x)
  143. x = self.drop_resid(x)
  144. x = x + shortcut # Add the original input back
  145. return x
  146. class GPTModel(nn.Module):
  147. def __init__(self, cfg):
  148. super().__init__()
  149. self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
  150. self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
  151. self.drop_emb = nn.Dropout(cfg["drop_rate"])
  152. self.trf_blocks = nn.Sequential(
  153. *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
  154. self.final_norm = LayerNorm(cfg["emb_dim"])
  155. self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
  156. def forward(self, in_idx):
  157. batch_size, seq_len = in_idx.shape
  158. tok_embeds = self.tok_emb(in_idx)
  159. pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
  160. x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
  161. x = self.drop_emb(x)
  162. x = self.trf_blocks(x)
  163. x = self.final_norm(x)
  164. logits = self.out_head(x)
  165. return logits
  166. def generate_text_simple(model, idx, max_new_tokens, context_size):
  167. # idx is (B, T) array of indices in the current context
  168. for _ in range(max_new_tokens):
  169. # Crop current context if it exceeds the supported context size
  170. # E.g., if LLM supports only 5 tokens, and the context size is 10
  171. # then only the last 5 tokens are used as context
  172. idx_cond = idx[:, -context_size:]
  173. # Get the predictions
  174. with torch.no_grad():
  175. logits = model(idx_cond)
  176. # Focus only on the last time step
  177. # (batch, n_token, vocab_size) becomes (batch, vocab_size)
  178. logits = logits[:, -1, :]
  179. # Get the idx of the vocab entry with the highest logits value
  180. idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
  181. # Append sampled index to the running sequence
  182. idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
  183. return idx
  184. if __name__ == "__main__":
  185. GPT_CONFIG_124M = {
  186. "vocab_size": 50257, # Vocabulary size
  187. "context_length": 1024, # Context length
  188. "emb_dim": 768, # Embedding dimension
  189. "n_heads": 12, # Number of attention heads
  190. "n_layers": 12, # Number of layers
  191. "drop_rate": 0.1, # Dropout rate
  192. "qkv_bias": False # Query-Key-Value bias
  193. }
  194. torch.manual_seed(123)
  195. model = GPTModel(GPT_CONFIG_124M)
  196. model.eval() # disable dropout
  197. start_context = "Hello, I am"
  198. tokenizer = tiktoken.get_encoding("gpt2")
  199. encoded = tokenizer.encode(start_context)
  200. encoded_tensor = torch.tensor(encoded).unsqueeze(0)
  201. print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
  202. print("\nInput text:", start_context)
  203. print("Encoded input text:", encoded)
  204. print("encoded_tensor.shape:", encoded_tensor.shape)
  205. out = generate_text_simple(
  206. model=model,
  207. idx=encoded_tensor,
  208. max_new_tokens=10,
  209. context_size=GPT_CONFIG_124M["context_length"]
  210. )
  211. decoded_text = tokenizer.decode(out.squeeze(0).tolist())
  212. print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}")
  213. print("\nOutput:", out)
  214. print("Output length:", len(out[0]))
  215. print("Output text:", decoded_text)