gpt_with_kv_cache_optimized.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. # This file collects all the relevant code that we covered thus far
  2. # throughout Chapters 3-4.
  3. # This file can be run as a standalone script.
  4. import time
  5. import tiktoken
  6. import torch
  7. import torch.nn as nn
  8. #####################################
  9. # Chapter 3
  10. #####################################
  11. class MultiHeadAttention(nn.Module):
  12. def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False, max_seq_len=None, window_size=None):
  13. super().__init__()
  14. assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
  15. self.d_out = d_out
  16. self.num_heads = num_heads
  17. self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
  18. self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
  19. self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
  20. self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
  21. self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
  22. self.dropout = nn.Dropout(dropout)
  23. ####################################################
  24. # NEW
  25. self.max_seq_len = max_seq_len or context_length
  26. self.window_size = window_size or self.max_seq_len
  27. self.register_buffer("cache_k", None, persistent=False)
  28. self.register_buffer("cache_v", None, persistent=False)
  29. ####################################################
  30. def forward(self, x, use_cache=False):
  31. b, num_tokens, d_in = x.shape
  32. keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out)
  33. values_new = self.W_value(x)
  34. queries = self.W_query(x)
  35. # We implicitly split the matrix by adding a `num_heads` dimension
  36. # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
  37. keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim)
  38. values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim)
  39. queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
  40. # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
  41. keys_new = keys_new.transpose(1, 2)
  42. values_new = values_new.transpose(1, 2)
  43. queries = queries.transpose(1, 2)
  44. ####################################################
  45. # NEW
  46. if use_cache:
  47. if self.cache_k is None or self.cache_k.size(0) != b:
  48. self.cache_k = torch.zeros(b, self.num_heads,
  49. self.window_size, self.head_dim,
  50. device=x.device)
  51. self.cache_v = torch.zeros_like(self.cache_k)
  52. self.ptr_cur = 0 # pointer to next free slot
  53. # if incoming chunk would overflow discard oldest tokens
  54. if self.ptr_cur + num_tokens > self.window_size:
  55. overflow = self.ptr_cur + num_tokens - self.window_size
  56. # shift everything left by `overflow` (cheap view-copy)
  57. self.cache_k[:, :, :-overflow, :] = self.cache_k[:, :, overflow:, :].clone()
  58. self.cache_v[:, :, :-overflow, :] = self.cache_v[:, :, overflow:, :].clone()
  59. self.ptr_cur -= overflow # pointer after shift
  60. self.cache_k[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = keys_new
  61. self.cache_v[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = values_new
  62. self.ptr_cur += num_tokens
  63. keys = self.cache_k[:, :, :self.ptr_cur, :]
  64. values = self.cache_v[:, :, :self.ptr_cur, :]
  65. else:
  66. keys, values = keys_new, values_new
  67. self.ptr_cur = 0 # keep pointer sane if you interleave modes
  68. ####################################################
  69. # Compute scaled dot-product attention (aka self-attention) with a causal mask
  70. attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
  71. ####################################################
  72. # NEW
  73. K = attn_scores.size(-1)
  74. if num_tokens == K:
  75. # No cache → use the pre‑baked triangular mask slice
  76. causal_mask = torch.triu(torch.ones(num_tokens, K, device=x.device, dtype=torch.bool), diagonal=1)
  77. else:
  78. # Cached: need to offset the diagonal by (K − num_tokens)
  79. offset = K - num_tokens # number of tokens already in cache before this chunk
  80. row_idx = torch.arange(num_tokens, device=x.device).unsqueeze(1) # (num_tokens, 1)
  81. col_idx = torch.arange(K, device=x.device).unsqueeze(0) # (1, K)
  82. causal_mask = row_idx + offset < col_idx # True where j > i+offset
  83. ####################################################
  84. # Use the mask to fill attention scores
  85. attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), -torch.inf)
  86. attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
  87. attn_weights = self.dropout(attn_weights)
  88. # Shape: (b, num_tokens, num_heads, head_dim)
  89. context_vec = (attn_weights @ values).transpose(1, 2)
  90. # Combine heads, where self.d_out = self.num_heads * self.head_dim
  91. context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
  92. context_vec = self.out_proj(context_vec) # optional projection
  93. return context_vec
  94. ####################################################
  95. # NEW
  96. def reset_cache(self):
  97. self.cache_k, self.cache_v = None, None
  98. ####################################################
  99. #####################################
  100. # Chapter 4
  101. #####################################
  102. class LayerNorm(nn.Module):
  103. def __init__(self, emb_dim):
  104. super().__init__()
  105. self.eps = 1e-5
  106. self.scale = nn.Parameter(torch.ones(emb_dim))
  107. self.shift = nn.Parameter(torch.zeros(emb_dim))
  108. def forward(self, x):
  109. mean = x.mean(dim=-1, keepdim=True)
  110. var = x.var(dim=-1, keepdim=True, unbiased=False)
  111. norm_x = (x - mean) / torch.sqrt(var + self.eps)
  112. return self.scale * norm_x + self.shift
  113. class GELU(nn.Module):
  114. def __init__(self):
  115. super().__init__()
  116. def forward(self, x):
  117. return 0.5 * x * (1 + torch.tanh(
  118. torch.sqrt(torch.tensor(2.0 / torch.pi)) *
  119. (x + 0.044715 * torch.pow(x, 3))
  120. ))
  121. class FeedForward(nn.Module):
  122. def __init__(self, cfg):
  123. super().__init__()
  124. self.layers = nn.Sequential(
  125. nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
  126. GELU(),
  127. nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
  128. )
  129. def forward(self, x):
  130. return self.layers(x)
  131. class TransformerBlock(nn.Module):
  132. def __init__(self, cfg):
  133. super().__init__()
  134. self.att = MultiHeadAttention(
  135. d_in=cfg["emb_dim"],
  136. d_out=cfg["emb_dim"],
  137. context_length=cfg["context_length"],
  138. num_heads=cfg["n_heads"],
  139. dropout=cfg["drop_rate"],
  140. qkv_bias=cfg["qkv_bias"],
  141. window_size=cfg["kv_window_size"] if "kv_window_size" in cfg else cfg["context_length"] # NEW
  142. )
  143. self.ff = FeedForward(cfg)
  144. self.norm1 = LayerNorm(cfg["emb_dim"])
  145. self.norm2 = LayerNorm(cfg["emb_dim"])
  146. self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
  147. def forward(self, x, use_cache=False):
  148. # Shortcut connection for attention block
  149. shortcut = x
  150. x = self.norm1(x)
  151. # x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
  152. ####################################################
  153. # NEW
  154. x = self.att(x, use_cache=use_cache)
  155. ####################################################
  156. x = self.drop_shortcut(x)
  157. x = x + shortcut # Add the original input back
  158. # Shortcut connection for feed-forward block
  159. shortcut = x
  160. x = self.norm2(x)
  161. x = self.ff(x)
  162. x = self.drop_shortcut(x)
  163. x = x + shortcut # Add the original input back
  164. return x
  165. class GPTModel(nn.Module):
  166. def __init__(self, cfg):
  167. super().__init__()
  168. self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
  169. self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
  170. self.drop_emb = nn.Dropout(cfg["drop_rate"])
  171. # self.trf_blocks = nn.Sequential(
  172. # *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
  173. ####################################################
  174. # NEW
  175. self.trf_blocks = nn.ModuleList(
  176. [TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
  177. self.ptr_current_pos = 0
  178. ####################################################
  179. self.final_norm = LayerNorm(cfg["emb_dim"])
  180. self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
  181. def forward(self, in_idx, use_cache=False):
  182. batch_size, seq_len = in_idx.shape
  183. tok_embeds = self.tok_emb(in_idx)
  184. # pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
  185. ####################################################
  186. # NEW
  187. if use_cache:
  188. pos_ids = torch.arange(self.ptr_current_pos, self.ptr_current_pos + seq_len, device=in_idx.device, dtype=torch.long)
  189. self.ptr_current_pos += seq_len
  190. else:
  191. pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long)
  192. pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
  193. ####################################################
  194. x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
  195. x = self.drop_emb(x)
  196. # x = self.trf_blocks(x)
  197. ####################################################
  198. # NEW
  199. for blk in self.trf_blocks:
  200. x = blk(x, use_cache=use_cache)
  201. ####################################################
  202. x = self.final_norm(x)
  203. logits = self.out_head(x)
  204. return logits
  205. ####################################################
  206. # NEW
  207. def reset_kv_cache(self):
  208. for blk in self.trf_blocks:
  209. blk.att.reset_cache()
  210. self.ptr_current_pos = 0
  211. ####################################################
  212. def generate_text_simple(model, idx, max_new_tokens, context_size):
  213. # idx is (B, T) array of indices in the current context
  214. for _ in range(max_new_tokens):
  215. # Crop current context if it exceeds the supported context size
  216. # E.g., if LLM supports only 5 tokens, and the context size is 10
  217. # then only the last 5 tokens are used as context
  218. idx_cond = idx[:, -context_size:]
  219. # Get the predictions
  220. with torch.no_grad():
  221. logits = model(idx_cond)
  222. # Focus only on the last time step
  223. # (batch, n_token, vocab_size) becomes (batch, vocab_size)
  224. logits = logits[:, -1, :]
  225. # Get the idx of the vocab entry with the highest logits value
  226. idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
  227. # Append sampled index to the running sequence
  228. idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
  229. return idx
  230. ####################################################
  231. # NEW
  232. def generate_text_simple_cached(model, idx, max_new_tokens, context_size=None, use_cache=True):
  233. model.eval()
  234. ctx_len = context_size or model.pos_emb.num_embeddings
  235. with torch.no_grad():
  236. if use_cache:
  237. model.reset_kv_cache()
  238. logits = model(idx[:, -ctx_len:], use_cache=True)
  239. for _ in range(max_new_tokens):
  240. next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
  241. idx = torch.cat([idx, next_idx], dim=1)
  242. logits = model(next_idx, use_cache=True)
  243. else:
  244. for _ in range(max_new_tokens):
  245. logits = model(idx[:, -ctx_len:], use_cache=False)
  246. next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
  247. idx = torch.cat([idx, next_idx], dim=1)
  248. return idx
  249. ####################################################
  250. def main():
  251. GPT_CONFIG_124M = {
  252. "vocab_size": 50257, # Vocabulary size
  253. "context_length": 1024, # Context length
  254. "emb_dim": 768, # Embedding dimension
  255. "n_heads": 12, # Number of attention heads
  256. "n_layers": 12, # Number of layers
  257. "drop_rate": 0.1, # Dropout rate
  258. "qkv_bias": False, # Query-Key-Value bias
  259. "kv_window_size": 1024 # NEW: KV cache window size
  260. }
  261. torch.manual_seed(123)
  262. model = GPTModel(GPT_CONFIG_124M)
  263. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  264. model.to(device)
  265. model.eval() # disable dropout
  266. start_context = "Hello, I am"
  267. tokenizer = tiktoken.get_encoding("gpt2")
  268. encoded = tokenizer.encode(start_context)
  269. encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0)
  270. print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
  271. print("\nInput text:", start_context)
  272. print("Encoded input text:", encoded)
  273. print("encoded_tensor.shape:", encoded_tensor.shape)
  274. if torch.cuda.is_available():
  275. torch.cuda.synchronize()
  276. start = time.time()
  277. # token_ids = generate_text_simple(
  278. # model=model,
  279. # idx=encoded_tensor,
  280. # max_new_tokens=200,
  281. # context_size=GPT_CONFIG_124M["context_length"]
  282. # )
  283. ####################################################
  284. # NEW
  285. token_ids = generate_text_simple_cached(
  286. model=model,
  287. idx=encoded_tensor,
  288. max_new_tokens=200,
  289. )
  290. ####################################################
  291. if torch.cuda.is_available():
  292. torch.cuda.synchronize()
  293. total_time = time.time() - start
  294. decoded_text = tokenizer.decode(token_ids.squeeze(0).tolist())
  295. print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}")
  296. print("\nOutput:", token_ids)
  297. print("Output length:", len(token_ids[0]))
  298. print("Output text:", decoded_text)
  299. print(f"\nTime: {total_time:.2f} sec")
  300. print(f"{int(len(token_ids[0])/total_time)} tokens/sec")
  301. if torch.cuda.is_available():
  302. max_mem_bytes = torch.cuda.max_memory_allocated()
  303. max_mem_gb = max_mem_bytes / (1024 ** 3)
  304. print(f"Max memory allocated: {max_mem_gb:.2f} GB")
  305. if __name__ == "__main__":
  306. main()