prepare_dataset.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. """
  6. Script that processes the Project Gutenberg files into fewer larger files.
  7. """
  8. import argparse
  9. import os
  10. import re
  11. from tqdm import tqdm
  12. from gutenberg.src.cleanup import strip_headers
  13. def is_english(text, threshold=0.9):
  14. ascii_chars = sum(1 for c in text if ord(c) < 128)
  15. return ascii_chars / len(text) > threshold
  16. def combine_files(file_paths, target_dir, max_size_mb=500, separator="<|endoftext|>", fallback_encoding="latin1"):
  17. if not os.path.exists(target_dir):
  18. os.makedirs(target_dir)
  19. current_content = []
  20. current_size = 0
  21. file_counter = 1
  22. for file_path in tqdm(file_paths):
  23. try:
  24. with open(file_path, "r", encoding="utf-8") as file:
  25. content = file.read()
  26. except UnicodeDecodeError:
  27. # Attempt to read the file with a fallback encoding
  28. tqdm.write(f"Warning: UnicodeDecodeError encountered. Trying fallback encoding for {file_path}")
  29. with open(file_path, "r", encoding=fallback_encoding) as file:
  30. content = file.read()
  31. if not is_english(content):
  32. tqdm.write(f"Skipping {file_path} as it does not contain primarily English text.")
  33. continue
  34. content = strip_headers(content)
  35. # Regular expression to replace multiple blank lines with a single blank line
  36. content = re.sub(r'\n\s*\n', '\n\n', content)
  37. estimated_size = len(content.encode("utf-8"))
  38. if current_size + estimated_size > max_size_mb * 1024 * 1024:
  39. target_file_path = os.path.join(target_dir, f"combined_{file_counter}.txt")
  40. with open(target_file_path, "w", encoding="utf-8") as target_file:
  41. target_file.write(separator.join(current_content))
  42. file_counter += 1
  43. current_content = [content]
  44. current_size = estimated_size
  45. else:
  46. current_content.append(content)
  47. current_size += estimated_size
  48. if current_content:
  49. target_file_path = os.path.join(target_dir, f"combined_{file_counter}.txt")
  50. with open(target_file_path, "w", encoding="utf-8") as target_file:
  51. target_file.write(separator.join(current_content))
  52. return file_counter
  53. if __name__ == "__main__":
  54. parser = argparse.ArgumentParser(description="Preprocess and combine text files for pretraining")
  55. parser.add_argument("--data_dir", type=str, default="gutenberg/data/raw",
  56. help="Directory containing the downloaded raw training data")
  57. parser.add_argument("--max_size_mb", type=int, default=500,
  58. help="The maximum file size for each concatenated file in megabytes")
  59. parser.add_argument("--output_dir", type=str, default="gutenberg_preprocessed",
  60. help="Directory where the preprocessed data will be saved")
  61. args = parser.parse_args()
  62. all_files = [os.path.join(path, name) for path, subdirs, files in os.walk(args.data_dir)
  63. for name in files if name.endswith((".txt", ".txt.utf8"))]
  64. print(f"{len(all_files)} file(s) to process.")
  65. file_counter = combine_files(all_files, args.output_dir, max_size_mb=args.max_size_mb)
  66. print(f"{file_counter} file(s) saved in {os.path.abspath(args.output_dir)}")