Explorar el Código

Fix data download if UCI is temporarily down (#592)

Sebastian Raschka hace 7 meses
padre
commit
222803737d
Se han modificado 1 ficheros con 8 adiciones y 31 borrados
  1. 8 31
      ch06/01_main-chapter-code/gpt_class_finetune.py

+ 8 - 31
ch06/01_main-chapter-code/gpt_class_finetune.py

@@ -21,34 +21,15 @@ from gpt_download import download_and_load_gpt2
 from previous_chapters import GPTModel, load_weights_into_gpt
 
 
-def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path, test_mode=False):
+def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path):
     if data_file_path.exists():
         print(f"{data_file_path} already exists. Skipping download and extraction.")
         return
 
-    if test_mode:  # Try multiple times since CI sometimes has connectivity issues
-        max_retries = 5
-        delay = 5  # delay between retries in seconds
-        for attempt in range(max_retries):
-            try:
-                # Downloading the file
-                with urllib.request.urlopen(url, timeout=10) as response:
-                    with open(zip_path, "wb") as out_file:
-                        out_file.write(response.read())
-                break  # if download is successful, break out of the loop
-            except urllib.error.URLError as e:
-                print(f"Attempt {attempt + 1} failed: {e}")
-                if attempt < max_retries - 1:
-                    time.sleep(delay)  # wait before retrying
-                else:
-                    print("Failed to download file after several attempts.")
-                    return  # exit if all retries fail
-
-    else:  # Code as it appears in the chapter
-        # Downloading the file
-        with urllib.request.urlopen(url) as response:
-            with open(zip_path, "wb") as out_file:
-                out_file.write(response.read())
+    # Downloading the file
+    with urllib.request.urlopen(url) as response:
+        with open(zip_path, "wb") as out_file:
+            out_file.write(response.read())
 
     # Unzipping the file
     with zipfile.ZipFile(zip_path, "r") as zip_ref:
@@ -277,15 +258,11 @@ if __name__ == "__main__":
     data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv"
 
     try:
-        download_and_unzip_spam_data(
-            url, zip_path, extracted_path, data_file_path, test_mode=args.test_mode
-        )
+        download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)
     except (urllib.error.HTTPError, urllib.error.URLError, TimeoutError) as e:
         print(f"Primary URL failed: {e}. Trying backup URL...")
-        backup_url = "https://f001.backblazeb2.com/file/LLMs-from-scratch/sms%2Bspam%2Bcollection.zip"
-        download_and_unzip_spam_data(
-            backup_url, zip_path, extracted_path, data_file_path, test_mode=args.test_mode
-        )
+        url = "https://f001.backblazeb2.com/file/LLMs-from-scratch/sms%2Bspam%2Bcollection.zip"
+        download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)
 
     df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])
     balanced_df = create_balanced_dataset(df)