Преглед изворни кода

Add missing device transfer in gpt_generate.py (#436)

Sebastian Raschka пре 1 година
родитељ
комит
f61c008c5d
1 измењених фајлова са 1 додато и 1 уклоњено
  1. 1 1
      ch05/01_main-chapter-code/gpt_generate.py

+ 1 - 1
ch05/01_main-chapter-code/gpt_generate.py

@@ -270,7 +270,7 @@ def main(gpt_config, input_prompt, model_size):
 
     token_ids = generate(
         model=gpt,
-        idx=text_to_token_ids(input_prompt, tokenizer),
+        idx=text_to_token_ids(input_prompt, tokenizer).to(device),
         max_new_tokens=25,
         context_size=gpt_config["context_length"],
         top_k=50,