ch07.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. import json
  6. import os
  7. import psutil
  8. import urllib
  9. import torch
  10. from tqdm import tqdm
  11. from torch.utils.data import Dataset
  12. def download_and_load_file(file_path, url):
  13. if not os.path.exists(file_path):
  14. with urllib.request.urlopen(url) as response:
  15. text_data = response.read().decode("utf-8")
  16. with open(file_path, "w", encoding="utf-8") as file:
  17. file.write(text_data)
  18. # The book originally contained this unnecessary "else" clause:
  19. # else:
  20. # with open(file_path, "r", encoding="utf-8") as file:
  21. # text_data = file.read()
  22. with open(file_path, "r", encoding="utf-8") as file:
  23. data = json.load(file)
  24. return data
  25. def format_input(entry):
  26. instruction_text = (
  27. f"Below is an instruction that describes a task. "
  28. f"Write a response that appropriately completes the request."
  29. f"\n\n### Instruction:\n{entry['instruction']}"
  30. )
  31. input_text = f"\n\n### Input:\n{entry['input']}" if entry["input"] else ""
  32. return instruction_text + input_text
  33. class InstructionDataset(Dataset):
  34. def __init__(self, data, tokenizer):
  35. self.data = data
  36. # Pre-tokenize texts
  37. self.encoded_texts = []
  38. for entry in data:
  39. instruction_plus_input = format_input(entry)
  40. response_text = f"\n\n### Response:\n{entry['output']}"
  41. full_text = instruction_plus_input + response_text
  42. self.encoded_texts.append(
  43. tokenizer.encode(full_text)
  44. )
  45. def __getitem__(self, index):
  46. return self.encoded_texts[index]
  47. def __len__(self):
  48. return len(self.data)
  49. def custom_collate_draft_1(
  50. batch,
  51. pad_token_id=50256,
  52. device="cpu"
  53. ):
  54. # Find the longest sequence in the batch
  55. # and increase the max length by +1, which will add one extra
  56. # padding token below
  57. batch_max_length = max(len(item)+1 for item in batch)
  58. # Pad and prepare inputs
  59. inputs_lst = []
  60. for item in batch:
  61. new_item = item.copy()
  62. # Add an <|endoftext|> token
  63. new_item += [pad_token_id]
  64. # Pad sequences to batch_max_length
  65. padded = (
  66. new_item + [pad_token_id] *
  67. (batch_max_length - len(new_item))
  68. )
  69. # Via padded[:-1], we remove the extra padded token
  70. # that has been added via the +1 setting in batch_max_length
  71. # (the extra padding token will be relevant in later codes)
  72. inputs = torch.tensor(padded[:-1])
  73. inputs_lst.append(inputs)
  74. # Convert list of inputs to tensor and transfer to target device
  75. inputs_tensor = torch.stack(inputs_lst).to(device)
  76. return inputs_tensor
  77. def custom_collate_draft_2(
  78. batch,
  79. pad_token_id=50256,
  80. device="cpu"
  81. ):
  82. # Find the longest sequence in the batch
  83. batch_max_length = max(len(item)+1 for item in batch)
  84. # Pad and prepare inputs
  85. inputs_lst, targets_lst = [], []
  86. for item in batch:
  87. new_item = item.copy()
  88. # Add an <|endoftext|> token
  89. new_item += [pad_token_id]
  90. # Pad sequences to max_length
  91. padded = (
  92. new_item + [pad_token_id] *
  93. (batch_max_length - len(new_item))
  94. )
  95. inputs = torch.tensor(padded[:-1]) # Truncate the last token for inputs
  96. targets = torch.tensor(padded[1:]) # Shift +1 to the right for targets
  97. inputs_lst.append(inputs)
  98. targets_lst.append(targets)
  99. # Convert list of inputs to tensor and transfer to target device
  100. inputs_tensor = torch.stack(inputs_lst).to(device)
  101. targets_tensor = torch.stack(targets_lst).to(device)
  102. return inputs_tensor, targets_tensor
  103. def custom_collate_fn(
  104. batch,
  105. pad_token_id=50256,
  106. ignore_index=-100,
  107. allowed_max_length=None,
  108. device="cpu"
  109. ):
  110. # Find the longest sequence in the batch
  111. batch_max_length = max(len(item)+1 for item in batch)
  112. # Pad and prepare inputs and targets
  113. inputs_lst, targets_lst = [], []
  114. for item in batch:
  115. new_item = item.copy()
  116. # Add an <|endoftext|> token
  117. new_item += [pad_token_id]
  118. # Pad sequences to max_length
  119. padded = (
  120. new_item + [pad_token_id] *
  121. (batch_max_length - len(new_item))
  122. )
  123. inputs = torch.tensor(padded[:-1]) # Truncate the last token for inputs
  124. targets = torch.tensor(padded[1:]) # Shift +1 to the right for targets
  125. # New: Replace all but the first padding tokens in targets by ignore_index
  126. mask = targets == pad_token_id
  127. indices = torch.nonzero(mask).squeeze()
  128. if indices.numel() > 1:
  129. targets[indices[1:]] = ignore_index
  130. # New: Optionally truncate to maximum sequence length
  131. if allowed_max_length is not None:
  132. inputs = inputs[:allowed_max_length]
  133. targets = targets[:allowed_max_length]
  134. inputs_lst.append(inputs)
  135. targets_lst.append(targets)
  136. # Convert list of inputs and targets to tensors and transfer to target device
  137. inputs_tensor = torch.stack(inputs_lst).to(device)
  138. targets_tensor = torch.stack(targets_lst).to(device)
  139. return inputs_tensor, targets_tensor
  140. def check_if_running(process_name):
  141. running = False
  142. for proc in psutil.process_iter(["name"]):
  143. if process_name in proc.info["name"]:
  144. running = True
  145. break
  146. return running
  147. def query_model(
  148. prompt,
  149. model="llama3",
  150. url="http://localhost:11434/api/chat"
  151. ):
  152. # Create the data payload as a dictionary
  153. data = {
  154. "model": model,
  155. "messages": [
  156. {"role": "user", "content": prompt}
  157. ],
  158. "options": { # Settings below are required for deterministic responses
  159. "seed": 123,
  160. "temperature": 0,
  161. "num_ctx": 2048
  162. }
  163. }
  164. # Convert the dictionary to a JSON formatted string and encode it to bytes
  165. payload = json.dumps(data).encode("utf-8")
  166. # Create a request object, setting the method to POST and adding necessary headers
  167. request = urllib.request.Request(
  168. url,
  169. data=payload,
  170. method="POST"
  171. )
  172. request.add_header("Content-Type", "application/json")
  173. # Send the request and capture the response
  174. response_data = ""
  175. with urllib.request.urlopen(request) as response:
  176. # Read and decode the response
  177. while True:
  178. line = response.readline().decode("utf-8")
  179. if not line:
  180. break
  181. response_json = json.loads(line)
  182. response_data += response_json["message"]["content"]
  183. return response_data
  184. def generate_model_scores(json_data, json_key, model="llama3"):
  185. scores = []
  186. for entry in tqdm(json_data, desc="Scoring entries"):
  187. prompt = (
  188. f"Given the input `{format_input(entry)}` "
  189. f"and correct output `{entry['output']}`, "
  190. f"score the model response `{entry[json_key]}`"
  191. f" on a scale from 0 to 100, where 100 is the best score. "
  192. f"Respond with the integer number only."
  193. )
  194. score = query_model(prompt, model)
  195. try:
  196. scores.append(int(score))
  197. except ValueError:
  198. print(f"Could not convert score: {score}")
  199. continue
  200. return scores