gpt_with_kv_cache.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. # This file collects all the relevant code that we covered thus far
  2. # throughout Chapters 3-4.
  3. # This file can be run as a standalone script.
  4. import time
  5. import tiktoken
  6. import torch
  7. import torch.nn as nn
  8. #####################################
  9. # Chapter 3
  10. #####################################
  11. class MultiHeadAttention(nn.Module):
  12. def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
  13. super().__init__()
  14. assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
  15. self.d_out = d_out
  16. self.num_heads = num_heads
  17. self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
  18. self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
  19. self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
  20. self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
  21. self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
  22. self.dropout = nn.Dropout(dropout)
  23. self.register_buffer(
  24. "mask",
  25. torch.triu(torch.ones(context_length, context_length),diagonal=1),
  26. persistent=False
  27. )
  28. ####################################################
  29. # NEW
  30. self.register_buffer("cache_k", None, persistent=False)
  31. self.register_buffer("cache_v", None, persistent=False)
  32. ####################################################
  33. def forward(self, x, use_cache=False):
  34. b, num_tokens, d_in = x.shape
  35. keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out)
  36. values_new = self.W_value(x)
  37. queries = self.W_query(x)
  38. # We implicitly split the matrix by adding a `num_heads` dimension
  39. # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
  40. keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim)
  41. values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim)
  42. queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
  43. ####################################################
  44. # NEW
  45. if use_cache:
  46. if self.cache_k is None:
  47. self.cache_k, self.cache_v = keys_new, values_new
  48. else:
  49. self.cache_k = torch.cat([self.cache_k, keys_new], dim=1)
  50. self.cache_v = torch.cat([self.cache_v, values_new], dim=1)
  51. keys, values = self.cache_k, self.cache_v
  52. else:
  53. keys, values = keys_new, values_new
  54. ####################################################
  55. # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
  56. keys = keys.transpose(1, 2)
  57. queries = queries.transpose(1, 2)
  58. values = values.transpose(1, 2)
  59. # Compute scaled dot-product attention (aka self-attention) with a causal mask
  60. attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
  61. # Original mask truncated to the number of tokens and converted to boolean
  62. mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
  63. # Use the mask to fill attention scores
  64. attn_scores.masked_fill_(mask_bool, -torch.inf)
  65. attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
  66. attn_weights = self.dropout(attn_weights)
  67. # Shape: (b, num_tokens, num_heads, head_dim)
  68. context_vec = (attn_weights @ values).transpose(1, 2)
  69. # Combine heads, where self.d_out = self.num_heads * self.head_dim
  70. context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
  71. context_vec = self.out_proj(context_vec) # optional projection
  72. return context_vec
  73. ####################################################
  74. # NEW
  75. def reset_cache(self):
  76. self.cache_k, self.cache_v = None, None
  77. ####################################################
  78. #####################################
  79. # Chapter 4
  80. #####################################
  81. class LayerNorm(nn.Module):
  82. def __init__(self, emb_dim):
  83. super().__init__()
  84. self.eps = 1e-5
  85. self.scale = nn.Parameter(torch.ones(emb_dim))
  86. self.shift = nn.Parameter(torch.zeros(emb_dim))
  87. def forward(self, x):
  88. mean = x.mean(dim=-1, keepdim=True)
  89. var = x.var(dim=-1, keepdim=True, unbiased=False)
  90. norm_x = (x - mean) / torch.sqrt(var + self.eps)
  91. return self.scale * norm_x + self.shift
  92. class GELU(nn.Module):
  93. def __init__(self):
  94. super().__init__()
  95. def forward(self, x):
  96. return 0.5 * x * (1 + torch.tanh(
  97. torch.sqrt(torch.tensor(2.0 / torch.pi)) *
  98. (x + 0.044715 * torch.pow(x, 3))
  99. ))
  100. class FeedForward(nn.Module):
  101. def __init__(self, cfg):
  102. super().__init__()
  103. self.layers = nn.Sequential(
  104. nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
  105. GELU(),
  106. nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
  107. )
  108. def forward(self, x):
  109. return self.layers(x)
  110. class TransformerBlock(nn.Module):
  111. def __init__(self, cfg):
  112. super().__init__()
  113. self.att = MultiHeadAttention(
  114. d_in=cfg["emb_dim"],
  115. d_out=cfg["emb_dim"],
  116. context_length=cfg["context_length"],
  117. num_heads=cfg["n_heads"],
  118. dropout=cfg["drop_rate"],
  119. qkv_bias=cfg["qkv_bias"])
  120. self.ff = FeedForward(cfg)
  121. self.norm1 = LayerNorm(cfg["emb_dim"])
  122. self.norm2 = LayerNorm(cfg["emb_dim"])
  123. self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
  124. def forward(self, x, use_cache=False):
  125. # Shortcut connection for attention block
  126. shortcut = x
  127. x = self.norm1(x)
  128. # x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
  129. ####################################################
  130. # NEW
  131. x = self.att(x, use_cache=use_cache)
  132. ####################################################
  133. x = self.drop_shortcut(x)
  134. x = x + shortcut # Add the original input back
  135. # Shortcut connection for feed-forward block
  136. shortcut = x
  137. x = self.norm2(x)
  138. x = self.ff(x)
  139. x = self.drop_shortcut(x)
  140. x = x + shortcut # Add the original input back
  141. return x
  142. class GPTModel(nn.Module):
  143. def __init__(self, cfg):
  144. super().__init__()
  145. self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
  146. self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
  147. self.drop_emb = nn.Dropout(cfg["drop_rate"])
  148. # self.trf_blocks = nn.Sequential(
  149. # *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
  150. ####################################################
  151. # NEW
  152. self.trf_blocks = nn.ModuleList(
  153. [TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
  154. self.current_pos = 0
  155. ####################################################
  156. self.final_norm = LayerNorm(cfg["emb_dim"])
  157. self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
  158. def forward(self, in_idx, use_cache=False):
  159. batch_size, seq_len = in_idx.shape
  160. tok_embeds = self.tok_emb(in_idx)
  161. # pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
  162. ####################################################
  163. # NEW
  164. if use_cache:
  165. pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long)
  166. self.current_pos += seq_len
  167. else:
  168. pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long)
  169. pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
  170. ####################################################
  171. x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
  172. x = self.drop_emb(x)
  173. # x = self.trf_blocks(x)
  174. ####################################################
  175. # NEW
  176. for blk in self.trf_blocks:
  177. x = blk(x, use_cache=use_cache)
  178. ####################################################
  179. x = self.final_norm(x)
  180. logits = self.out_head(x)
  181. return logits
  182. ####################################################
  183. # NEW
  184. def reset_kv_cache(self):
  185. for blk in self.trf_blocks:
  186. blk.att.reset_cache()
  187. self.current_pos = 0
  188. ####################################################
  189. def generate_text_simple(model, idx, max_new_tokens, context_size):
  190. # idx is (B, T) array of indices in the current context
  191. for _ in range(max_new_tokens):
  192. # Crop current context if it exceeds the supported context size
  193. # E.g., if LLM supports only 5 tokens, and the context size is 10
  194. # then only the last 5 tokens are used as context
  195. idx_cond = idx[:, -context_size:]
  196. # Get the predictions
  197. with torch.no_grad():
  198. logits = model(idx_cond)
  199. # Focus only on the last time step
  200. # (batch, n_token, vocab_size) becomes (batch, vocab_size)
  201. logits = logits[:, -1, :]
  202. # Get the idx of the vocab entry with the highest logits value
  203. idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
  204. # Append sampled index to the running sequence
  205. idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
  206. return idx
  207. ####################################################
  208. # NEW
  209. def generate_text_simple_cached(model, idx, max_new_tokens):
  210. model.eval()
  211. model.reset_kv_cache()
  212. # Init cache with full prompt
  213. logits = model(idx, use_cache=True)
  214. for _ in range(max_new_tokens):
  215. last_logits = logits[:, -1]
  216. next_idx = last_logits.argmax(dim=-1, keepdim=True)
  217. idx = torch.cat([idx, next_idx], dim=1)
  218. logits = model(next_idx, use_cache=True)
  219. return idx
  220. ####################################################
  221. def main():
  222. GPT_CONFIG_124M = {
  223. "vocab_size": 50257, # Vocabulary size
  224. "context_length": 1024, # Context length
  225. "emb_dim": 768, # Embedding dimension
  226. "n_heads": 12, # Number of attention heads
  227. "n_layers": 12, # Number of layers
  228. "drop_rate": 0.1, # Dropout rate
  229. "qkv_bias": False # Query-Key-Value bias
  230. }
  231. torch.manual_seed(123)
  232. model = GPTModel(GPT_CONFIG_124M)
  233. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  234. model.to(device)
  235. model.eval() # disable dropout
  236. start_context = "Hello, I am"
  237. tokenizer = tiktoken.get_encoding("gpt2")
  238. encoded = tokenizer.encode(start_context)
  239. encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0)
  240. print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
  241. print("\nInput text:", start_context)
  242. print("Encoded input text:", encoded)
  243. print("encoded_tensor.shape:", encoded_tensor.shape)
  244. if torch.cuda.is_available():
  245. torch.cuda.synchronize()
  246. start = time.time()
  247. # token_ids = generate_text_simple(
  248. # model=model,
  249. # idx=encoded_tensor,
  250. # max_new_tokens=200,
  251. # context_size=GPT_CONFIG_124M["context_length"]
  252. # )
  253. ####################################################
  254. # NEW
  255. token_ids = generate_text_simple_cached(
  256. model=model,
  257. idx=encoded_tensor,
  258. max_new_tokens=200,
  259. )
  260. ####################################################
  261. if torch.cuda.is_available():
  262. torch.cuda.synchronize()
  263. total_time = time.time() - start
  264. decoded_text = tokenizer.decode(token_ids.squeeze(0).tolist())
  265. print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}")
  266. print("\nOutput:", token_ids)
  267. print("Output length:", len(token_ids[0]))
  268. print("Output text:", decoded_text)
  269. print(f"\nTime: {total_time:.2f} sec")
  270. print(f"{int(len(token_ids[0])/total_time)} tokens/sec")
  271. if torch.cuda.is_available():
  272. max_mem_bytes = torch.cuda.max_memory_allocated()
  273. max_mem_gb = max_mem_bytes / (1024 ** 3)
  274. print(f"Max memory allocated: {max_mem_gb:.2f} GB")
  275. if __name__ == "__main__":
  276. main()