|
|
@@ -1769,36 +1769,6 @@
|
|
|
"print(\"\\nSecond head:\\n\", second_res)"
|
|
|
]
|
|
|
},
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 40,
|
|
|
- "id": "08c2a3fd-e674-4d69-9ef4-ea94b788e937",
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "2360064"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 40,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "block_size = 1024\n",
|
|
|
- "d_in, d_out = 768, 768\n",
|
|
|
- "num_heads = 12\n",
|
|
|
- "\n",
|
|
|
- "mha = MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads)\n",
|
|
|
- "\n",
|
|
|
- "def count_parameters(model):\n",
|
|
|
- " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
|
|
- "\n",
|
|
|
- "count_parameters(mha)"
|
|
|
- ]
|
|
|
- },
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
"id": "dec671bf-7938-4304-ad1e-75d9920e7f43",
|