Browse Source

fixed gqa qkv code comments (#660)

Daniel Kleine 5 tháng trước cách đây
mục cha
commit
c2cfb47b1a

+ 3 - 3
ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb

@@ -501,9 +501,9 @@
     "        ################################################\n",
     "\n",
     "        # Transpose keys, values, and queries\n",
-    "        keys = keys.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)\n",
-    "        values = values.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)\n",
-    "        queries = queries.transpose(1, 2)  # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
+    "        keys = keys.transpose(1, 2)  # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
+    "        values = values.transpose(1, 2)  # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
+    "        queries = queries.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)\n",
     "\n",
     "        ##################### NEW #####################\n",
     "        # Apply RoPE\n",

+ 3 - 3
ch05/07_gpt_to_llama/standalone-llama32.ipynb

@@ -257,9 +257,9 @@
     "        values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n",
     "\n",
     "        # Transpose keys, values, and queries\n",
-    "        keys = keys.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)\n",
-    "        values = values.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)\n",
-    "        queries = queries.transpose(1, 2)  # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
+    "        keys = keys.transpose(1, 2)  # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
+    "        values = values.transpose(1, 2)  # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
+    "        queries = queries.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)\n",
     "\n",
     "        # Apply RoPE\n",
     "        keys = apply_rope(keys, cos, sin)\n",

+ 3 - 3
pkg/llms_from_scratch/llama3.py

@@ -166,9 +166,9 @@ class GroupedQueryAttention(nn.Module):
         values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)
 
         # Transpose keys, values, and queries
-        keys = keys.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)
-        values = values.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)
-        queries = queries.transpose(1, 2)  # Shape: (b, num_query_groups, num_tokens, head_dim)
+        keys = keys.transpose(1, 2)  # Shape: (b, num_kv_groups, num_tokens, head_dim)
+        values = values.transpose(1, 2)  # Shape: (b, num_kv_groups, num_tokens, head_dim)
+        queries = queries.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)
 
         # Apply RoPE
         keys = apply_rope(keys, cos, sin)