previous_chapters.py 3.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import tiktoken
  2. import torch
  3. import torch.nn as nn
  4. from torch.utils.data import Dataset, DataLoader
  5. class GPTDatasetV1(Dataset):
  6. def __init__(self, txt, tokenizer, max_length, stride):
  7. self.tokenizer = tokenizer
  8. self.input_ids = []
  9. self.target_ids = []
  10. # Tokenize the entire text
  11. token_ids = tokenizer.encode(txt)
  12. # Use a sliding window to chunk the book into overlapping sequences of max_length
  13. for i in range(0, len(token_ids) - max_length, stride):
  14. input_chunk = token_ids[i:i + max_length]
  15. target_chunk = token_ids[i + 1: i + max_length + 1]
  16. self.input_ids.append(torch.tensor(input_chunk))
  17. self.target_ids.append(torch.tensor(target_chunk))
  18. def __len__(self):
  19. return len(self.input_ids)
  20. def __getitem__(self, idx):
  21. return self.input_ids[idx], self.target_ids[idx]
  22. def create_dataloader_v1(txt, batch_size=4, max_length=256,
  23. stride=128, shuffle=True, drop_last=True):
  24. # Initialize the tokenizer
  25. tokenizer = tiktoken.get_encoding("gpt2")
  26. # Create dataset
  27. dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
  28. # Create dataloader
  29. dataloader = DataLoader(
  30. dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
  31. return dataloader
  32. class MultiHeadAttention(nn.Module):
  33. def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
  34. super().__init__()
  35. assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
  36. self.d_out = d_out
  37. self.num_heads = num_heads
  38. self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
  39. self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
  40. self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
  41. self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
  42. self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
  43. self.dropout = nn.Dropout(dropout)
  44. self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))
  45. def forward(self, x):
  46. b, num_tokens, d_in = x.shape
  47. keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
  48. queries = self.W_query(x)
  49. values = self.W_value(x)
  50. # We implicitly split the matrix by adding a `num_heads` dimension
  51. # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
  52. keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
  53. values = values.view(b, num_tokens, self.num_heads, self.head_dim)
  54. queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
  55. # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
  56. keys = keys.transpose(1, 2)
  57. queries = queries.transpose(1, 2)
  58. values = values.transpose(1, 2)
  59. # Compute scaled dot-product attention (aka self-attention) with a causal mask
  60. attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
  61. # Original mask truncated to the number of tokens and converted to boolean
  62. mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
  63. # Unsqueeze the mask to match dimensions
  64. mask_unsqueezed = mask_bool.unsqueeze(0)
  65. # Use the unsqueezed mask to fill attention scores
  66. attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)
  67. attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
  68. attn_weights = self.dropout(attn_weights)
  69. # Shape: (b, num_tokens, num_heads, head_dim)
  70. context_vec = (attn_weights @ values).transpose(1, 2)
  71. # Combine heads, where self.d_out = self.num_heads * self.head_dim
  72. context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
  73. context_vec = self.out_proj(context_vec) # optional projection
  74. return context_vec