Browse Source

use probas in argmax

rasbt 1 year ago
parent
commit
c88e8edf72
1 changed files with 22 additions and 11 deletions
  1. 22 11
      appendix-A/03_main-chapter-code/code-part1.ipynb

+ 22 - 11
appendix-A/03_main-chapter-code/code-part1.ipynb

@@ -37,7 +37,7 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "2.0.1\n"
+      "2.2.1\n"
      ]
     }
    ],
@@ -591,13 +591,13 @@
      "output_type": "stream",
      "text": [
       "Parameter containing:\n",
-      "tensor([[-0.0064,  0.0004, -0.0903,  ..., -0.1316,  0.0910,  0.0363],\n",
-      "        [ 0.1354,  0.1124, -0.0476,  ...,  0.0578,  0.1014,  0.0008],\n",
-      "        [ 0.0975, -0.0478,  0.0298,  ...,  0.0416,  0.0849,  0.1314],\n",
+      "tensor([[ 0.0956,  0.1280, -0.0696,  ...,  0.0961,  0.0631,  0.1349],\n",
+      "        [ 0.0983,  0.0580, -0.0574,  ...,  0.0981,  0.0370,  0.0516],\n",
+      "        [-0.0429, -0.1411, -0.1399,  ...,  0.0767,  0.0019,  0.1400],\n",
       "        ...,\n",
-      "        [ 0.0118,  0.0240,  0.0420,  ..., -0.1305, -0.0517, -0.0826],\n",
-      "        [-0.0323,  0.1073,  0.0215,  ..., -0.1264, -0.1100,  0.1232],\n",
-      "        [ 0.0861,  0.0403, -0.0545,  ...,  0.1352,  0.0817, -0.0938]],\n",
+      "        [-0.0777, -0.0726,  0.1273,  ..., -0.0613,  0.0491, -0.1381],\n",
+      "        [-0.0830, -0.0969, -0.0473,  ...,  0.0762,  0.1318, -0.1174],\n",
+      "        [ 0.0468, -0.0213,  0.0387,  ...,  0.0639,  0.0927, -0.0668]],\n",
       "       requires_grad=True)\n"
      ]
     }
@@ -881,10 +881,21 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 37,
    "id": "4db4d7f4-82da-44a4-b94e-ee04665d9c3c",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Batch 1: tensor([[-1.2000,  3.1000],\n",
+      "        [-0.5000,  2.6000]]) tensor([0, 0])\n",
+      "Batch 2: tensor([[ 2.3000, -1.1000],\n",
+      "        [-0.9000,  2.9000]]) tensor([1, 0])\n"
+     ]
+    }
+   ],
    "source": [
     "for idx, (x, y) in enumerate(train_loader):\n",
     "    print(f\"Batch {idx+1}:\", x, y)"
@@ -1000,7 +1011,7 @@
     "probas = torch.softmax(outputs, dim=1)\n",
     "print(probas)\n",
     "\n",
-    "predictions = torch.argmax(outputs, dim=1)\n",
+    "predictions = torch.argmax(probas, dim=1)\n",
     "print(predictions)"
    ]
   },
@@ -1254,7 +1265,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.10.6"
+   "version": "3.11.4"
   }
  },
  "nbformat": 4,