|
|
@@ -27,7 +27,7 @@ class GPTDatasetV1(Dataset):
|
|
|
return self.input_ids[idx], self.target_ids[idx]
|
|
|
|
|
|
|
|
|
-def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=True):
|
|
|
+def create_dataloader_v1(txt, batch_size=4, max_length=256, stride=128, shuffle=True):
|
|
|
# Initialize the tokenizer
|
|
|
tokenizer = tiktoken.get_encoding("gpt2")
|
|
|
|