|
@@ -35,6 +35,7 @@
|
|
|
"source": [
|
|
"source": [
|
|
|
"import tiktoken\n",
|
|
"import tiktoken\n",
|
|
|
"import torch\n",
|
|
"import torch\n",
|
|
|
|
|
+ "import torch.nn as nn\n",
|
|
|
"from torch.utils.data import Dataset, DataLoader\n",
|
|
"from torch.utils.data import Dataset, DataLoader\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
"\n",
|
|
"\n",
|
|
@@ -86,8 +87,8 @@
|
|
|
"block_size = max_len\n",
|
|
"block_size = max_len\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
- "token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)\n",
|
|
|
|
|
- "pos_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)\n",
|
|
|
|
|
|
|
+ "token_embedding_layer = nn.Embedding(vocab_size, output_dim)\n",
|
|
|
|
|
+ "pos_embedding_layer = nn.Embedding(vocab_size, output_dim)\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
"max_length = 4\n",
|
|
"max_length = 4\n",
|
|
|
"dataloader = create_dataloader(raw_text, batch_size=8, max_length=max_length, stride=5)"
|
|
"dataloader = create_dataloader(raw_text, batch_size=8, max_length=max_length, stride=5)"
|
|
@@ -152,17 +153,15 @@
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
"outputs": [],
|
|
|
"source": [
|
|
"source": [
|
|
|
- "import torch\n",
|
|
|
|
|
- "\n",
|
|
|
|
|
- "class CausalSelfAttention(torch.nn.Module):\n",
|
|
|
|
|
|
|
+ "class CausalSelfAttention(nn.Module):\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
" def __init__(self, d_in, d_out, block_size, dropout):\n",
|
|
" def __init__(self, d_in, d_out, block_size, dropout):\n",
|
|
|
" super().__init__()\n",
|
|
" super().__init__()\n",
|
|
|
" self.d_out = d_out\n",
|
|
" self.d_out = d_out\n",
|
|
|
- " self.W_query = torch.nn.Linear(d_in, d_out, bias=False)\n",
|
|
|
|
|
- " self.W_key = torch.nn.Linear(d_in, d_out, bias=False)\n",
|
|
|
|
|
- " self.W_value = torch.nn.Linear(d_in, d_out, bias=False)\n",
|
|
|
|
|
- " self.dropout = torch.nn.Dropout(dropout) # New\n",
|
|
|
|
|
|
|
+ " self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
|
|
|
|
|
+ " self.W_key = nn.Linear(d_in, d_out, bias=False)\n",
|
|
|
|
|
+ " self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
|
|
|
|
|
+ " self.dropout = nn.Dropout(dropout) # New\n",
|
|
|
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
|
|
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
" def forward(self, x):\n",
|
|
" def forward(self, x):\n",
|
|
@@ -181,14 +180,14 @@
|
|
|
" return context_vec\n",
|
|
" return context_vec\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
- "class MultiHeadAttentionWrapper(torch.nn.Module):\n",
|
|
|
|
|
|
|
+ "class MultiHeadAttentionWrapper(nn.Module):\n",
|
|
|
" def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
|
|
" def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
|
|
|
" super().__init__()\n",
|
|
" super().__init__()\n",
|
|
|
- " self.heads = torch.nn.ModuleList(\n",
|
|
|
|
|
|
|
+ " self.heads = nn.ModuleList(\n",
|
|
|
" [CausalSelfAttention(d_in, d_out, block_size, dropout) \n",
|
|
" [CausalSelfAttention(d_in, d_out, block_size, dropout) \n",
|
|
|
" for _ in range(num_heads)]\n",
|
|
" for _ in range(num_heads)]\n",
|
|
|
" )\n",
|
|
" )\n",
|
|
|
- " self.out_proj = torch.nn.Linear(d_out*num_heads, d_out*num_heads)\n",
|
|
|
|
|
|
|
+ " self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
" def forward(self, x):\n",
|
|
" def forward(self, x):\n",
|
|
|
" context_vec = torch.cat([head(x) for head in self.heads], dim=-1)\n",
|
|
" context_vec = torch.cat([head(x) for head in self.heads], dim=-1)\n",
|
|
@@ -241,10 +240,7 @@
|
|
|
"metadata": {},
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
"outputs": [],
|
|
|
"source": [
|
|
"source": [
|
|
|
- "import torch\n",
|
|
|
|
|
- "\n",
|
|
|
|
|
- "\n",
|
|
|
|
|
- "class MultiHeadAttention(torch.nn.Module):\n",
|
|
|
|
|
|
|
+ "class MultiHeadAttention(nn.Module):\n",
|
|
|
" def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
|
|
" def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
|
|
|
" super().__init__()\n",
|
|
" super().__init__()\n",
|
|
|
" assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
|
|
" assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
|
|
@@ -253,30 +249,48 @@
|
|
|
" self.num_heads = num_heads\n",
|
|
" self.num_heads = num_heads\n",
|
|
|
" self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim\n",
|
|
" self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
- " self.W_query = torch.nn.Linear(d_in, d_out, bias=False)\n",
|
|
|
|
|
- " self.W_key = torch.nn.Linear(d_in, d_out, bias=False)\n",
|
|
|
|
|
- " self.W_value = torch.nn.Linear(d_in, d_out, bias=False)\n",
|
|
|
|
|
- " self.out_proj = torch.nn.Linear(d_out, d_out) # Linear layer to combine head outputs\n",
|
|
|
|
|
- " self.dropout = torch.nn.Dropout(dropout)\n",
|
|
|
|
|
|
|
+ " self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
|
|
|
|
|
+ " self.W_key = nn.Linear(d_in, d_out, bias=False)\n",
|
|
|
|
|
+ " self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
|
|
|
|
|
+ " self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs\n",
|
|
|
|
|
+ " self.dropout = nn.Dropout(dropout)\n",
|
|
|
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))\n",
|
|
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
" def forward(self, x):\n",
|
|
" def forward(self, x):\n",
|
|
|
- " b, n_tokens, d_in = x.shape\n",
|
|
|
|
|
|
|
+ " b, num_tokens, d_in = x.shape\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
- " # Split into multiple heads\n",
|
|
|
|
|
- " keys = self.W_key(x).view(b, n_tokens, self.num_heads, self.head_dim).transpose(1, 2)\n",
|
|
|
|
|
- " queries = self.W_query(x).view(b, n_tokens, self.num_heads, self.head_dim).transpose(1, 2)\n",
|
|
|
|
|
- " values = self.W_value(x).view(b, n_tokens, self.num_heads, self.head_dim).transpose(1, 2)\n",
|
|
|
|
|
|
|
+ " keys = self.W_key(x) # Shape: (b, num_tokens, d_out)\n",
|
|
|
|
|
+ " queries = self.W_query(x)\n",
|
|
|
|
|
+ " values = self.W_value(x)\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # We implicitly split the matrix by adding a `num_heads` dimension\n",
|
|
|
|
|
+ " # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)\n",
|
|
|
|
|
+ " keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) \n",
|
|
|
|
|
+ " values = values.view(b, num_tokens, self.num_heads, self.head_dim)\n",
|
|
|
|
|
+ " queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)\n",
|
|
|
|
|
+ " keys = keys.transpose(1, 2)\n",
|
|
|
|
|
+ " queries = queries.transpose(1, 2)\n",
|
|
|
|
|
+ " values = values.transpose(1, 2)\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
- " # Compute scaled dot-product attention for each head\n",
|
|
|
|
|
|
|
+ " # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
|
|
|
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
|
|
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
|
|
|
- " attn_scores.masked_fill_(self.mask.bool()[:n_tokens, :n_tokens].unsqueeze(0).unsqueeze(0), -torch.inf)\n",
|
|
|
|
|
|
|
+ " # Original mask truncated to the number of tokens and converted to boolean\n",
|
|
|
|
|
+ " mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
|
|
|
|
|
+ " # Unsqueeze the mask twice to match dimensions\n",
|
|
|
|
|
+ " mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0)\n",
|
|
|
|
|
+ " # Use the unsqueezed mask to fill attention scores\n",
|
|
|
|
|
+ " attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)\n",
|
|
|
|
|
+ " \n",
|
|
|
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
|
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
|
|
" attn_weights = self.dropout(attn_weights)\n",
|
|
" attn_weights = self.dropout(attn_weights)\n",
|
|
|
- " context_vec = (attn_weights @ values).transpose(1, 2) # Shape: (b, T, n_heads, head_dim)\n",
|
|
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # Shape: (b, num_tokens, num_heads, head_dim)\n",
|
|
|
|
|
+ " context_vec = (attn_weights @ values).transpose(1, 2) \n",
|
|
|
" \n",
|
|
" \n",
|
|
|
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
|
|
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
|
|
|
- " context_vec = context_vec.contiguous().view(b, n_tokens, self.d_out)\n",
|
|
|
|
|
|
|
+ " context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)\n",
|
|
|
" context_vec = self.out_proj(context_vec) # optional projection\n",
|
|
" context_vec = self.out_proj(context_vec) # optional projection\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
" return context_vec"
|
|
" return context_vec"
|