Browse Source

Add GPTModelFast (#584)

* Add GPTModelFast

* update
Sebastian Raschka 7 months ago
parent
commit
ffd4035144

+ 4 - 2
pkg/llms_from_scratch/README.md

@@ -50,11 +50,12 @@ Once installed, you can import code from any chapter using:
 from llms_from_scratch.ch02 import GPTDatasetV1, create_dataloader_v1
 
 from llms_from_scratch.ch03 import (
-    MultiHeadAttention,
     SelfAttention_v1,
     SelfAttention_v2,
     CausalAttention,
-    MultiHeadAttentionWrapper
+    MultiHeadAttentionWrapper,
+    MultiHeadAttention,
+    PyTorchMultiHeadAttention # Bonus: Faster variant using PyTorch's scaled_dot_product_attention
 )
 
 from llms_from_scratch.ch04 import (
@@ -63,6 +64,7 @@ from llms_from_scratch.ch04 import (
     FeedForward,
     TransformerBlock,
     GPTModel,
+    GPTModelFast # Bonus: Faster variant using PyTorch's scaled_dot_product_attention
     generate_text_simple
 )
 

+ 47 - 0
pkg/llms_from_scratch/ch03.py

@@ -149,3 +149,50 @@ class MultiHeadAttention(nn.Module):
         context_vec = self.out_proj(context_vec)  # optional projection
 
         return context_vec
+
+
+######################
+# Bonus
+######################
+
+
+class PyTorchMultiHeadAttention(nn.Module):
+    def __init__(self, d_in, d_out, num_heads, dropout=0.0, qkv_bias=False):
+        super().__init__()
+
+        assert d_out % num_heads == 0, "embed_dim is indivisible by num_heads"
+
+        self.num_heads = num_heads
+        self.head_dim = d_out // num_heads
+        self.d_out = d_out
+
+        self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)
+        self.proj = nn.Linear(d_out, d_out)
+        self.dropout = dropout
+
+    def forward(self, x):
+        batch_size, num_tokens, embed_dim = x.shape
+
+        # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)
+        qkv = self.qkv(x)
+
+        # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)
+        qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)
+
+        # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)
+        qkv = qkv.permute(2, 0, 3, 1, 4)
+
+        # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)
+        queries, keys, values = qkv
+
+        use_dropout = 0. if not self.training else self.dropout
+
+        context_vec = nn.functional.scaled_dot_product_attention(
+            queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)
+
+        # Combine heads, where self.d_out = self.num_heads * self.head_dim
+        context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)
+
+        context_vec = self.proj(context_vec)
+
+        return context_vec

+ 88 - 1
pkg/llms_from_scratch/ch04.py

@@ -3,7 +3,7 @@
 #   - https://www.manning.com/books/build-a-large-language-model-from-scratch
 # Code: https://github.com/rasbt/LLMs-from-scratch
 
-from .ch03 import MultiHeadAttention
+from .ch03 import MultiHeadAttention, PyTorchMultiHeadAttention
 import torch
 import torch.nn as nn
 
@@ -128,3 +128,90 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
         idx = torch.cat((idx, idx_next), dim=1)  # (batch, n_tokens+1)
 
     return idx
