gpt_with_kv_cache.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. # This file collects all the relevant code that we covered thus far
  2. # throughout Chapters 3-4.
  3. # This file can be run as a standalone script.
  4. import time
  5. import tiktoken
  6. import torch
  7. import torch.nn as nn
  8. #####################################
  9. # Chapter 3
  10. #####################################
  11. class MultiHeadAttention(nn.Module):
  12. def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
  13. super().__init__()
  14. assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
  15. self.d_out = d_out
  16. self.num_heads = num_heads
  17. self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
  18. self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
  19. self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
  20. self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
  21. self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
  22. self.dropout = nn.Dropout(dropout)
  23. self.register_buffer(
  24. "mask",
  25. torch.triu(torch.ones(context_length, context_length), diagonal=1),
  26. persistent=False
  27. )
  28. ####################################################
  29. # NEW
  30. self.register_buffer("cache_k", None, persistent=False)
  31. self.register_buffer("cache_v", None, persistent=False)
  32. self.ptr_current_pos = 0
  33. ####################################################
  34. def forward(self, x, use_cache=False):
  35. b, num_tokens, d_in = x.shape
  36. keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out)
  37. values_new = self.W_value(x)
  38. queries = self.W_query(x)
  39. # We implicitly split the matrix by adding a `num_heads` dimension
  40. # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
  41. keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim)
  42. values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim)
  43. queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
  44. ####################################################
  45. # NEW
  46. if use_cache:
  47. if self.cache_k is None:
  48. self.cache_k, self.cache_v = keys_new, values_new
  49. else:
  50. self.cache_k = torch.cat([self.cache_k, keys_new], dim=1)
  51. self.cache_v = torch.cat([self.cache_v, values_new], dim=1)
  52. keys, values = self.cache_k, self.cache_v
  53. else:
  54. keys, values = keys_new, values_new
  55. ####################################################
  56. # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
  57. keys = keys.transpose(1, 2)
  58. queries = queries.transpose(1, 2)
  59. values = values.transpose(1, 2)
  60. # Compute scaled dot-product attention (aka self-attention) with a causal mask
  61. attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
  62. ####################################################
  63. # NEW
  64. num_tokens_Q = queries.shape[-2]
  65. num_tokens_K = keys.shape[-2]
  66. if use_cache:
  67. mask_bool = self.mask.bool()[
  68. self.ptr_current_pos:self.ptr_current_pos + num_tokens_Q, :num_tokens_K
  69. ]
  70. self.ptr_current_pos += num_tokens_Q
  71. ####################################################
  72. # Original mask truncated to the number of tokens and converted to boolean
  73. else:
  74. mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K]
  75. # Use the mask to fill attention scores
  76. attn_scores.masked_fill_(mask_bool, -torch.inf)
  77. attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
  78. attn_weights = self.dropout(attn_weights)
  79. # Shape: (b, num_tokens, num_heads, head_dim)
  80. context_vec = (attn_weights @ values).transpose(1, 2)
  81. # Combine heads, where self.d_out = self.num_heads * self.head_dim
  82. context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
  83. context_vec = self.out_proj(context_vec) # optional projection
  84. return context_vec
  85. ####################################################
  86. # NEW
  87. def reset_cache(self):
  88. self.cache_k, self.cache_v = None, None
  89. self.ptr_current_pos = 0
  90. ####################################################
  91. #####################################
  92. # Chapter 4
  93. #####################################
  94. class LayerNorm(nn.Module):
  95. def __init__(self, emb_dim):
  96. super().__init__()
  97. self.eps = 1e-5
  98. self.scale = nn.Parameter(torch.ones(emb_dim))
  99. self.shift = nn.Parameter(torch.zeros(emb_dim))
  100. def forward(self, x):
  101. mean = x.mean(dim=-1, keepdim=True)
  102. var = x.var(dim=-1, keepdim=True, unbiased=False)
  103. norm_x = (x - mean) / torch.sqrt(var + self.eps)
  104. return self.scale * norm_x + self.shift
  105. class GELU(nn.Module):
  106. def __init__(self):
  107. super().__init__()
  108. def forward(self, x):
  109. return 0.5 * x * (1 + torch.tanh(
  110. torch.sqrt(torch.tensor(2.0 / torch.pi)) *
  111. (x + 0.044715 * torch.pow(x, 3))
  112. ))
  113. class FeedForward(nn.Module):
  114. def __init__(self, cfg):
  115. super().__init__()
  116. self.layers = nn.Sequential(
  117. nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
  118. GELU(),
  119. nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
  120. )
  121. def forward(self, x):
  122. return self.layers(x)
  123. class TransformerBlock(nn.Module):
  124. def __init__(self, cfg):
  125. super().__init__()
  126. self.att = MultiHeadAttention(
  127. d_in=cfg["emb_dim"],
  128. d_out=cfg["emb_dim"],
  129. context_length=cfg["context_length"],
  130. num_heads=cfg["n_heads"],
  131. dropout=cfg["drop_rate"],
  132. qkv_bias=cfg["qkv_bias"])
  133. self.ff = FeedForward(cfg)
  134. self.norm1 = LayerNorm(cfg["emb_dim"])
  135. self.norm2 = LayerNorm(cfg["emb_dim"])
  136. self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
  137. def forward(self, x, use_cache=False):
  138. # Shortcut connection for attention block
  139. shortcut = x
  140. x = self.norm1(x)
  141. # x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
  142. ####################################################
  143. # NEW
  144. x = self.att(x, use_cache=use_cache)
  145. ####################################################
  146. x = self.drop_shortcut(x)
  147. x = x + shortcut # Add the original input back
  148. # Shortcut connection for feed-forward block
  149. shortcut = x
  150. x = self.norm2(x)
  151. x = self.ff(x)
  152. x = self.drop_shortcut(x)
  153. x = x + shortcut # Add the original input back
  154. return x
  155. class GPTModel(nn.Module):
  156. def __init__(self, cfg):
  157. super().__init__()
  158. self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
  159. self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
  160. self.drop_emb = nn.Dropout(cfg["drop_rate"])
  161. # self.trf_blocks = nn.Sequential(
  162. # *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
  163. ####################################################
  164. # NEW
  165. self.trf_blocks = nn.ModuleList(
  166. [TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
  167. self.current_pos = 0
  168. ####################################################
  169. self.final_norm = LayerNorm(cfg["emb_dim"])
  170. self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
  171. def forward(self, in_idx, use_cache=False):
  172. batch_size, seq_len = in_idx.shape
  173. tok_embeds = self.tok_emb(in_idx)
  174. # pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
  175. ####################################################
  176. # NEW
  177. if use_cache:
  178. pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long)
  179. self.current_pos += seq_len
  180. else:
  181. pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long)
  182. pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
  183. ####################################################
  184. x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
  185. x = self.drop_emb(x)
  186. # x = self.trf_blocks(x)
  187. ####################################################
  188. # NEW
  189. for blk in self.trf_blocks:
  190. x = blk(x, use_cache=use_cache)
  191. ####################################################
  192. x = self.final_norm(x)
  193. logits = self.out_head(x)
  194. return logits
  195. ####################################################
  196. # NEW
  197. def reset_kv_cache(self):
  198. for blk in self.trf_blocks:
  199. blk.att.reset_cache()
  200. self.current_pos = 0
  201. ####################################################
  202. def generate_text_simple(model, idx, max_new_tokens, context_size):
  203. # idx is (B, T) array of indices in the current context
  204. for _ in range(max_new_tokens):
  205. # Crop current context if it exceeds the supported context size
  206. # E.g., if LLM supports only 5 tokens, and the context size is 10
  207. # then only the last 5 tokens are used as context
  208. idx_cond = idx[:, -context_size:]
  209. # Get the predictions
  210. with torch.no_grad():
  211. logits = model(idx_cond)
  212. # Focus only on the last time step
  213. # (batch, n_token, vocab_size) becomes (batch, vocab_size)
  214. logits = logits[:, -1, :]
  215. # Get the idx of the vocab entry with the highest logits value
  216. idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
  217. # Append sampled index to the running sequence
  218. idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
  219. return idx
  220. ####################################################
  221. # NEW
  222. def generate_text_simple_cached(model, idx, max_new_tokens,
  223. context_size=None, use_cache=True):
  224. model.eval()
  225. ctx_len = context_size or model.pos_emb.num_embeddings
  226. with torch.no_grad():
  227. if use_cache:
  228. # Init cache with full prompt
  229. model.reset_kv_cache()
  230. logits = model(idx[:, -ctx_len:], use_cache=True)
  231. for _ in range(max_new_tokens):
  232. # a) pick the token with the highest log-probability (greedy sampling)
  233. next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
  234. # b) append it to the running sequence
  235. idx = torch.cat([idx, next_idx], dim=1)
  236. # c) feed model only the new token
  237. logits = model(next_idx, use_cache=True)
  238. else:
  239. for _ in range(max_new_tokens):
  240. logits = model(idx[:, -ctx_len:], use_cache=False)
  241. next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
  242. idx = torch.cat([idx, next_idx], dim=1)
  243. return idx
  244. ####################################################
  245. def main():
  246. GPT_CONFIG_124M = {
  247. "vocab_size": 50257, # Vocabulary size
  248. "context_length": 1024, # Context length
  249. "emb_dim": 768, # Embedding dimension
  250. "n_heads": 12, # Number of attention heads
  251. "n_layers": 12, # Number of layers
  252. "drop_rate": 0.1, # Dropout rate
  253. "qkv_bias": False # Query-Key-Value bias
  254. }
  255. torch.manual_seed(123)
  256. model = GPTModel(GPT_CONFIG_124M)
  257. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  258. model.to(device)
  259. model.eval() # disable dropout
  260. start_context = "Hello, I am"
  261. tokenizer = tiktoken.get_encoding("gpt2")
  262. encoded = tokenizer.encode(start_context)
  263. encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0)
  264. print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
  265. print("\nInput text:", start_context)
  266. print("Encoded input text:", encoded)
  267. print("encoded_tensor.shape:", encoded_tensor.shape)
  268. if torch.cuda.is_available():
  269. torch.cuda.synchronize()
  270. start = time.time()
  271. # token_ids = generate_text_simple(
  272. # model=model,
  273. # idx=encoded_tensor,
  274. # max_new_tokens=200,
  275. # context_size=GPT_CONFIG_124M["context_length"]
  276. # )
  277. ####################################################
  278. # NEW
  279. token_ids = generate_text_simple_cached(
  280. model=model,
  281. idx=encoded_tensor,
  282. max_new_tokens=200,
  283. )
  284. ####################################################
  285. if torch.cuda.is_available():
  286. torch.cuda.synchronize()
  287. total_time = time.time() - start
  288. decoded_text = tokenizer.decode(token_ids.squeeze(0).tolist())
  289. print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}")
  290. print("\nOutput:", token_ids)
  291. print("Output length:", len(token_ids[0]))
  292. print("Output text:", decoded_text)
  293. print(f"\nTime: {total_time:.2f} sec")
  294. print(f"{int(len(token_ids[0])/total_time)} tokens/sec")
  295. if torch.cuda.is_available():
  296. max_mem_bytes = torch.cuda.max_memory_allocated()
  297. max_mem_gb = max_mem_bytes / (1024 ** 3)
  298. print(f"Max memory allocated: {max_mem_gb:.2f} GB")
  299. if __name__ == "__main__":
  300. main()