ch04.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. from .ch03 import MultiHeadAttention, PyTorchMultiHeadAttention
  6. import torch
  7. import torch.nn as nn
  8. class LayerNorm(nn.Module):
  9. def __init__(self, emb_dim):
  10. super().__init__()
  11. self.eps = 1e-5
  12. self.scale = nn.Parameter(torch.ones(emb_dim))
  13. self.shift = nn.Parameter(torch.zeros(emb_dim))
  14. def forward(self, x):
  15. mean = x.mean(dim=-1, keepdim=True)
  16. var = x.var(dim=-1, keepdim=True, unbiased=False)
  17. norm_x = (x - mean) / torch.sqrt(var + self.eps)
  18. return self.scale * norm_x + self.shift
  19. class GELU(nn.Module):
  20. def __init__(self):
  21. super().__init__()
  22. def forward(self, x):
  23. return 0.5 * x * (1 + torch.tanh(
  24. torch.sqrt(torch.tensor(2.0 / torch.pi)) *
  25. (x + 0.044715 * torch.pow(x, 3))
  26. ))
  27. class FeedForward(nn.Module):
  28. def __init__(self, cfg):
  29. super().__init__()
  30. self.layers = nn.Sequential(
  31. nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
  32. GELU(),
  33. nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
  34. )
  35. def forward(self, x):
  36. return self.layers(x)
  37. class TransformerBlock(nn.Module):
  38. def __init__(self, cfg):
  39. super().__init__()
  40. self.att = MultiHeadAttention(
  41. d_in=cfg["emb_dim"],
  42. d_out=cfg["emb_dim"],
  43. context_length=cfg["context_length"],
  44. num_heads=cfg["n_heads"],
  45. dropout=cfg["drop_rate"],
  46. qkv_bias=cfg["qkv_bias"])
  47. self.ff = FeedForward(cfg)
  48. self.norm1 = LayerNorm(cfg["emb_dim"])
  49. self.norm2 = LayerNorm(cfg["emb_dim"])
  50. self.drop_resid = nn.Dropout(cfg["drop_rate"])
  51. def forward(self, x):
  52. # Shortcut connection for attention block
  53. shortcut = x
  54. x = self.norm1(x)
  55. x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
  56. x = self.drop_resid(x)
  57. x = x + shortcut # Add the original input back
  58. # Shortcut connection for feed-forward block
  59. shortcut = x
  60. x = self.norm2(x)
  61. x = self.ff(x)
  62. x = self.drop_resid(x)
  63. x = x + shortcut # Add the original input back
  64. return x
  65. class GPTModel(nn.Module):
  66. def __init__(self, cfg):
  67. super().__init__()
  68. self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
  69. self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
  70. self.drop_emb = nn.Dropout(cfg["drop_rate"])
  71. self.trf_blocks = nn.Sequential(
  72. *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
  73. self.final_norm = LayerNorm(cfg["emb_dim"])
  74. self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
  75. def forward(self, in_idx):
  76. batch_size, seq_len = in_idx.shape
  77. tok_embeds = self.tok_emb(in_idx)
  78. pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
  79. x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
  80. x = self.drop_emb(x)
  81. x = self.trf_blocks(x)
  82. x = self.final_norm(x)
  83. logits = self.out_head(x)
  84. return logits
  85. def generate_text_simple(model, idx, max_new_tokens, context_size):
  86. # idx is (B, T) array of indices in the current context
  87. for _ in range(max_new_tokens):
  88. # Crop current context if it exceeds the supported context size
  89. # E.g., if LLM supports only 5 tokens, and the context size is 10
  90. # then only the last 5 tokens are used as context
  91. idx_cond = idx[:, -context_size:]
  92. # Get the predictions
  93. with torch.no_grad():
  94. logits = model(idx_cond)
  95. # Focus only on the last time step
  96. # (batch, n_token, vocab_size) becomes (batch, vocab_size)
  97. logits = logits[:, -1, :]
  98. # Get the idx of the vocab entry with the highest logits value
  99. idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
  100. # Append sampled index to the running sequence
  101. idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
  102. return idx
  103. ######################
  104. # Bonus
  105. ######################
  106. class FeedForwardFast(nn.Module):
  107. def __init__(self, cfg):
  108. super().__init__()
  109. self.layers = nn.Sequential(
  110. nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
  111. nn.GELU(approximate="tanh"),
  112. nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
  113. )
  114. def forward(self, x):
  115. return self.layers(x)
  116. class TransformerBlockFast(nn.Module):
  117. def __init__(self, cfg):
  118. super().__init__()
  119. self.att = PyTorchMultiHeadAttention(
  120. d_in=cfg["emb_dim"],
  121. d_out=cfg["emb_dim"],
  122. num_heads=cfg["n_heads"],
  123. dropout=cfg["drop_rate"],
  124. qkv_bias=cfg["qkv_bias"])
  125. self.ff = FeedForwardFast(cfg)
  126. self.norm1 = nn.LayerNorm(cfg["emb_dim"])
  127. self.norm2 = nn.LayerNorm(cfg["emb_dim"])
  128. self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
  129. def forward(self, x):
  130. # Shortcut connection for attention block
  131. shortcut = x
  132. x = self.norm1(x)
  133. x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
  134. x = self.drop_shortcut(x)
  135. x = x + shortcut # Add the original input back
  136. # Shortcut connection for feed-forward block
  137. shortcut = x
  138. x = self.norm2(x)
  139. x = self.ff(x)
  140. x = self.drop_shortcut(x)
  141. x = x + shortcut # Add the original input back
  142. return x
  143. class GPTModelFast(nn.Module):
  144. """
  145. A faster variant of GPTModel optimized for training speed.
  146. This version is only marginally faster on CPU (~1.02x) but significantly
  147. faster on GPU (~2.05x) during training, thanks to optimized CUDA kernels
  148. and FlashAttention support.
  149. Key differences from the original GPTModel:
  150. 1. Uses PyTorch's built-in LayerNorm instead of a custom implementation.
  151. 2. Uses PyTorch's built-in GELU instead of a custom implementation.
  152. 3. Uses PyTorch's scaled_dot_product_attention instead of a custom MultiHeadAttention.
  153. 4. Automatically enables FlashAttention on compatible GPUs.
  154. """
  155. def __init__(self, cfg):
  156. super().__init__()
  157. self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
  158. self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
  159. self.drop_emb = nn.Dropout(cfg["drop_rate"])
  160. self.trf_blocks = nn.Sequential(
  161. *[TransformerBlockFast(cfg) for _ in range(cfg["n_layers"])])
  162. self.final_norm = nn.LayerNorm(cfg["emb_dim"])
  163. self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
  164. def forward(self, in_idx):
  165. batch_size, seq_len = in_idx.shape
  166. tok_embeds = self.tok_emb(in_idx)
  167. pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
  168. x = tok_embeds + pos_embeds
  169. x = self.drop_emb(x)
  170. x = self.trf_blocks(x)
  171. x = self.final_norm(x)
  172. logits = self.out_head(x)
  173. return logits