Explorar o código

Add missing device transfer in gpt_generate.py (#436)

Sebastian Raschka hai 1 ano
pai
achega
f61c008c5d
Modificáronse 1 ficheiros con 1 adicións e 1 borrados
  1. 1 1
      ch05/01_main-chapter-code/gpt_generate.py

+ 1 - 1
ch05/01_main-chapter-code/gpt_generate.py

@@ -270,7 +270,7 @@ def main(gpt_config, input_prompt, model_size):
 
     token_ids = generate(
         model=gpt,
-        idx=text_to_token_ids(input_prompt, tokenizer),
+        idx=text_to_token_ids(input_prompt, tokenizer).to(device),
         max_new_tokens=25,
         context_size=gpt_config["context_length"],
         top_k=50,