|
|
@@ -967,32 +967,6 @@
|
|
|
"print(\"Output shape:\", output.shape)"
|
|
|
]
|
|
|
},
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 22,
|
|
|
- "id": "01e737a6-fc99-42bb-9f7e-4da899168811",
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "name": "stdout",
|
|
|
- "output_type": "stream",
|
|
|
- "text": [
|
|
|
- "Input shape: torch.Size([2, 4, 768])\n",
|
|
|
- "Output shape: torch.Size([2, 4, 768])\n"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "torch.manual_seed(123)\n",
|
|
|
- "\n",
|
|
|
- "x = torch.rand(2, 4, 768) # Shape: [batch_size, num_tokens, emb_dim]\n",
|
|
|
- "block = TransformerBlock(GPT_CONFIG_124M)\n",
|
|
|
- "output = block(x)\n",
|
|
|
- "\n",
|
|
|
- "print(\"Input shape:\", x.shape)\n",
|
|
|
- "print(\"Output shape:\", output.shape)"
|
|
|
- ]
|
|
|
- },
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
"id": "91f502e4-f3e4-40cb-8268-179eec002394",
|
|
|
@@ -1114,44 +1088,6 @@
|
|
|
"print(out)"
|
|
|
]
|
|
|
},
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 44,
|
|
|
- "id": "252b78c2-4404-483b-84fe-a412e55c16fc",
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "name": "stdout",
|
|
|
- "output_type": "stream",
|
|
|
- "text": [
|
|
|
- "Input batch:\n",
|
|
|
- " tensor([[6109, 3626, 6100, 345],\n",
|
|
|
- " [6109, 1110, 6622, 257]])\n",
|
|
|
- "\n",
|
|
|
- "Output shape: torch.Size([2, 4, 50257])\n",
|
|
|
- "tensor([[[-0.0055, 0.3224, 0.2185, ..., 0.2539, 0.4578, -0.4747],\n",
|
|
|
- " [ 0.2663, -0.2975, -0.5040, ..., -0.3903, 0.5328, -0.4224],\n",
|
|
|
- " [ 1.1146, -0.0923, 0.1303, ..., 0.1521, -0.4494, 0.0276],\n",
|
|
|
- " [-0.8239, 0.1174, -0.2566, ..., 1.1197, 0.1036, -0.3993]],\n",
|
|
|
- "\n",
|
|
|
- " [[-0.1027, 0.1752, -0.1048, ..., 0.2258, 0.1559, -0.8747],\n",
|
|
|
- " [ 0.2230, 0.1246, 0.0492, ..., 0.8573, -0.2933, 0.3036],\n",
|
|
|
- " [ 0.9409, 1.3068, -0.1610, ..., 0.8244, 0.1763, 0.0811],\n",
|
|
|
- " [ 0.4395, 0.2753, 0.1540, ..., 1.3410, -0.3709, 0.1643]]],\n",
|
|
|
- " grad_fn=<UnsafeViewBackward0>)\n"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "torch.manual_seed(123)\n",
|
|
|
- "model = GPTModel(GPT_CONFIG_124M)\n",
|
|
|
- "\n",
|
|
|
- "out = model(batch)\n",
|
|
|
- "print(\"Input batch:\\n\", batch)\n",
|
|
|
- "print(\"\\nOutput shape:\", out.shape)\n",
|
|
|
- "print(out)"
|
|
|
- ]
|
|
|
- },
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
"id": "6d616e7a-568b-4921-af29-bd3f4683cd2e",
|