app.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. from pathlib import Path
  6. import sys
  7. import tiktoken
  8. import torch
  9. import chainlit
  10. # For llms_from_scratch installation instructions, see:
  11. # https://github.com/rasbt/LLMs-from-scratch/tree/main/pkg
  12. from llms_from_scratch.ch04 import GPTModel
  13. from llms_from_scratch.ch05 import (
  14. generate,
  15. text_to_token_ids,
  16. token_ids_to_text,
  17. )
  18. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  19. def get_model_and_tokenizer():
  20. """
  21. Code to load a GPT-2 model with finetuned weights generated in chapter 7.
  22. This requires that you run the code in chapter 7 first, which generates the necessary gpt2-medium355M-sft.pth file.
  23. """
  24. GPT_CONFIG_355M = {
  25. "vocab_size": 50257, # Vocabulary size
  26. "context_length": 1024, # Shortened context length (orig: 1024)
  27. "emb_dim": 1024, # Embedding dimension
  28. "n_heads": 16, # Number of attention heads
  29. "n_layers": 24, # Number of layers
  30. "drop_rate": 0.0, # Dropout rate
  31. "qkv_bias": True # Query-key-value bias
  32. }
  33. tokenizer = tiktoken.get_encoding("gpt2")
  34. model_path = Path("..") / "01_main-chapter-code" / "gpt2-medium355M-sft.pth"
  35. if not model_path.exists():
  36. print(
  37. f"Could not find the {model_path} file. Please run the chapter 7 code "
  38. " (ch07.ipynb) to generate the gpt2-medium355M-sft.pt file."
  39. )
  40. sys.exit()
  41. checkpoint = torch.load(model_path, weights_only=True)
  42. model = GPTModel(GPT_CONFIG_355M)
  43. model.load_state_dict(checkpoint)
  44. model.to(device)
  45. return tokenizer, model, GPT_CONFIG_355M
  46. def extract_response(response_text, input_text):
  47. return response_text[len(input_text):].replace("### Response:", "").strip()
  48. # Obtain the necessary tokenizer and model files for the chainlit function below
  49. tokenizer, model, model_config = get_model_and_tokenizer()
  50. @chainlit.on_message
  51. async def main(message: chainlit.Message):
  52. """
  53. The main Chainlit function.
  54. """
  55. torch.manual_seed(123)
  56. prompt = f"""Below is an instruction that describes a task. Write a response
  57. that appropriately completes the request.
  58. ### Instruction:
  59. {message.content}
  60. """
  61. token_ids = generate( # function uses `with torch.no_grad()` internally already
  62. model=model,
  63. idx=text_to_token_ids(prompt, tokenizer).to(device), # The user text is provided via as `message.content`
  64. max_new_tokens=35,
  65. context_size=model_config["context_length"],
  66. eos_id=50256
  67. )
  68. text = token_ids_to_text(token_ids, tokenizer)
  69. response = extract_response(text, prompt)
  70. await chainlit.Message(
  71. content=f"{response}", # This returns the model response to the interface
  72. ).send()