ollama_evaluate.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. #
  6. # A minimal instruction finetuning file based on the code in chapter 7
  7. import json
  8. import psutil
  9. from tqdm import tqdm
  10. import urllib.request
  11. def query_model(prompt, model="llama3", url="http://localhost:11434/api/chat"):
  12. # Create the data payload as a dictionary
  13. data = {
  14. "model": model,
  15. "messages": [
  16. {"role": "user", "content": prompt}
  17. ],
  18. "options": { # Settings below are required for deterministic responses
  19. "seed": 123,
  20. "temperature": 0,
  21. "num_ctx": 2048
  22. }
  23. }
  24. # Convert the dictionary to a JSON formatted string and encode it to bytes
  25. payload = json.dumps(data).encode("utf-8")
  26. # Create a request object, setting the method to POST and adding necessary headers
  27. request = urllib.request.Request(url, data=payload, method="POST")
  28. request.add_header("Content-Type", "application/json")
  29. # Send the request and capture the response
  30. response_data = ""
  31. with urllib.request.urlopen(request) as response:
  32. # Read and decode the response
  33. while True:
  34. line = response.readline().decode("utf-8")
  35. if not line:
  36. break
  37. response_json = json.loads(line)
  38. response_data += response_json["message"]["content"]
  39. return response_data
  40. def check_if_running(process_name):
  41. running = False
  42. for proc in psutil.process_iter(["name"]):
  43. if process_name in proc.info["name"]:
  44. running = True
  45. break
  46. return running
  47. def format_input(entry):
  48. instruction_text = (
  49. f"Below is an instruction that describes a task. "
  50. f"Write a response that appropriately completes the request."
  51. f"\n\n### Instruction:\n{entry['instruction']}"
  52. )
  53. input_text = f"\n\n### Input:\n{entry['input']}" if entry["input"] else ""
  54. return instruction_text + input_text
  55. def main(file_path):
  56. ollama_running = check_if_running("ollama")
  57. if not ollama_running:
  58. raise RuntimeError("Ollama not running. Launch ollama before proceeding.")
  59. print("Ollama running:", check_if_running("ollama"))
  60. with open(file_path, "r") as file:
  61. test_data = json.load(file)
  62. model = "llama3"
  63. scores = generate_model_scores(test_data, "model_response", model)
  64. print(f"Number of scores: {len(scores)} of {len(test_data)}")
  65. print(f"Average score: {sum(scores)/len(scores):.2f}\n")
  66. def generate_model_scores(json_data, json_key, model="llama3"):
  67. scores = []
  68. for entry in tqdm(json_data, desc="Scoring entries"):
  69. if entry[json_key] == "":
  70. scores.append(0)
  71. else:
  72. prompt = (
  73. f"Given the input `{format_input(entry)}` "
  74. f"and correct output `{entry['output']}`, "
  75. f"score the model response `{entry[json_key]}`"
  76. f" on a scale from 0 to 100, where 100 is the best score. "
  77. f"Respond with the integer number only."
  78. )
  79. score = query_model(prompt, model)
  80. try:
  81. scores.append(int(score))
  82. except ValueError:
  83. print(f"Could not convert score: {score}")
  84. continue
  85. return scores
  86. if __name__ == "__main__":
  87. import argparse
  88. parser = argparse.ArgumentParser(
  89. description="Evaluate model responses with ollama"
  90. )
  91. parser.add_argument(
  92. "--file_path",
  93. required=True,
  94. help=(
  95. "The path to the test dataset `.json` file with the"
  96. " `'output'` and `'model_response'` keys"
  97. )
  98. )
  99. args = parser.parse_args()
  100. main(file_path=args.file_path)