+
+######################
+# Bonus
+######################
+
+
+class FeedForwardFast(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self.layers = nn.Sequential(
+            nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
+            nn.GELU(approximate="tanh"),
+            nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
+        )
+
+    def forward(self, x):
+        return self.layers(x)
+
+
+class TransformerBlockFast(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self.att = PyTorchMultiHeadAttention(
+            d_in=cfg["emb_dim"],
+            d_out=cfg["emb_dim"],
+            num_heads=cfg["n_heads"],
+            dropout=cfg["drop_rate"],
+            qkv_bias=cfg["qkv_bias"])
+        self.ff = FeedForwardFast(cfg)
+        self.norm1 = nn.LayerNorm(cfg["emb_dim"])
+        self.norm2 = nn.LayerNorm(cfg["emb_dim"])
+        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
+
+    def forward(self, x):
+        # Shortcut connection for attention block
+        shortcut = x
+        x = self.norm1(x)
+        x = self.att(x)   # Shape [batch_size, num_tokens, emb_size]
+        x = self.drop_shortcut(x)
+        x = x + shortcut  # Add the original input back
+
+        # Shortcut connection for feed-forward block
+        shortcut = x
+        x = self.norm2(x)
+        x = self.ff(x)
+        x = self.drop_shortcut(x)
+        x = x + shortcut  # Add the original input back
+
+        return x
+
+
+class GPTModelFast(nn.Module):
+    """
+    A faster variant of GPTModel optimized for training speed.
+
+    This version is only marginally faster on CPU (~1.02x) but significantly
+    faster on GPU (~2.05x) during training, thanks to optimized CUDA kernels
+    and FlashAttention support.
+
+    Key differences from the original GPTModel:
+    1. Uses PyTorch's built-in LayerNorm instead of a custom implementation.
+    2. Uses PyTorch's built-in GELU instead of a custom implementation.
+    3. Uses PyTorch's scaled_dot_product_attention instead of a custom MultiHeadAttention.
+    4. Automatically enables FlashAttention on compatible GPUs.
+    """
+    def __init__(self, cfg):
+        super().__init__()
+        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
+        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
+        self.drop_emb = nn.Dropout(cfg["drop_rate"])
+
+        self.trf_blocks = nn.Sequential(
+            *[TransformerBlockFast(cfg) for _ in range(cfg["n_layers"])])
+
+        self.final_norm = nn.LayerNorm(cfg["emb_dim"])
+        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
+
+    def forward(self, in_idx):
+        batch_size, seq_len = in_idx.shape
+        tok_embeds = self.tok_emb(in_idx)
+        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
+        x = tok_embeds + pos_embeds
+        x = self.drop_emb(x)
+        x = self.trf_blocks(x)
+        x = self.final_norm(x)
+        logits = self.out_head(x)
+        return logits

+ 10 - 2
pkg/llms_from_scratch/tests/test_ch03.py

@@ -4,7 +4,7 @@
 # Code: https://github.com/rasbt/LLMs-from-scratch
 
 
-from llms_from_scratch.ch03 import MultiHeadAttention
+from llms_from_scratch.ch03 import MultiHeadAttention, PyTorchMultiHeadAttention
 import torch
 
 
@@ -14,7 +14,15 @@ def test_mha():
     d_in = 256
     d_out = 16
 
-    mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
+    mha = MultiHeadAttention(d_in, d_out, context_length, dropout=0.0, num_heads=2)
+
+    batch = torch.rand(8, 6, d_in)
+    context_vecs = mha(batch)
+
+    context_vecs.shape == torch.Size([8, 6, d_out])
+
+    # Test bonus class
+    mha = PyTorchMultiHeadAttention(d_in, d_out, num_heads=2)
 
     batch = torch.rand(8, 6, d_in)
     context_vecs = mha(batch)

+ 16 - 13
pkg/llms_from_scratch/tests/test_ch04.py

@@ -3,26 +3,29 @@
 #   - https://www.manning.com/books/build-a-large-language-model-from-scratch
 # Code: https://github.com/rasbt/LLMs-from-scratch
 
-from llms_from_scratch.ch04 import GPTModel
+from llms_from_scratch.ch04 import GPTModel, GPTModelFast
 from llms_from_scratch.ch04 import generate_text_simple
 
+import pytest
 import torch
 import tiktoken
 
 
-def test_GPTModel():
-    GPT_CONFIG_124M = {
-        "vocab_size": 50257,     # Vocabulary size
-        "context_length": 1024,  # Context length
-        "emb_dim": 768,          # Embedding dimension
-        "n_heads": 12,           # Number of attention heads
-        "n_layers": 12,          # Number of layers
-        "drop_rate": 0.1,        # Dropout rate
-        "qkv_bias": False        # Query-Key-Value bias
-    }
+GPT_CONFIG_124M = {
+    "vocab_size": 50257,     # Vocabulary size
+    "context_length": 1024,  # Context length
+    "emb_dim": 768,          # Embedding dimension
+    "n_heads": 12,           # Number of attention heads
+    "n_layers": 12,          # Number of layers
+    "drop_rate": 0.1,        # Dropout rate
+    "qkv_bias": False        # Query-Key-Value bias
+}
 
+
+@pytest.mark.parametrize("ModelClass", [GPTModel, GPTModelFast])
+def test_gpt_model_variants(ModelClass):
     torch.manual_seed(123)
-    model = GPTModel(GPT_CONFIG_124M)
+    model = ModelClass(GPT_CONFIG_124M)
     model.eval()  # disable dropout
 
     start_context = "Hello, I am"
@@ -47,4 +50,4 @@ def test_GPTModel():
         [15496,   11,   314,   716, 27018, 24086, 47843, 30961, 42348,  7267,
          49706, 43231, 47062, 34657]
     ])
-    torch.equal(expect, out)
+    assert torch.equal(expect, out), "Generated output does not match expected output"

+ 37 - 41
pkg/llms_from_scratch/tests/test_ch05.py

@@ -4,7 +4,7 @@
 # Code: https://github.com/rasbt/LLMs-from-scratch
 
 from llms_from_scratch.ch02 import create_dataloader_v1
