| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
- # Source for "Build a Large Language Model From Scratch"
- # - https://www.manning.com/books/build-a-large-language-model-from-scratch
- # Code: https://github.com/rasbt/LLMs-from-scratch
- from pathlib import Path
- import sys
- import tiktoken
- import torch
- import chainlit
- # For llms_from_scratch installation instructions, see:
- # https://github.com/rasbt/LLMs-from-scratch/tree/main/pkg
- from llms_from_scratch.ch04 import GPTModel
- from llms_from_scratch.ch05 import (
- generate,
- text_to_token_ids,
- token_ids_to_text,
- )
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- def get_model_and_tokenizer():
- """
- Code to load a GPT-2 model with finetuned weights generated in chapter 7.
- This requires that you run the code in chapter 7 first, which generates the necessary gpt2-medium355M-sft.pth file.
- """
- GPT_CONFIG_355M = {
- "vocab_size": 50257, # Vocabulary size
- "context_length": 1024, # Shortened context length (orig: 1024)
- "emb_dim": 1024, # Embedding dimension
- "n_heads": 16, # Number of attention heads
- "n_layers": 24, # Number of layers
- "drop_rate": 0.0, # Dropout rate
- "qkv_bias": True # Query-key-value bias
- }
- tokenizer = tiktoken.get_encoding("gpt2")
- model_path = Path("..") / "01_main-chapter-code" / "gpt2-medium355M-sft.pth"
- if not model_path.exists():
- print(
- f"Could not find the {model_path} file. Please run the chapter 7 code "
- " (ch07.ipynb) to generate the gpt2-medium355M-sft.pt file."
- )
- sys.exit()
- checkpoint = torch.load(model_path, weights_only=True)
- model = GPTModel(GPT_CONFIG_355M)
- model.load_state_dict(checkpoint)
- model.to(device)
- return tokenizer, model, GPT_CONFIG_355M
- def extract_response(response_text, input_text):
- return response_text[len(input_text):].replace("### Response:", "").strip()
- # Obtain the necessary tokenizer and model files for the chainlit function below
- tokenizer, model, model_config = get_model_and_tokenizer()
- @chainlit.on_message
- async def main(message: chainlit.Message):
- """
- The main Chainlit function.
- """
- torch.manual_seed(123)
- prompt = f"""Below is an instruction that describes a task. Write a response
- that appropriately completes the request.
- ### Instruction:
- {message.content}
- """
- token_ids = generate( # function uses `with torch.no_grad()` internally already
- model=model,
- idx=text_to_token_ids(prompt, tokenizer).to(device), # The user text is provided via as `message.content`
- max_new_tokens=35,
- context_size=model_config["context_length"],
- eos_id=50256
- )
- text = token_ids_to_text(token_ids, tokenizer)
- response = extract_response(text, prompt)
- await chainlit.Message(
- content=f"{response}", # This returns the model response to the interface
- ).send()
|