Эх сурвалжийг харах

added pkg fixes (#676)

Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
Daniel Kleine 5 сар өмнө
parent
commit
14c054d36c

+ 1 - 1
pkg/llms_from_scratch/ch07.py

@@ -9,7 +9,7 @@ import psutil
 import urllib
 import urllib
 
 
 import torch
 import torch
-import tqdm
+from tqdm import tqdm
 from torch.utils.data import Dataset
 from torch.utils.data import Dataset
 
 
 
 

+ 4 - 15
pkg/llms_from_scratch/llama3.py

@@ -309,22 +309,11 @@ class Llama3Tokenizer:
             special_tokens=self.special,
             special_tokens=self.special,
         )
         )
 
 
-    def encode(self, text, bos=False, eos=False, allowed_special=set()):
-        ids: list[int] = []
-
-        if bos:
-            ids.append(self.special_tokens["<|begin_of_text|>"])
-
-        # delegate to underlying tiktoken.Encoding.encode
-        ids.extend(
-            self.model.encode(
-                text,
-                allowed_special=allowed_special,
-            )
-        )
+    def encode(self, text, bos=False, eos=False):
+        ids = ([self.special["<|begin_of_text|>"]] if bos else []) \
+              + self.model.encode(text)
         if eos:
         if eos:
-            ids.append(self.special_tokens["<|end_of_text|>"])
-
+            ids.append(self.special["<|end_of_text|>"])
         return ids
         return ids
 
 
     def decode(self, ids):
     def decode(self, ids):

+ 1 - 1
pkg/llms_from_scratch/tests/test_llama3.py

@@ -199,7 +199,7 @@ def test_gpt_model_variants(ModelClass, generate_fn, llama3_weights_path):
 
 
     torch.manual_seed(123)
     torch.manual_seed(123)
     model = ModelClass(LLAMA32_CONFIG_1B)
     model = ModelClass(LLAMA32_CONFIG_1B)
-    model.load_state_dict(torch.load(llama3_weights_path))
+    model.load_state_dict(torch.load(llama3_weights_path, weights_only=True))
     model.eval()
     model.eval()
 
 
     start_context = "Llamas eat"
     start_context = "Llamas eat"