app.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. from pathlib import Path
  6. import sys
  7. import tiktoken
  8. import torch
  9. import chainlit
  10. # For llms_from_scratch installation instructions, see:
  11. # https://github.com/rasbt/LLMs-from-scratch/tree/main/pkg
  12. from llms_from_scratch.ch04 import GPTModel
  13. from llms_from_scratch.ch06 import classify_review
  14. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  15. def get_model_and_tokenizer():
  16. """
  17. Code to load finetuned GPT-2 model generated in chapter 6.
  18. This requires that you run the code in chapter 6 first, which generates the necessary model.pth file.
  19. """
  20. GPT_CONFIG_124M = {
  21. "vocab_size": 50257, # Vocabulary size
  22. "context_length": 1024, # Context length
  23. "emb_dim": 768, # Embedding dimension
  24. "n_heads": 12, # Number of attention heads
  25. "n_layers": 12, # Number of layers
  26. "drop_rate": 0.1, # Dropout rate
  27. "qkv_bias": True # Query-key-value bias
  28. }
  29. tokenizer = tiktoken.get_encoding("gpt2")
  30. model_path = Path("..") / "01_main-chapter-code" / "review_classifier.pth"
  31. if not model_path.exists():
  32. print(
  33. f"Could not find the {model_path} file. Please run the chapter 6 code"
  34. " (ch06.ipynb) to generate the review_classifier.pth file."
  35. )
  36. sys.exit()
  37. # Instantiate model
  38. model = GPTModel(GPT_CONFIG_124M)
  39. # Convert model to classifier as in section 6.5 in ch06.ipynb
  40. num_classes = 2
  41. model.out_head = torch.nn.Linear(in_features=GPT_CONFIG_124M["emb_dim"], out_features=num_classes)
  42. # Then load model weights
  43. checkpoint = torch.load(model_path, map_location=device, weights_only=True)
  44. model.load_state_dict(checkpoint)
  45. model.to(device)
  46. model.eval()
  47. return tokenizer, model
  48. # Obtain the necessary tokenizer and model files for the chainlit function below
  49. tokenizer, model = get_model_and_tokenizer()
  50. @chainlit.on_message
  51. async def main(message: chainlit.Message):
  52. """
  53. The main Chainlit function.
  54. """
  55. user_input = message.content
  56. label = classify_review(user_input, model, tokenizer, device, max_length=120)
  57. await chainlit.Message(
  58. content=f"{label}", # This returns the model response to the interface
  59. ).send()