|
|
@@ -257,9 +257,9 @@
|
|
|
" values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n",
|
|
|
"\n",
|
|
|
" # Transpose keys, values, and queries\n",
|
|
|
- " keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
|
|
- " values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
|
|
- " queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
|
|
|
+ " keys = keys.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
|
|
|
+ " values = values.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)\n",
|
|
|
+ " queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
|
|
"\n",
|
|
|
" # Apply RoPE\n",
|
|
|
" keys = apply_rope(keys, cos, sin)\n",
|