|
|
@@ -21,34 +21,15 @@ from gpt_download import download_and_load_gpt2
|
|
|
from previous_chapters import GPTModel, load_weights_into_gpt
|
|
|
|
|
|
|
|
|
-def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path, test_mode=False):
|
|
|
+def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path):
|
|
|
if data_file_path.exists():
|
|
|
print(f"{data_file_path} already exists. Skipping download and extraction.")
|
|
|
return
|
|
|
|
|
|
- if test_mode: # Try multiple times since CI sometimes has connectivity issues
|
|
|
- max_retries = 5
|
|
|
- delay = 5 # delay between retries in seconds
|
|
|
- for attempt in range(max_retries):
|
|
|
- try:
|
|
|
- # Downloading the file
|
|
|
- with urllib.request.urlopen(url, timeout=10) as response:
|
|
|
- with open(zip_path, "wb") as out_file:
|
|
|
- out_file.write(response.read())
|
|
|
- break # if download is successful, break out of the loop
|
|
|
- except urllib.error.URLError as e:
|
|
|
- print(f"Attempt {attempt + 1} failed: {e}")
|
|
|
- if attempt < max_retries - 1:
|
|
|
- time.sleep(delay) # wait before retrying
|
|
|
- else:
|
|
|
- print("Failed to download file after several attempts.")
|
|
|
- return # exit if all retries fail
|
|
|
-
|
|
|
- else: # Code as it appears in the chapter
|
|
|
- # Downloading the file
|
|
|
- with urllib.request.urlopen(url) as response:
|
|
|
- with open(zip_path, "wb") as out_file:
|
|
|
- out_file.write(response.read())
|
|
|
+ # Downloading the file
|
|
|
+ with urllib.request.urlopen(url) as response:
|
|
|
+ with open(zip_path, "wb") as out_file:
|
|
|
+ out_file.write(response.read())
|
|
|
|
|
|
# Unzipping the file
|
|
|
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
|
|
@@ -277,15 +258,11 @@ if __name__ == "__main__":
|
|
|
data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv"
|
|
|
|
|
|
try:
|
|
|
- download_and_unzip_spam_data(
|
|
|
- url, zip_path, extracted_path, data_file_path, test_mode=args.test_mode
|
|
|
- )
|
|
|
+ download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)
|
|
|
except (urllib.error.HTTPError, urllib.error.URLError, TimeoutError) as e:
|
|
|
print(f"Primary URL failed: {e}. Trying backup URL...")
|
|
|
- backup_url = "https://f001.backblazeb2.com/file/LLMs-from-scratch/sms%2Bspam%2Bcollection.zip"
|
|
|
- download_and_unzip_spam_data(
|
|
|
- backup_url, zip_path, extracted_path, data_file_path, test_mode=args.test_mode
|
|
|
- )
|
|
|
+ url = "https://f001.backblazeb2.com/file/LLMs-from-scratch/sms%2Bspam%2Bcollection.zip"
|
|
|
+ download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)
|
|
|
|
|
|
df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])
|
|
|
balanced_df = create_balanced_dataset(df)
|