Bladeren bron

Fix bug in masking when kv cache is used. (#697)

* Fix bug in masking when kv cache is used.

* add tests

* dd tests

* upd

* add kv cache test to gh workflow

* explicit mask slicing

* upd

---------

Co-authored-by: rasbt <mail@sebastianraschka.com>
Martin Ma 5 maanden geleden
bovenliggende
commit
6522be94be

+ 1 - 0
.github/workflows/basic-tests-linux-uv.yml

@@ -49,6 +49,7 @@ jobs:
           source .venv/bin/activate
           pytest --ruff setup/02_installing-python-libraries/tests.py
           pytest --ruff ch04/01_main-chapter-code/tests.py
+          pytest --ruff ch04/03_kv-cache/tests.py
           pytest --ruff ch05/01_main-chapter-code/tests.py
           pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
           pytest --ruff ch06/01_main-chapter-code/tests.py

+ 30 - 18
ch04/03_kv-cache/README.md

@@ -86,6 +86,18 @@ def forward(self, x, use_cache=False):
         keys, values = self.cache_k, self.cache_v
     else:
         keys, values = keys_new, values_new
+        
+    # ...
+    
+    num_tokens_Q = queries.shape[-2]
+    num_tokens_K = keys.shape[-2]
+    if use_cache:
+        mask_bool = self.mask.bool()[
+            self.ptr_current_pos:self.ptr_current_pos + num_tokens_Q, :num_tokens_K
+        ]
+        self.ptr_current_pos += num_tokens_Q
+    else:
+        mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K]
 ```
 
 &nbsp;
@@ -98,6 +110,7 @@ When generating texts, between independent sequences (for instance to text gener
 ```python
 def reset_cache(self):
     self.cache_k, self.cache_v = None, None
+    self.ptr_current_pos = 0
 ```
 
 &nbsp;
@@ -157,30 +170,29 @@ def reset_kv_cache(self):
 With the changes to the `GPTModel`, `TransformerBlock`, and `MultiHeadAttention`, finally, here's how we use the KV cache in a simple text generation function:
 
 ```python
-def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
+def generate_text_simple_cached(model, idx, max_new_tokens, 
+                                context_size=None, use_cache=True):
     model.eval()
+    ctx_len = context_size or model.pos_emb.num_embeddings
 
-    ctx_len = model.pos_emb.num_embeddings  # max supported length, e.g. 1024
-    if use_cache:
-        # Init cache with full prompt
-        model.reset_kv_cache()
-        with torch.no_grad():
+    with torch.no_grad():
+        if use_cache:
+            # Init cache with full prompt
+            model.reset_kv_cache()
             logits = model(idx[:, -ctx_len:], use_cache=True)
 
-        for _ in range(max_new_tokens):
-            # a) pick the token with the highest log-probability (greedy sampling)
-            next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
-            # b) append it to the running sequence
-            idx = torch.cat([idx, next_idx], dim=1)
-            # c) feed model only the new token
-            with torch.no_grad():
+            for _ in range(max_new_tokens):
+                # a) pick the token with the highest log-probability (greedy sampling)
+                next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
+                # b) append it to the running sequence
+                idx = torch.cat([idx, next_idx], dim=1)
+                # c) feed model only the new token
                 logits = model(next_idx, use_cache=True)
-    else:
-        for _ in range(max_new_tokens):
-            with torch.no_grad():
+        else:
+            for _ in range(max_new_tokens):
                 logits = model(idx[:, -ctx_len:], use_cache=False)