-from llms_from_scratch.ch04 import GPTModel
+from llms_from_scratch.ch04 import GPTModel, GPTModelFast
 from llms_from_scratch.ch05 import train_model_simple
 
 import os
@@ -16,60 +16,47 @@ import torch
 from torch.utils.data import Subset, DataLoader
 
 
-@pytest.mark.parametrize("file_name", ["the-verdict.txt"])
-def test_train_simple(tmp_path, file_name):
+GPT_CONFIG_124M = {
+    "vocab_size": 50257,
+    "context_length": 256,  # Shortened for test speed
+    "emb_dim": 768,
+    "n_heads": 12,
+    "n_layers": 12,
+    "drop_rate": 0.1,
+    "qkv_bias": False
+}
 
-    GPT_CONFIG_124M = {
-        "vocab_size": 50257,    # Vocabulary size
-        "context_length": 256,  # Shortened context length (orig: 1024)
-        "emb_dim": 768,         # Embedding dimension
-        "n_heads": 12,          # Number of attention heads
-        "n_layers": 12,         # Number of layers
-        "drop_rate": 0.1,       # Dropout rate
-        "qkv_bias": False       # Query-key-value bias
-    }
+OTHER_SETTINGS = {
+    "learning_rate": 5e-4,
+    "num_epochs": 2,
+    "batch_size": 1,
+    "weight_decay": 0.1
+}
 
-    OTHER_SETTINGS = {
-        "learning_rate": 5e-4,
-        "num_epochs": 2,
-        "batch_size": 1,
-        "weight_decay": 0.1
-    }
 
+@pytest.mark.parametrize("ModelClass", [GPTModel, GPTModelFast])
+def test_train_simple(tmp_path, ModelClass):
     torch.manual_seed(123)
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
     ##############################
     # Download data if necessary
     ##############################
-
     file_path = tmp_path / "the-verdict.txt"
     url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt"
 
     if not os.path.exists(file_path):
         with urllib.request.urlopen(url) as response:
-            text_data = response.read().decode('utf-8')
-        with open(file_path, "w", encoding="utf-8") as file:
-            file.write(text_data)
+            text_data = response.read().decode("utf-8")
+        with open(file_path, "w", encoding="utf-8") as f:
+            f.write(text_data)
     else:
-        with open(file_path, "r", encoding="utf-8") as file:
-            text_data = file.read()
-
-    ##############################
-    # Initialize model
-    ##############################
-
-    model = GPTModel(GPT_CONFIG_124M)
-    model.to(device)  # no assignment model = model.to(device) necessary for nn.Module classes
-    optimizer = torch.optim.AdamW(
-        model.parameters(), lr=OTHER_SETTINGS["learning_rate"], weight_decay=OTHER_SETTINGS["weight_decay"]
-    )
+        with open(file_path, "r", encoding="utf-8") as f:
+            text_data = f.read()
 
     ##############################
     # Set up dataloaders
     ##############################
-
-    # Train/validation ratio
     train_ratio = 0.90
     split_idx = int(train_ratio * len(text_data))
 
@@ -93,16 +80,25 @@ def test_train_simple(tmp_path, file_name):
         num_workers=0
     )
 
+    # Limit to 1 batch for speed
+    train_subset = Subset(train_loader.dataset, range(1))
+    one_batch_train_loader = DataLoader(train_subset, batch_size=1)
+    val_subset = Subset(val_loader.dataset, range(1))
+    one_batch_val_loader = DataLoader(val_subset, batch_size=1)
+
     ##############################
     # Train model
     ##############################
+    model = ModelClass(GPT_CONFIG_124M)
+    model.to(device)
 
-    tokenizer = tiktoken.get_encoding("gpt2")
+    optimizer = torch.optim.AdamW(
+        model.parameters(),
+        lr=OTHER_SETTINGS["learning_rate"],
+        weight_decay=OTHER_SETTINGS["weight_decay"]
+    )
 
-    train_subset = Subset(train_loader.dataset, range(1))
-    one_batch_train_loader = DataLoader(train_subset, batch_size=1)
-    val_subset = Subset(val_loader.dataset, range(1))
-    one_batch_val_loader = DataLoader(val_subset, batch_size=1)
+    tokenizer = tiktoken.get_encoding("gpt2")
 
     train_losses, val_losses, tokens_seen = train_model_simple(
         model, one_batch_train_loader, one_batch_val_loader, optimizer, device,

+ 1 - 1
pyproject.toml

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
 
 [project]
 name = "llms-from-scratch"
-version = "1.0.0"
+version = "1.0.1"
 description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step"
 readme = "README.md"
 requires-python = ">=3.10"