find-near-duplicates.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. import argparse
  6. import json
  7. import re
  8. from sklearn import __version__ as sklearn_version
  9. from sklearn.feature_extraction.text import TfidfVectorizer
  10. from sklearn.metrics.pairwise import cosine_similarity
  11. # Sample JSON dataset
  12. example_data = [
  13. {"instruction": "What is the capital of Italy?",
  14. "input": "", "output": "The capital of Italy is Rome."
  15. },
  16. {"instruction": "What's the capital city of Italy?",
  17. "input": "", "output": "The capital city is Rome."
  18. },
  19. {"instruction": "Identify the main verb in the sentence: 'The cat sleeps on the couch.'",
  20. "input": "", "output": "The verb is 'sleeps'."
  21. },
  22. {"instruction": "Identify the verb in the following sentence: The cat sleeps on the couch.",
  23. "input": "", "output": "The verb in the sentence is \"sleeps.\""
  24. },
  25. # ...
  26. ]
  27. def preprocess_text(text):
  28. # Lowercase the text
  29. text = text.lower()
  30. # Remove punctuation
  31. text = re.sub(r'[^\w\s]', '', text)
  32. return text
  33. def find_near_duplicates(json_data, threshold=0.75, key="instruction"):
  34. """The higher the threshold, the more similar the texts have to be to match"""
  35. # Extract instructions
  36. text = [preprocess_text(item[key]) for item in json_data if item[key]]
  37. near_duplicates = []
  38. indices_to_remove = set()
  39. if not text:
  40. return {}, near_duplicates
  41. # Vectorize the text data
  42. vectorizer = TfidfVectorizer(stop_words=None, analyzer='char', ngram_range=(1, 3))
  43. tfidf_matrix = vectorizer.fit_transform(text)
  44. # Compute cosine similarity between each pair of entries
  45. cos_sim_matrix = cosine_similarity(tfidf_matrix)
  46. # Find pairs of near-duplicate instructions based on the threshold
  47. for i in range(len(cos_sim_matrix)):
  48. for j in range(i+1, len(cos_sim_matrix)):
  49. if cos_sim_matrix[i, j] > threshold:
  50. if len(json_data[i][key]) <= 1 or len(json_data[j][key]) <= 1:
  51. continue
  52. near_duplicates.append((json_data[i], json_data[j], cos_sim_matrix[i, j]))
  53. if key in ("input", "output"): # Don't remove duplicates based on the instruction
  54. indices_to_remove.add(j) # Mark the second entry for removal
  55. # Remove the near-duplicate entries
  56. filtered_json_data = [item for index, item in enumerate(json_data) if index not in indices_to_remove]
  57. return filtered_json_data, near_duplicates
  58. def find_print_and_remove_near_duplicates(json_data, remove_duplicates=False, threshold=0.75):
  59. """
  60. Searches each key in the first JSON object for duplicates across a list of JSON objects.
  61. Prints the duplicates if found.
  62. """
  63. for key in json_data[0].keys():
  64. if remove_duplicates:
  65. json_data, near_duplicates = find_near_duplicates(json_data, key=key, threshold=threshold)
  66. else:
  67. _, near_duplicates = find_near_duplicates(json_data, key=key, threshold=threshold)
  68. separator = 50 * '='
  69. print(f"\n\n{separator}\nSearching '{key}' for duplicates ...\n{separator}")
  70. if not near_duplicates:
  71. print("No duplicates found")
  72. else:
  73. for dup in near_duplicates:
  74. print(
  75. f"Duplicate pair found with similarity {dup[2]:.2f}:\n"
  76. f"1. {dup[0][key]}\n2. {dup[1][key]}\n"
  77. )
  78. return json_data
  79. if __name__ == "__main__":
  80. print("scikit-learn version:", sklearn_version)
  81. parser = argparse.ArgumentParser()
  82. parser.add_argument(
  83. "--json_file",
  84. type=str,
  85. help=("Path to the dataset JSON file")
  86. )
  87. parser.add_argument(
  88. "--threshold",
  89. type=float,
  90. default=0.9,
  91. help=("A sensitivity threshold between 0 and 1 where 1 is strictest")
  92. )
  93. parser.add_argument(
  94. "--remove_duplicates",
  95. action='store_true',
  96. default=False,
  97. help=(
  98. "Removes duplicates based on the 'input' or 'output' keys "
  99. " (but not the 'instruction') and saves the cleaned JSON file as --json_output_file"
  100. )
  101. )
  102. parser.add_argument(
  103. "--json_output_file",
  104. type=str,
  105. help=("Path to the dataset JSON file")
  106. )
  107. args = parser.parse_args()
  108. if args.remove_duplicates and not args.json_output_file:
  109. raise ValueError(
  110. "Provide an output file via --json_output_file "
  111. "to save the cleaned JSON data."
  112. )
  113. if not args.json_file:
  114. json_data = example_data
  115. else:
  116. with open(args.json_file, "r") as file:
  117. json_data = json.load(file)
  118. json_data = find_print_and_remove_near_duplicates(
  119. json_data=json_data,
  120. remove_duplicates=args.remove_duplicates,
  121. threshold=args.threshold
  122. )
  123. if args.remove_duplicates:
  124. with open(args.json_output_file, "w") as file:
  125. json.dump(json_data, file, indent=4)