app_own.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. from pathlib import Path
  6. import sys
  7. import tiktoken
  8. import torch
  9. import chainlit
  10. from previous_chapters import (
  11. generate,
  12. GPTModel,
  13. text_to_token_ids,
  14. token_ids_to_text,
  15. )
  16. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  17. def get_model_and_tokenizer():
  18. """
  19. Code to load a GPT-2 model with pretrained weights generated in chapter 5.
  20. This requires that you run the code in chapter 5 first, which generates the necessary model.pth file.
  21. """
  22. GPT_CONFIG_124M = {
  23. "vocab_size": 50257, # Vocabulary size
  24. "context_length": 256, # Shortened context length (orig: 1024)
  25. "emb_dim": 768, # Embedding dimension
  26. "n_heads": 12, # Number of attention heads
  27. "n_layers": 12, # Number of layers
  28. "drop_rate": 0.1, # Dropout rate
  29. "qkv_bias": False # Query-key-value bias
  30. }
  31. tokenizer = tiktoken.get_encoding("gpt2")
  32. model_path = Path("..") / "01_main-chapter-code" / "model.pth"
  33. if not model_path.exists():
  34. print(f"Could not find the {model_path} file. Please run the chapter 5 code (ch05.ipynb) to generate the model.pth file.")
  35. sys.exit()
  36. checkpoint = torch.load(model_path, weights_only=True)
  37. model = GPTModel(GPT_CONFIG_124M)
  38. model.load_state_dict(checkpoint)
  39. model.to(device)
  40. return tokenizer, model, GPT_CONFIG_124M
  41. # Obtain the necessary tokenizer and model files for the chainlit function below
  42. tokenizer, model, model_config = get_model_and_tokenizer()
  43. @chainlit.on_message
  44. async def main(message: chainlit.Message):
  45. """
  46. The main Chainlit function.
  47. """
  48. token_ids = generate( # function uses `with torch.no_grad()` internally already
  49. model=model,
  50. idx=text_to_token_ids(message.content, tokenizer).to(device), # The user text is provided via as `message.content`
  51. max_new_tokens=50,
  52. context_size=model_config["context_length"],
  53. top_k=1,
  54. temperature=0.0
  55. )
  56. text = token_ids_to_text(token_ids, tokenizer)
  57. await chainlit.Message(
  58. content=f"{text}", # This returns the model response to the interface
  59. ).send()