@@ -47,7 +47,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "PyTorch version: 2.1.0\n"
+ "PyTorch version: 2.3.1\n"
]
}
],
@@ -373,7 +373,7 @@
"metadata": {},
"outputs": [],
"source": [
- "linear.weight = torch.nn.Parameter(embedding.weight.T.detach())"
+ "linear.weight = torch.nn.Parameter(embedding.weight.T)"
},
{