-            next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
-            idx = torch.cat([idx, next_idx], dim=1)
+                next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
+                idx = torch.cat([idx, next_idx], dim=1)
 
     return idx
 ```

+ 32 - 20
ch04/03_kv-cache/gpt_with_kv_cache.py

@@ -27,7 +27,7 @@ class MultiHeadAttention(nn.Module):
         self.dropout = nn.Dropout(dropout)
         self.register_buffer(
             "mask",
-            torch.triu(torch.ones(context_length, context_length),diagonal=1),
+            torch.triu(torch.ones(context_length, context_length), diagonal=1),
             persistent=False
         )
 
@@ -35,6 +35,7 @@ class MultiHeadAttention(nn.Module):
         # NEW
         self.register_buffer("cache_k", None, persistent=False)
         self.register_buffer("cache_v", None, persistent=False)
+        self.ptr_current_pos = 0
         ####################################################
 
     def forward(self, x, use_cache=False):
@@ -71,8 +72,19 @@ class MultiHeadAttention(nn.Module):
         # Compute scaled dot-product attention (aka self-attention) with a causal mask
         attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head
 
+        ####################################################
+        # NEW
+        num_tokens_Q = queries.shape[-2]
+        num_tokens_K = keys.shape[-2]
+        if use_cache:
+            mask_bool = self.mask.bool()[
+                self.ptr_current_pos:self.ptr_current_pos + num_tokens_Q, :num_tokens_K
+            ]
+            self.ptr_current_pos += num_tokens_Q
+        ####################################################
         # Original mask truncated to the number of tokens and converted to boolean
-        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
+        else:
+            mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K]
 
         # Use the mask to fill attention scores
         attn_scores.masked_fill_(mask_bool, -torch.inf)
@@ -93,6 +105,7 @@ class MultiHeadAttention(nn.Module):
     # NEW
     def reset_cache(self):
         self.cache_k, self.cache_v = None, None
+        self.ptr_current_pos = 0
     ####################################################
 
 
@@ -264,30 +277,29 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
 
 ####################################################
 # NEW
-def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
+def generate_text_simple_cached(model, idx, max_new_tokens,
+                                context_size=None, use_cache=True):
     model.eval()
+    ctx_len = context_size or model.pos_emb.num_embeddings
 
-    ctx_len = model.pos_emb.num_embeddings  # max supported length, e.g. 1024
-    if use_cache:
-        # Init cache with full prompt
-        model.reset_kv_cache()
-        with torch.no_grad():
+    with torch.no_grad():
+        if use_cache:
+            # Init cache with full prompt
+            model.reset_kv_cache()
             logits = model(idx[:, -ctx_len:], use_cache=True)
 
-        for _ in range(max_new_tokens):
-            # a) pick the token with the highest log-probability (greedy sampling)
-            next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
-            # b) append it to the running sequence
-            idx = torch.cat([idx, next_idx], dim=1)
-            # c) feed model only the new token
-            with torch.no_grad():
+            for _ in range(max_new_tokens):
+                # a) pick the token with the highest log-probability (greedy sampling)
+                next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
+                # b) append it to the running sequence
+                idx = torch.cat([idx, next_idx], dim=1)
+                # c) feed model only the new token
                 logits = model(next_idx, use_cache=True)
-    else:
-        for _ in range(max_new_tokens):
-            with torch.no_grad():
+        else:
+            for _ in range(max_new_tokens):
                 logits = model(idx[:, -ctx_len:], use_cache=False)
-            next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
-            idx = torch.cat([idx, next_idx], dim=1)
+                next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
+                idx = torch.cat([idx, next_idx], dim=1)
 
     return idx
 ####################################################

+ 15 - 19
ch04/03_kv-cache/gpt_with_kv_cache_optimized.py

@@ -171,7 +171,8 @@ class TransformerBlock(nn.Module):
             num_heads=cfg["n_heads"],
             dropout=cfg["drop_rate"],
             qkv_bias=cfg["qkv_bias"],
-            window_size=cfg["kv_window_size"])  # NEW
+            window_size=cfg["kv_window_size"] if "kv_window_size" in cfg else cfg["context_length"]   # NEW
+        )
         self.ff = FeedForward(cfg)
         self.norm1 = LayerNorm(cfg["emb_dim"])
         self.norm2 = LayerNorm(cfg["emb_dim"])
@@ -289,30 +290,25 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
 
 ####################################################
 # NEW
-def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
+def generate_text_simple_cached(model, idx, max_new_tokens, context_size=None, use_cache=True):
     model.eval()
 
-    ctx_len = model.pos_emb.num_embeddings  # max supported length, e.g. 1024
-    if use_cache:
-        # Init cache with full prompt
-        model.reset_kv_cache()
-        with torch.no_grad():
+    ctx_len = context_size or model.pos_emb.num_embeddings
+
+    with torch.no_grad():
+        if use_cache:
+            model.reset_kv_cache()
             logits = model(idx[:, -ctx_len:], use_cache=True)
 
-        for _ in range(max_new_tokens):
-            # a) pick the token with the highest log-probability (greedy sampling)
-            next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
-            # b) append it to the running sequence
-            idx = torch.cat([idx, next_idx], dim=1)
-            # c) feed model only the new token
-            with torch.no_grad():
+            for _ in range(max_new_tokens):
+                next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
+                idx = torch.cat([idx, next_idx], dim=1)
                 logits = model(next_idx, use_cache=True)
-    else:
-        for _ in range(max_new_tokens):
-            with torch.no_grad():
+        else:
+            for _ in range(max_new_tokens):
                 logits = model(idx[:, -ctx_len:], use_cache=False)
-            next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
-            idx = torch.cat([idx, next_idx], dim=1)
+                next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
+                idx = torch.cat([idx, next_idx], dim=1)
 
     return idx
 ####################################################

+ 101 - 0
ch04/03_kv-cache/tests.py

@@ -0,0 +1,101 @@
+# Code to test the GPT model implementation against the KV cache variants
+
+import pytest
+import torch
+import tiktoken
+
+from gpt_ch04 import GPTModel as GPTModelBase
+from gpt_ch04 import generate_text_simple
+
+from gpt_with_kv_cache import GPTModel as GPTModelKV1
+from gpt_with_kv_cache_optimized import GPTModel as GPTModelKV2
+from gpt_with_kv_cache import generate_text_simple_cached
+
+
+GPT_CONFIG_124M = {
+    "vocab_size": 50257,
+    "context_length": 1024,
+    "emb_dim": 768,
+    "n_heads": 12,
+    "n_layers": 12,
+    "drop_rate": 0.1,
+    "qkv_bias": False,
+}
+
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+@pytest.mark.parametrize("ModelClass", [GPTModelBase, GPTModelKV1, GPTModelKV2])
+def test_gpt_model_equivalence_not_cached(ModelClass):
+    torch.manual_seed(123)
+
+    model = ModelClass(GPT_CONFIG_124M).to(device)
+    model.eval()
+
+    tokenizer = tiktoken.get_encoding("gpt2")
+    prompt = "Hello, I am"
+    encoded = tokenizer.encode(prompt)
+    encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0)
+
+    model_name = ModelClass.__module__ + "." + ModelClass.__name__
+
+    token_ids = generate_text_simple(
+        model=model,
+        idx=encoded_tensor,
+        max_new_tokens=30,
+        context_size=GPT_CONFIG_124M["context_length"]
+    )
+
+    if not hasattr(test_gpt_model_equivalence_not_cached, "results"):
+        test_gpt_model_equivalence_not_cached.results = []
+
+    test_gpt_model_equivalence_not_cached.results.append((model_name, token_ids))
+
+    if len(test_gpt_model_equivalence_not_cached.results) == 3:
+        base_name, base_output = test_gpt_model_equivalence_not_cached.results[0]
+        for other_name, other_output in test_gpt_model_equivalence_not_cached.results[1:]:
+            assert torch.equal(base_output, other_output), (
+                f"Mismatch between {base_name} and {other_name}"
+            )
+
+
+@pytest.mark.parametrize("ModelClass", [GPTModelBase, GPTModelKV1, GPTModelKV2])
+def test_gpt_model_equivalence_cached(ModelClass):
+    torch.manual_seed(123)
+
+    model = ModelClass(GPT_CONFIG_124M).to(device)
+    model.eval()
+
+    tokenizer = tiktoken.get_encoding("gpt2")
+    prompt = "Hello, I am"
+    encoded_tensor = torch.tensor(tokenizer.encode(prompt), device=device).unsqueeze(0)
+
+    model_name = ModelClass.__module__ + "." + ModelClass.__name__
+
+    if ModelClass is GPTModelBase:
+        token_ids = generate_text_simple(
+            model=model,
+            idx=encoded_tensor,
+            max_new_tokens=30,
+            context_size=GPT_CONFIG_124M["context_length"]
+        )
+    else:
+        token_ids = generate_text_simple_cached(
+            model=model,
+            idx=encoded_tensor,
+            max_new_tokens=30,
+            context_size=GPT_CONFIG_124M["context_length"]
+        )
+
+    if not hasattr(test_gpt_model_equivalence_cached, "results"):
+        test_gpt_model_equivalence_cached.results = []
+
+    test_gpt_model_equivalence_cached.results.append((model_name, token_ids))
+
+    if len(test_gpt_model_equivalence_cached.results) == 3:
+        base_name, base_output = test_gpt_model_equivalence_cached.results[0]
+        for other_name, other_output in test_gpt_model_equivalence_cached.results[1:]:
+            assert torch.equal(base_output, other_output), (
+                f"Mismatch between {base_name} and {other_name}"
+            )