浏览代码

Add missing device transfer in gpt_generate.py (#436)

Sebastian Raschka 1 年之前
父节点
当前提交
f61c008c5d
共有 1 个文件被更改,包括 1 次插入1 次删除
  1. 1 1
      ch05/01_main-chapter-code/gpt_generate.py

+ 1 - 1
ch05/01_main-chapter-code/gpt_generate.py

@@ -270,7 +270,7 @@ def main(gpt_config, input_prompt, model_size):
 
     token_ids = generate(
         model=gpt,
-        idx=text_to_token_ids(input_prompt, tokenizer),
+        idx=text_to_token_ids(input_prompt, tokenizer).to(device),
         max_new_tokens=25,
         context_size=gpt_config["context_length"],
         top_k=50,