|
@@ -17,6 +17,8 @@ from previous_chapters import (
|
|
|
token_ids_to_text,
|
|
token_ids_to_text,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
+
|
|
|
|
|
|
|
|
def get_model_and_tokenizer():
|
|
def get_model_and_tokenizer():
|
|
|
"""
|
|
"""
|
|
@@ -34,8 +36,6 @@ def get_model_and_tokenizer():
|
|
|
"qkv_bias": False # Query-key-value bias
|
|
"qkv_bias": False # Query-key-value bias
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
-
|
|
|
|
|
tokenizer = tiktoken.get_encoding("gpt2")
|
|
tokenizer = tiktoken.get_encoding("gpt2")
|
|
|
|
|
|
|
|
model_path = Path("..") / "01_main-chapter-code" / "model.pth"
|
|
model_path = Path("..") / "01_main-chapter-code" / "model.pth"
|
|
@@ -43,7 +43,7 @@ def get_model_and_tokenizer():
|
|
|
print(f"Could not find the {model_path} file. Please run the chapter 5 code (ch05.ipynb) to generate the model.pth file.")
|
|
print(f"Could not find the {model_path} file. Please run the chapter 5 code (ch05.ipynb) to generate the model.pth file.")
|
|
|
sys.exit()
|
|
sys.exit()
|
|
|
|
|
|
|
|
- checkpoint = torch.load("model.pth", weights_only=True)
|
|
|
|
|
|
|
+ checkpoint = torch.load(model_path, weights_only=True)
|
|
|
model = GPTModel(GPT_CONFIG_124M)
|
|
model = GPTModel(GPT_CONFIG_124M)
|
|
|
model.load_state_dict(checkpoint)
|
|
model.load_state_dict(checkpoint)
|
|
|
model.to(device)
|
|
model.to(device)
|
|
@@ -60,9 +60,9 @@ async def main(message: chainlit.Message):
|
|
|
"""
|
|
"""
|
|
|
The main Chainlit function.
|
|
The main Chainlit function.
|
|
|
"""
|
|
"""
|
|
|
- token_ids = generate(
|
|
|
|
|
|
|
+ token_ids = generate( # function uses `with torch.no_grad()` internally already
|
|
|
model=model,
|
|
model=model,
|
|
|
- idx=text_to_token_ids(message.content, tokenizer), # The user text is provided via as `message.content`
|
|
|
|
|
|
|
+ idx=text_to_token_ids(message.content, tokenizer).to(device), # The user text is provided via as `message.content`
|
|
|
max_new_tokens=50,
|
|
max_new_tokens=50,
|
|
|
context_size=model_config["context_length"],
|
|
context_size=model_config["context_length"],
|
|
|
top_k=1,
|
|
top_k=1,
|