Sebastian Raschka 5 сар өмнө
parent
commit
2af686d70b

+ 1 - 0
README.md

@@ -112,6 +112,7 @@ Several folders contain optional materials as a bonus for interested readers:
   - [Understanding PyTorch Buffers](ch03/03_understanding-buffers/understanding-buffers.ipynb)
 - **Chapter 4: Implementing a GPT model from scratch**
   - [FLOPS Analysis](ch04/02_performance-analysis/flops-analysis.ipynb)
+  - [KV Cache](ch04/03_kv-cache)
 - **Chapter 5: Pretraining on unlabeled data:**
   - [Alternative Weight Loading Methods](ch05/02_alternative_weight_loading/)
   - [Pretraining GPT on the Project Gutenberg Dataset](ch05/03_bonus_pretraining_on_gutenberg)

+ 220 - 0
ch04/03_kv-cache/README.md

@@ -0,0 +1,220 @@
+# Bonus Material: KV Cache
+
+
+
+**This folder implements the addition of a KV cache to the GPT model.** 
+
+ 
+## Overview
+
+In short, a KV cache stores intermediate key (K) and value (V) computations for reuse during inference, which results in a substantial speed-up when generating responses. The downside is that it adds some complexity to the code, increases memory usage, and can't be used during training. However, the inference speed-ups are often well worth the trade-offs in code complexity and memory when deploying LLMs.
+
+ 
+## How it works
+
+Imagine the LLM is generating some text. Concretely, suppose the LLM is given the following prompt: "Time flies".
+
+The figure below shows an excerpt of the underlying attention score computation using a modified graphic from Chapter 3 with the key and value vectors highlighted:
+
+<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/kv-cache/kv-cache-attn-1.png?3" width=800>
+
+Now, as we learned in Chapters 2 and 4, LLMs generate one word (or token) at a time. Suppose the LLM generated the word "fast" so that the prompt for the next round becomes "Time flies fast". This is illustrated in the next figure below:
+
+<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/kv-cache/kv-cache-attn-2.png?3" width=800>
+
+As we can see, based on comparing the previous 2 figures, the keys, and value vectors for the first two tokens are exactly the same, and it would be wasteful to recompute them in each next-token text generation round.
+
+So, the idea of the KV cache is to implement a caching mechanism that stores the previously generated key and value vectors for reuse, which helps us to avoid unnecessary recomputations.
+
+&nbsp;
+
+## KV cache implementation
+
+There are many ways to implement a KV cache, with the main idea being that we only compute the key and value tensors for the newly generated tokens in each generation step.
+
+I opted for a simple one that emphasizes code readability. I think it's easiest to just scroll through the code changes to see how it's implemented.
+
+There are two files in this folder:
+
+1. [`gpt_ch04.py`](gpt_ch04.py): Self-contained code taken from Chapter 3 and 4 to implement the LLM and run the simple text generation function
+2. [`gpt_with_kv_cache.py`](gpt_with_kv_cache.py): The same as above, but with the necessary changes made to implement the KV cache. 
+
+You can either 
+
+a. Open the [`gpt_with_kv_cache.py`](gpt_with_kv_cache.py) file and look out for the `# NEW` sections that mark the new changes:
+
+<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/kv-cache/new-sections.png?3" width=800>
+
+b. Check out the two code files via a file diff tool of your choice to compare the changes:
+
+<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/kv-cache/file-diff.png?3" width=800>
+
+To summarize the implementation details, here's a short walkthrough.
+
+&nbsp;
+
+### 1. Registering the cache buffers
+
+Inside the `MultiHeadAttention` constructor we add two non-persistent buffers, `cache_k` and `cache_v`, which will hold concatenated keys and values across steps:
+
+```python
+self.register_buffer("cache_k", None, persistent=False)
+self.register_buffer("cache_v", None, persistent=False)
+```
+
+&nbsp;
+
+### 2. Forward pass with `use_cache` flag
+
+Next, we extend the `forward` method of the `MultiHeadAttention` class to accept `use_cache` argument. After projecting the new chunk of tokens into `keys_new`, `values_new` and `queries`, we either initialize the kv cache or append to our cache:
+
+```python
+def forward(self, x, use_cache=False):
+    b, num_tokens, d_in = x.shape
+
+    keys_new = self.W_key(x)  # Shape: (b, num_tokens, d_out)
+    values_new = self.W_value(x)
+    queries = self.W_query(x)
+    #...
+
+    if use_cache:
+        if self.cache_k is None:
+            self.cache_k, self.cache_v = keys_new, values_new
+        else:
+            self.cache_k = torch.cat([self.cache_k, keys_new], dim=1)
+            self.cache_v = torch.cat([self.cache_v, values_new], dim=1)
+        keys, values = self.cache_k, self.cache_v
+    else:
+        keys, values = keys_new, values_new
+```
+
+&nbsp;
+
+
+### 3. Clearing the cache
+
+When generating texts, between independent sequences (for instance to text generation calls) we must reset both buffers, so we also add a cache resetting method the to the `MultiHeadAttention` class:
+
+```python
+def reset_cache(self):
+    self.cache_k, self.cache_v = None, None
+```
+
+&nbsp;
+
+### 4. Propagating `use_cache` in the full model
+
+With the changes to the `MultiHeadAttention` class in plass, we now modify the  `GPTModel` class. First, we add a position tracking for the token indices to the instructor:
+
+```python
+self.current_pos = 0
+```
+
+Then, we replace the one-liner block call with an explicit loop, passing `use_cache` through each transformer block:
+
+```python
+def forward(self, in_idx, use_cache=False):
+    # ...
+ 
+    if use_cache:
+        pos_ids = torch.arange(
+            self.current_pos, self.current_pos + seq_len,            
+            device=in_idx.device, dtype=torch.long
+        )
+        self.current_pos += seq_len
+    else:
+        pos_ids = torch.arange(
+            0, seq_len, device=in_idx.device, dtype=torch.long
+        )
+    
+    pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
+    x = tok_embeds + pos_embeds
+    # ...
+    for blk in self.trf_blocks:
+        x = blk(x, use_cache=use_cache)
+```
+
+The above change then also requires a small modification to the `TransformerBlock` class to accept the `use_cache` argument:
+```python
+    def forward(self, x, use_cache=False):
+        # ...
+        self.att(x, use_cache=use_cache)
+```
+
+Lastly, we add a model-level reset to `GPTModel` to clear all block caches at once for our convenience:
+
+```python
+def reset_kv_cache(self):
+    for blk in self.trf_blocks:
+        blk.att.reset_cache()
+    self.current_pos = 0
+```
+
+&nbsp;
+
+### 5. Using the cache in generation
+
+With the changes to the `GPTModel`, `TransformerBlock`, and `MultiHeadAttention`, finally, here's how we use the KV cache in a simple text generation function:
+
+```python
+def generate_text_simple_cached(model, idx, max_new_tokens):
+    model.eval()
+    model.reset_kv_cache()
+
+    # Init cache with full prompt
+    logits = model(idx, use_cache=True)
+
+    for _ in range(max_new_tokens):
+        last_logits = logits[:, -1]
+        next_idx = last_logits.argmax(dim=-1, keepdim=True)
+        idx = torch.cat([idx, next_idx], dim=1)
+
+        logits = model(next_idx, use_cache=True)
+
+    return idx
+```
+
+&nbsp;
+
+## Simple performance comparison
+
+After covering the KV cache on a conceptual level, the big question is how well it actually performs in practice on a small example. To give the implementation a try, we can run the two aforementioned code files as Python scripts, which will run the small 124 M parameter LLM to generate 200 new tokens (given a 4-token prompt "Hello, I am" to start with):
+
+```bash
+pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/requirements.txt
+
+python gpt_ch04.py
+
+python gpt_with_kv_cache.py
+```
+
+On a Mac Mini with M4 chip (CPU), the results are as follows:
+
+|                         | Tokens/sec |
+| ----------------------- | ---------- |
+| `gpt_ch04.py`           | 27         |
+| `gpt_with_kv_cache.py`  | 110        |
+
+So, as we can see, we already get a ~5x speed-up with a small 124 M parameter model and a short 200-token sequence length. (Note that this implementation is optimized for code readability and not optimized for CUDA or MPS runtime speed, which would require pre-allocating tensors instead of reinstating and concatenating them.)
+
+**Note:** The model generates "gibberish" in both cases, i.e., text that looks like this: 
+
+> Output text: Hello, I am Featureiman Byeswickattribute argue logger Normandy Compton analogous bore ITVEGIN ministriesysics Kle functional recountrictionchangingVirgin embarrassedgl ...
+
+This is because we haven't trained the model, yet. The next chapter trains the model, and you can use the KV-cache on the trained model (however, the KV cache is only meant to be used during inference) to generate coherent text. Here, we are using the untrained model to keep the code simple(r).
+
+What's more important, though, is that both the `gpt_ch04.py` and `gpt_with_kv_cache.py` implementations produce exactly the same text. This tells us that the KV cache is implemented correctly -- it is easy to make indexing mistakes that can lead to divergent results.
+
+
+&nbsp;
+
+## KV cache advantages and disadvantages 
+
+As sequence length increases, the benefits and downsides of a KV cache become more pronounced in the following ways:
+
+- [Good] **Computational efficiency increases**: Without caching, the attention at step *t* must compare the new query with *t* previous keys, so the cumulative work scales quadratically, O(n²). With a cache, each key and value is computed once and then reused, reducing the total per-step complexity to linear, O(n).
+
+- [Bad] **Memory usage increases linearly**: Each new token appends to the KV cache. For long sequences and larger LLMs, the cumulative KV cache grows larger, which can consume a significant or even prohibitive amount of (GPU) memory. As a workaround, we can truncate the KV cache, but this adds even more complexity (but again, it may well be worth it when deploying LLMs.)
+
+
+

