|
|
@@ -41,7 +41,7 @@ class GPTDatasetV1(Dataset):
|
|
|
|
|
|
|
|
|
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
|
|
- stride=128, shuffle=True, drop_last=True):
|
|
|
+ stride=128, shuffle=True, drop_last=True, num_workers=0):
|
|
|
# Initialize the tokenizer
|
|
|
tokenizer = tiktoken.get_encoding("gpt2")
|
|
|
|
|
|
@@ -50,7 +50,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
|
|
|
|
|
# Create dataloader
|
|
|
dataloader = DataLoader(
|
|
|
- dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
|
|
+ dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
|
|
|
|
|
|
return dataloader
|
|
|
|