|
|
@@ -239,4 +239,4 @@ if __name__ == "__main__":
|
|
|
# Save and load model
|
|
|
torch.save(model.state_dict(), "model.pth")
|
|
|
model = GPTModel(GPT_CONFIG_124M)
|
|
|
- model.load_state_dict(torch.load("model.pth"), weights_only=True)
|
|
|
+ model.load_state_dict(torch.load("model.pth", weights_only=True))
|