+ 257 - 0
ch04/03_kv-cache/gpt_ch04.py

@@ -0,0 +1,257 @@
+# This file collects all the relevant code that we covered thus far
+# throughout Chapters 3-4.
+# This file can be run as a standalone script.
+
+import time
+import tiktoken
+import torch
+import torch.nn as nn
+
+
+#####################################
+# Chapter 3
+#####################################
+class MultiHeadAttention(nn.Module):
+    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
+        super().__init__()
+        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
+
+        self.d_out = d_out
+        self.num_heads = num_heads
+        self.head_dim = d_out // num_heads  # Reduce the projection dim to match desired output dim
+
+        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
+        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
+        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
+        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
+        self.dropout = nn.Dropout(dropout)
+        self.register_buffer(
+            "mask",
+            torch.triu(torch.ones(context_length, context_length),diagonal=1),
+            persistent=False
+        )
+
+    def forward(self, x):
+        b, num_tokens, d_in = x.shape
+
+        keys = self.W_key(x)  # Shape: (b, num_tokens, d_out)
+        values = self.W_value(x)
+        queries = self.W_query(x)
+
+        # We implicitly split the matrix by adding a `num_heads` dimension
+        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
+        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
+        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
+        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
+
+        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
+        keys = keys.transpose(1, 2)
+        queries = queries.transpose(1, 2)
+        values = values.transpose(1, 2)
+
+        # Compute scaled dot-product attention (aka self-attention) with a causal mask
+        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head
+
+        # Original mask truncated to the number of tokens and converted to boolean
+        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
+
+        # Use the mask to fill attention scores
+        attn_scores.masked_fill_(mask_bool, -torch.inf)
+
+        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
+        attn_weights = self.dropout(attn_weights)
+
+        # Shape: (b, num_tokens, num_heads, head_dim)
+        context_vec = (attn_weights @ values).transpose(1, 2)
+
+        # Combine heads, where self.d_out = self.num_heads * self.head_dim
+        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
+        context_vec = self.out_proj(context_vec)  # optional projection
+
+        return context_vec
+
+
+#####################################
+# Chapter 4
+#####################################
+class LayerNorm(nn.Module):
+    def __init__(self, emb_dim):
+        super().__init__()
+        self.eps = 1e-5
+        self.scale = nn.Parameter(torch.ones(emb_dim))
+        self.shift = nn.Parameter(torch.zeros(emb_dim))
+
+    def forward(self, x):
+        mean = x.mean(dim=-1, keepdim=True)
+        var = x.var(dim=-1, keepdim=True, unbiased=False)
+        norm_x = (x - mean) / torch.sqrt(var + self.eps)
+        return self.scale * norm_x + self.shift
+
+
+class GELU(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x):
+        return 0.5 * x * (1 + torch.tanh(
+            torch.sqrt(torch.tensor(2.0 / torch.pi)) *
+            (x + 0.044715 * torch.pow(x, 3))
+        ))
+
+
+class FeedForward(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self.layers = nn.Sequential(
+            nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
+            GELU(),
+            nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
+        )
+
+    def forward(self, x):
+        return self.layers(x)
+
+
+class TransformerBlock(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self.att = MultiHeadAttention(
+            d_in=cfg["emb_dim"],
+            d_out=cfg["emb_dim"],
+            context_length=cfg["context_length"],
+            num_heads=cfg["n_heads"],
+            dropout=cfg["drop_rate"],
+            qkv_bias=cfg["qkv_bias"])
+        self.ff = FeedForward(cfg)
+        self.norm1 = LayerNorm(cfg["emb_dim"])
+        self.norm2 = LayerNorm(cfg["emb_dim"])
+        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
+
+    def forward(self, x):
+        # Shortcut connection for attention block
+        shortcut = x
+        x = self.norm1(x)
+        x = self.att(x)   # Shape [batch_size, num_tokens, emb_size]
+        x = self.drop_shortcut(x)
+        x = x + shortcut  # Add the original input back
+
+        # Shortcut connection for feed-forward block
+        shortcut = x
+        x = self.norm2(x)
+        x = self.ff(x)
+        x = self.drop_shortcut(x)
+        x = x + shortcut  # Add the original input back
+
+        return x
+
+
+class GPTModel(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
+        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
+        self.drop_emb = nn.Dropout(cfg["drop_rate"])
+
+        self.trf_blocks = nn.Sequential(
+            *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
+
+        self.final_norm = LayerNorm(cfg["emb_dim"])
+        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
+
+    def forward(self, in_idx):
+        batch_size, seq_len = in_idx.shape
+        tok_embeds = self.tok_emb(in_idx)
+        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
+        x = tok_embeds + pos_embeds  # Shape [batch_size, num_tokens, emb_size]
+        x = self.drop_emb(x)
+        x = self.trf_blocks(x)
+        x = self.final_norm(x)
+        logits = self.out_head(x)
+        return logits
+
+
+def generate_text_simple(model, idx, max_new_tokens, context_size):
+    # idx is (B, T) array of indices in the current context
+    for _ in range(max_new_tokens):
+
+        # Crop current context if it exceeds the supported context size
+        # E.g., if LLM supports only 5 tokens, and the context size is 10
+        # then only the last 5 tokens are used as context
+        idx_cond = idx[:, -context_size:]
+
+        # Get the predictions
+        with torch.no_grad():
+            logits = model(idx_cond)
+
+        # Focus only on the last time step
+        # (batch, n_token, vocab_size) becomes (batch, vocab_size)
+        logits = logits[:, -1, :]
+
+        # Get the idx of the vocab entry with the highest logits value
+        idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # (batch, 1)
+
+        # Append sampled index to the running sequence
+        idx = torch.cat((idx, idx_next), dim=1)  # (batch, n_tokens+1)
+
+    return idx
+
+
+def main():
+    GPT_CONFIG_124M = {
+        "vocab_size": 50257,     # Vocabulary size
+        "context_length": 1024,  # Context length
+        "emb_dim": 768,          # Embedding dimension
+        "n_heads": 12,           # Number of attention heads
+        "n_layers": 12,          # Number of layers
+        "drop_rate": 0.1,        # Dropout rate
+        "qkv_bias": False        # Query-Key-Value bias
+    }
+
+    torch.manual_seed(123)
+    model = GPTModel(GPT_CONFIG_124M)
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    model.to(device)
+    model.eval()  # disable dropout
+
+    start_context = "Hello, I am"
+
+    tokenizer = tiktoken.get_encoding("gpt2")
+    encoded = tokenizer.encode(start_context)
+    encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0)
+
+    print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
+    print("\nInput text:", start_context)
+    print("Encoded input text:", encoded)
+    print("encoded_tensor.shape:", encoded_tensor.shape)
+
+    if torch.cuda.is_available():
+        torch.cuda.synchronize()
+    start = time.time()
+
+    token_ids = generate_text_simple(
+        model=model,
+        idx=encoded_tensor,
+        max_new_tokens=200,
+        context_size=GPT_CONFIG_124M["context_length"]
+    )
+    if torch.cuda.is_available():
+        torch.cuda.synchronize()
+    total_time = time.time() - start
+
+    decoded_text = tokenizer.decode(token_ids.squeeze(0).tolist())
+
+    print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}")
+    print("\nOutput:", token_ids)
+    print("Output length:", len(token_ids[0]))
+    print("Output text:", decoded_text)
+
+    print(f"\nTime: {total_time:.2f} sec")
+    print(f"{int(len(token_ids[0])/total_time)} tokens/sec")
+    if torch.cuda.is_available():
+        max_mem_bytes = torch.cuda.max_memory_allocated()
+        max_mem_gb = max_mem_bytes / (1024 ** 3)
+        print(f"Max memory allocated: {max_mem_gb:.2f} GB")
+
+
+if __name__ == "__main__":
+    main()

+ 353 - 0
ch04/03_kv-cache/gpt_with_kv_cache.py

@@ -0,0 +1,353 @@
+# This file collects all the relevant code that we covered thus far
+# throughout Chapters 3-4.
+# This file can be run as a standalone script.
+
+import time
+import tiktoken
+import torch
+import torch.nn as nn
+
+
+#####################################
+# Chapter 3
+#####################################
+class MultiHeadAttention(nn.Module):
+    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
+        super().__init__()
+        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
+
+        self.d_out = d_out
+        self.num_heads = num_heads
+        self.head_dim = d_out // num_heads  # Reduce the projection dim to match desired output dim
+
+        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
+        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
+        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
+        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
+        self.dropout = nn.Dropout(dropout)
+        self.register_buffer(
+            "mask",
+            torch.triu(torch.ones(context_length, context_length),diagonal=1),
+            persistent=False
+        )
+
+        ####################################################
+        # NEW
+        self.register_buffer("cache_k", None, persistent=False)
+        self.register_buffer("cache_v", None, persistent=False)
+        ####################################################
+
+    def forward(self, x, use_cache=False):
+        b, num_tokens, d_in = x.shape
+
+        keys_new = self.W_key(x)  # Shape: (b, num_tokens, d_out)
+        values_new = self.W_value(x)
+        queries = self.W_query(x)
+
+        # We implicitly split the matrix by adding a `num_heads` dimension
+        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
+        keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim)
+        values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim)
+        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
+
+        ####################################################
+        # NEW
+        if use_cache:
+            if self.cache_k is None:
+                self.cache_k, self.cache_v = keys_new, values_new
+            else:
+                self.cache_k = torch.cat([self.cache_k, keys_new], dim=1)
+                self.cache_v = torch.cat([self.cache_v, values_new], dim=1)
+            keys, values = self.cache_k, self.cache_v
+        else:
+            keys, values = keys_new, values_new
+        ####################################################
+
+        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
+        keys = keys.transpose(1, 2)
+        queries = queries.transpose(1, 2)
+        values = values.transpose(1, 2)
+
+        # Compute scaled dot-product attention (aka self-attention) with a causal mask
+        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head
+
+        # Original mask truncated to the number of tokens and converted to boolean
+        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
+
+        # Use the mask to fill attention scores
+        attn_scores.masked_fill_(mask_bool, -torch.inf)
+
+        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
+        attn_weights = self.dropout(attn_weights)
+
+        # Shape: (b, num_tokens, num_heads, head_dim)
+        context_vec = (attn_weights @ values).transpose(1, 2)
+
+        # Combine heads, where self.d_out = self.num_heads * self.head_dim
+        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
+        context_vec = self.out_proj(context_vec)  # optional projection
+
+        return context_vec
+
+    ####################################################
+    # NEW
+    def reset_cache(self):
+        self.cache_k, self.cache_v = None, None
+    ####################################################
+
+
+#####################################
+# Chapter 4
+#####################################
+class LayerNorm(nn.Module):
+    def __init__(self, emb_dim):
+        super().__init__()
+        self.eps = 1e-5
+        self.scale = nn.Parameter(torch.ones(emb_dim))
+        self.shift = nn.Parameter(torch.zeros(emb_dim))
+
+    def forward(self, x):
+        mean = x.mean(dim=-1, keepdim=True)
+        var = x.var(dim=-1, keepdim=True, unbiased=False)
+        norm_x = (x - mean) / torch.sqrt(var + self.eps)
+        return self.scale * norm_x + self.shift
+
+
+class GELU(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x):
+        return 0.5 * x * (1 + torch.tanh(
+            torch.sqrt(torch.tensor(2.0 / torch.pi)) *
+            (x + 0.044715 * torch.pow(x, 3))
+        ))
+
+
+class FeedForward(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self.layers = nn.Sequential(
+            nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
+            GELU(),
+            nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
+        )
+
+    def forward(self, x):
+        return self.layers(x)
+
+
+class TransformerBlock(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self.att = MultiHeadAttention(
+            d_in=cfg["emb_dim"],
+            d_out=cfg["emb_dim"],
+            context_length=cfg["context_length"],
+            num_heads=cfg["n_heads"],
+            dropout=cfg["drop_rate"],
+            qkv_bias=cfg["qkv_bias"])
+        self.ff = FeedForward(cfg)
+        self.norm1 = LayerNorm(cfg["emb_dim"])
+        self.norm2 = LayerNorm(cfg["emb_dim"])
+        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
+
+    def forward(self, x, use_cache=False):
+        # Shortcut connection for attention block
+        shortcut = x
+        x = self.norm1(x)
+
+        # x = self.att(x)   # Shape [batch_size, num_tokens, emb_size]
+        ####################################################
+        # NEW
+        x = self.att(x, use_cache=use_cache)
+        ####################################################
+
+        x = self.drop_shortcut(x)
+        x = x + shortcut  # Add the original input back
+
+        # Shortcut connection for feed-forward block
+        shortcut = x
+        x = self.norm2(x)
+        x = self.ff(x)
+        x = self.drop_shortcut(x)
+        x = x + shortcut  # Add the original input back
+
+        return x
+
+
+class GPTModel(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
+        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
+        self.drop_emb = nn.Dropout(cfg["drop_rate"])
+
+        # self.trf_blocks = nn.Sequential(
+        #    *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
+        ####################################################
+        # NEW
+        self.trf_blocks = nn.ModuleList(
+            [TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
+
+        self.current_pos = 0
+        ####################################################
+
+        self.final_norm = LayerNorm(cfg["emb_dim"])
+        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
+
+    def forward(self, in_idx, use_cache=False):
+        batch_size, seq_len = in_idx.shape
+        tok_embeds = self.tok_emb(in_idx)
+
+        # pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
+
+        ####################################################
+        # NEW
+
+        if use_cache:
+            pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long)
+            self.current_pos += seq_len
+        else:
+            pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long)
+        pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
+        ####################################################
+
+        x = tok_embeds + pos_embeds  # Shape [batch_size, num_tokens, emb_size]
+        x = self.drop_emb(x)
+
+        # x = self.trf_blocks(x)
+        ####################################################
+        # NEW
+        for blk in self.trf_blocks:
+            x = blk(x, use_cache=use_cache)
+        ####################################################
+
+        x = self.final_norm(x)
+        logits = self.out_head(x)
+        return logits
+
+    ####################################################
+    # NEW
+    def reset_kv_cache(self):
+        for blk in self.trf_blocks:
+            blk.att.reset_cache()
+        self.current_pos = 0
+    ####################################################
+
+
+def generate_text_simple(model, idx, max_new_tokens, context_size):
+    # idx is (B, T) array of indices in the current context
+    for _ in range(max_new_tokens):
+
+        # Crop current context if it exceeds the supported context size
+        # E.g., if LLM supports only 5 tokens, and the context size is 10
+        # then only the last 5 tokens are used as context
+        idx_cond = idx[:, -context_size:]
+
+        # Get the predictions
+        with torch.no_grad():
+            logits = model(idx_cond)
+
+        # Focus only on the last time step
+        # (batch, n_token, vocab_size) becomes (batch, vocab_size)
+        logits = logits[:, -1, :]
+
+        # Get the idx of the vocab entry with the highest logits value
+        idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # (batch, 1)
+
+        # Append sampled index to the running sequence
+        idx = torch.cat((idx, idx_next), dim=1)  # (batch, n_tokens+1)
+
+    return idx
+
+
+####################################################
+# NEW
+def generate_text_simple_cached(model, idx, max_new_tokens):
+    model.eval()
+    model.reset_kv_cache()
+
+    # Init cache with full prompt
+    logits = model(idx, use_cache=True)
+
+    for _ in range(max_new_tokens):
+        last_logits = logits[:, -1]
+        next_idx = last_logits.argmax(dim=-1, keepdim=True)
+        idx = torch.cat([idx, next_idx], dim=1)
+
+        logits = model(next_idx, use_cache=True)
+
+    return idx
+####################################################
+
+
+def main():
+    GPT_CONFIG_124M = {
+        "vocab_size": 50257,     # Vocabulary size
+        "context_length": 1024,  # Context length
+        "emb_dim": 768,          # Embedding dimension
+        "n_heads": 12,           # Number of attention heads
+        "n_layers": 12,          # Number of layers
+        "drop_rate": 0.1,        # Dropout rate
+        "qkv_bias": False        # Query-Key-Value bias
+    }
+
+    torch.manual_seed(123)
+    model = GPTModel(GPT_CONFIG_124M)
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    model.to(device)
+    model.eval()  # disable dropout
+
+    start_context = "Hello, I am"
+
+    tokenizer = tiktoken.get_encoding("gpt2")
+    encoded = tokenizer.encode(start_context)
+    encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0)
+
+    print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
+    print("\nInput text:", start_context)
+    print("Encoded input text:", encoded)
+    print("encoded_tensor.shape:", encoded_tensor.shape)
+
+    if torch.cuda.is_available():
+        torch.cuda.synchronize()
+    start = time.time()
+
+    # token_ids = generate_text_simple(
+    #     model=model,
+    #     idx=encoded_tensor,
+    #     max_new_tokens=200,
+    #     context_size=GPT_CONFIG_124M["context_length"]
+    # )
+
+    ####################################################
+    # NEW
+    token_ids = generate_text_simple_cached(
+        model=model,
+        idx=encoded_tensor,
+        max_new_tokens=200,
+    )
+    ####################################################
+
+    if torch.cuda.is_available():
+        torch.cuda.synchronize()
+    total_time = time.time() - start
+
+    decoded_text = tokenizer.decode(token_ids.squeeze(0).tolist())
+
+    print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}")
+    print("\nOutput:", token_ids)
+    print("Output length:", len(token_ids[0]))
+    print("Output text:", decoded_text)
+
+    print(f"\nTime: {total_time:.2f} sec")
+    print(f"{int(len(token_ids[0])/total_time)} tokens/sec")
+    if torch.cuda.is_available():
+        max_mem_bytes = torch.cuda.max_memory_allocated()
+        max_mem_gb = max_mem_bytes / (1024 ** 3)
+        print(f"Max memory allocated: {max_mem_gb:.2f} GB")
+
+
+if __name__ == "__main__":
+    main()

+ 1 - 0
ch04/README.md

@@ -9,6 +9,7 @@
 ## Bonus Materials
 
 - [02_performance-analysis](02_performance-analysis) contains optional code analyzing the performance of the GPT model(s) implemented in the main chapter
+- [03_kv-cache](03_kv-cache) implements a KV cache to speed up the text generation during inference
 - [ch05/07_gpt_to_llama](../ch05/07_gpt_to_llama) contains a step-by-step guide for converting a GPT architecture implementation to Llama 3.2 and loads pretrained weights from Meta AI (it might be interesting to look at alternative architectures after completing chapter 4, but you can also save that for after reading chapter 5)