Quellcode durchsuchen

Add missing device transfer in gpt_generate.py (#436)

Sebastian Raschka vor 1 Jahr
Ursprung
Commit
f61c008c5d
1 geänderte Dateien mit 1 neuen und 1 gelöschten Zeilen
  1. 1 1
      ch05/01_main-chapter-code/gpt_generate.py

+ 1 - 1
ch05/01_main-chapter-code/gpt_generate.py

@@ -270,7 +270,7 @@ def main(gpt_config, input_prompt, model_size):
 
     token_ids = generate(
         model=gpt,
-        idx=text_to_token_ids(input_prompt, tokenizer),
+        idx=text_to_token_ids(input_prompt, tokenizer).to(device),
         max_new_tokens=25,
         context_size=gpt_config["context_length"],
         top_k=50,