Pārlūkot izejas kodu

Chainlit bonus material fixes (#361)

* fix cmd

* moved idx to device

* improved code with clone().detach()

* fixed path

* fix: added extra line for pep8

* updated .gitginore

* Update ch05/06_user_interface/app_orig.py

* Update ch05/06_user_interface/app_own.py

* Apply suggestions from code review

---------

Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
Daniel Kleine 1 gadu atpakaļ
vecāks
revīzija
eefe4bf12b

+ 1 - 0
.gitignore

@@ -92,6 +92,7 @@ ch07/04_preference-tuning-with-dpo/loss-plot.pdf
 # Other
 ch05/06_user_interface/chainlit.md
 ch05/06_user_interface/.chainlit
+ch05/06_user_interface/.files
 
 # Temporary OS-related files
 .DS_Store

+ 1 - 1
ch05/06_user_interface/README.md

@@ -17,7 +17,7 @@ To implement this user interface, we use the open-source [Chainlit Python packag
 
 First, we install the `chainlit` package via
 
-```python
+```bash
 pip install chainlit
 ```
 

+ 4 - 4
ch05/06_user_interface/app_orig.py

@@ -16,6 +16,8 @@ from previous_chapters import (
     token_ids_to_text,
 )
 
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
 
 def get_model_and_tokenizer():
     """
@@ -44,8 +46,6 @@ def get_model_and_tokenizer():
 
     BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
 
-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
     settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
 
     gpt = GPTModel(BASE_CONFIG)
@@ -67,9 +67,9 @@ async def main(message: chainlit.Message):
     """
     The main Chainlit function.
     """
-    token_ids = generate(
+    token_ids = generate(  # function uses `with torch.no_grad()` internally already
         model=model,
-        idx=text_to_token_ids(message.content, tokenizer),  # The user text is provided via as `message.content`
+        idx=text_to_token_ids(message.content, tokenizer).to(device),  # The user text is provided via as `message.content`
         max_new_tokens=50,
         context_size=model_config["context_length"],
         top_k=1,

+ 5 - 5
ch05/06_user_interface/app_own.py

@@ -17,6 +17,8 @@ from previous_chapters import (
     token_ids_to_text,
 )
 
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
 
 def get_model_and_tokenizer():
     """
@@ -34,8 +36,6 @@ def get_model_and_tokenizer():
         "qkv_bias": False       # Query-key-value bias
     }
 
-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
     tokenizer = tiktoken.get_encoding("gpt2")
 
     model_path = Path("..") / "01_main-chapter-code" / "model.pth"
@@ -43,7 +43,7 @@ def get_model_and_tokenizer():
         print(f"Could not find the {model_path} file. Please run the chapter 5 code (ch05.ipynb) to generate the model.pth file.")
         sys.exit()
 
-    checkpoint = torch.load("model.pth", weights_only=True)
+    checkpoint = torch.load(model_path, weights_only=True)
     model = GPTModel(GPT_CONFIG_124M)
     model.load_state_dict(checkpoint)
     model.to(device)
@@ -60,9 +60,9 @@ async def main(message: chainlit.Message):
     """
     The main Chainlit function.
     """
-    token_ids = generate(
+    token_ids = generate(  # function uses `with torch.no_grad()` internally already
         model=model,
-        idx=text_to_token_ids(message.content, tokenizer),  # The user text is provided via as `message.content`
+        idx=text_to_token_ids(message.content, tokenizer).to(device),  # The user text is provided via as `message.content`
         max_new_tokens=50,
         context_size=model_config["context_length"],
         top_k=1,