浏览代码

add HF equivalency tests for standalone nbs (#774)

* add HF equivalency tests for standalone nbs

* update

* update

* update

* update
Sebastian Raschka 3 月之前
父节点
当前提交
80d4732456

+ 4 - 2
.github/workflows/basic-tests-linux-uv.yml

@@ -51,8 +51,10 @@ jobs:
           pytest --ruff ch04/01_main-chapter-code/tests.py
           pytest --ruff ch04/03_kv-cache/tests.py
           pytest --ruff ch05/01_main-chapter-code/tests.py
-          pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
-          pytest --ruff ch05/12_gemma3/tests/test_gemma3.py
+          pytest --ruff ch05/07_gpt_to_llama/tests/tests_rope_and_parts.py
+          pytest --ruff ch05/07_gpt_to_llama/tests/test_llama32_nb.py
+          pytest --ruff ch05/11_qwen3/tests/test_qwen3_nb.py
+          pytest --ruff ch05/12_gemma3/tests/test_gemma3_nb.py
           pytest --ruff ch06/01_main-chapter-code/tests.py
 
       - name: Validate Selected Jupyter Notebooks (uv)

+ 4 - 2
.github/workflows/basic-tests-macos-uv.yml

@@ -50,8 +50,10 @@ jobs:
           pytest --ruff setup/02_installing-python-libraries/tests.py
           pytest --ruff ch04/01_main-chapter-code/tests.py
           pytest --ruff ch05/01_main-chapter-code/tests.py
-          pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
-          pytest --ruff ch05/12_gemma3/tests/test_gemma3.py
+          pytest --ruff ch05/07_gpt_to_llama/tests/tests_rope_and_parts.py
+          pytest --ruff ch05/07_gpt_to_llama/tests/test_llama32_nb.py
+          pytest --ruff ch05/11_qwen3/tests/test_qwen3_nb.py
+          pytest --ruff ch05/12_gemma3/tests/test_gemma3_nb.py
           pytest --ruff ch06/01_main-chapter-code/tests.py
 
       - name: Validate Selected Jupyter Notebooks (uv)

+ 0 - 1
.github/workflows/basic-tests-old-pytorch.yml

@@ -47,7 +47,6 @@ jobs:
         pytest --ruff setup/02_installing-python-libraries/tests.py
         pytest --ruff ch04/01_main-chapter-code/tests.py
         pytest --ruff ch05/01_main-chapter-code/tests.py
-        pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
         pytest --ruff ch06/01_main-chapter-code/tests.py
 
     - name: Validate Selected Jupyter Notebooks

+ 0 - 2
.github/workflows/basic-tests-pip.yml

@@ -41,7 +41,6 @@ jobs:
           source .venv/bin/activate
           pip install --upgrade pip
           pip install -r requirements.txt
-          pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
           pip install pytest pytest-ruff nbval
 
       - name: Test Selected Python Scripts
@@ -50,7 +49,6 @@ jobs:
           pytest --ruff setup/02_installing-python-libraries/tests.py
           pytest --ruff ch04/01_main-chapter-code/tests.py
           pytest --ruff ch05/01_main-chapter-code/tests.py
-          pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
           pytest --ruff ch06/01_main-chapter-code/tests.py
 
       - name: Validate Selected Jupyter Notebooks

+ 0 - 1
.github/workflows/basic-tests-pixi.yml

@@ -50,7 +50,6 @@ jobs:
           pytest --ruff setup/02_installing-python-libraries/tests.py
           pytest --ruff ch04/01_main-chapter-code/tests.py
           pytest --ruff ch05/01_main-chapter-code/tests.py
-          pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
           pytest --ruff ch06/01_main-chapter-code/tests.py
 
       - name: Validate Selected Jupyter Notebooks

+ 0 - 2
.github/workflows/basic-tests-pytorch-rc.yml

@@ -33,7 +33,6 @@ jobs:
       run: |
         curl -LsSf https://astral.sh/uv/install.sh | sh
         uv sync --dev --python=3.10  # tests for backwards compatibility
-        uv pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
         uv add pytest-ruff nbval
         uv pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
 
@@ -43,7 +42,6 @@ jobs:
         pytest --ruff setup/02_installing-python-libraries/tests.py
         pytest --ruff ch04/01_main-chapter-code/tests.py
         pytest --ruff ch05/01_main-chapter-code/tests.py
-        pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
         pytest --ruff ch06/01_main-chapter-code/tests.py
 
     - name: Validate Selected Jupyter Notebooks

+ 4 - 1
.github/workflows/basic-tests-windows-uv-pip.yml

@@ -43,6 +43,7 @@ jobs:
           pip install tensorflow-io-gcs-filesystem==0.31.0  # Explicit for Windows
           pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
           pip install pytest-ruff nbval
+          pip install -e .
 
       - name: Run Python Tests
         shell: bash
@@ -51,7 +52,9 @@ jobs:
           pytest --ruff setup/02_installing-python-libraries/tests.py
           pytest --ruff ch04/01_main-chapter-code/tests.py
           pytest --ruff ch05/01_main-chapter-code/tests.py
-          pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
+          pytest --ruff ch05/07_gpt_to_llama/tests/tests_rope_and_parts.py
+          pytest --ruff ch05/07_gpt_to_llama/tests/test_llama32_nb.py
+          pytest --ruff ch05/11_qwen3/tests/test_qwen3_nb.py
           pytest --ruff ch06/01_main-chapter-code/tests.py
 
       - name: Run Jupyter Notebook Tests

+ 0 - 1
.github/workflows/basic-tests-windows-uv.yml.disabled

@@ -51,7 +51,6 @@ jobs:
           pytest --ruff setup/02_installing-python-libraries/tests.py
           pytest --ruff ch04/01_main-chapter-code/tests.py
           pytest --ruff ch05/01_main-chapter-code/tests.py
-          pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
           pytest --ruff ch06/01_main-chapter-code/tests.py
 
       - name: Run Jupyter Notebook Tests

+ 116 - 0
ch05/07_gpt_to_llama/tests/test_llama32_nb.py

@@ -0,0 +1,116 @@
+# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
+# Source for "Build a Large Language Model From Scratch"
+#   - https://www.manning.com/books/build-a-large-language-model-from-scratch
+# Code: https://github.com/rasbt/LLMs-from-scratch
+
+import importlib
+from pathlib import Path
+
+import pytest
+import torch
+
+from llms_from_scratch.utils import import_definitions_from_notebook
+
+
+transformers_installed = importlib.util.find_spec("transformers") is not None
+
+
+@pytest.fixture
+def nb_imports():
+    nb_dir = Path(__file__).resolve().parents[1]
+    mod = import_definitions_from_notebook(nb_dir, "standalone-llama32.ipynb")
+    return mod
+
+
+@pytest.fixture
+def dummy_input():
+    torch.manual_seed(123)
+    return torch.randint(0, 100, (1, 8))  # batch size 1, seq length 8
+
+
+@pytest.fixture
+def dummy_cfg_base():
+    return {
+        "vocab_size": 100,
+        "emb_dim": 32,            # hidden_size
+        "hidden_dim": 64,         # intermediate_size (FFN)
+        "n_layers": 2,
+        "n_heads": 4,
+        "head_dim": 8,
+        "n_kv_groups": 1,
+        "dtype": torch.float32,
+        "rope_base": 500_000.0,
+        "rope_freq": {
+            "factor": 8.0,
+            "low_freq_factor": 1.0,
+            "high_freq_factor": 4.0,
+            "original_context_length": 8192,
+        },
+        "context_length": 64,
+    }
+
+
+@torch.inference_mode()
+def test_dummy_llama3_forward(dummy_cfg_base, dummy_input, nb_imports):
+    torch.manual_seed(123)
+    model = nb_imports.Llama3Model(dummy_cfg_base)
+    out = model(dummy_input)
+    assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"])
+
+
+@torch.inference_mode()
+@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
+def test_llama3_base_equivalence_with_transformers(nb_imports):
+    from transformers.models.llama import LlamaConfig, LlamaForCausalLM
+    cfg = {
+        "vocab_size": 257,
+        "context_length": 8192,
+        "emb_dim": 32,
+        "n_heads": 4,
+        "n_layers": 2,
+        "hidden_dim": 64,
+        "n_kv_groups": 2,
+        "rope_base": 500_000.0,
+        "rope_freq": {
+            "factor": 32.0,
+            "low_freq_factor": 1.0,
+            "high_freq_factor": 4.0,
+            "original_context_length": 8192,
+        },
+        "dtype": torch.float32,
+    }
+
+    ours = nb_imports.Llama3Model(cfg)
+
+    hf_cfg = LlamaConfig(
+        vocab_size=cfg["vocab_size"],
+        hidden_size=cfg["emb_dim"],
+        num_attention_heads=cfg["n_heads"],
+        num_key_value_heads=cfg["n_kv_groups"],
+        num_hidden_layers=cfg["n_layers"],
+        intermediate_size=cfg["hidden_dim"],
+        max_position_embeddings=cfg["context_length"],
+        rms_norm_eps=1e-5,
+        attention_bias=False,
+        rope_theta=cfg["rope_base"],
+        tie_word_embeddings=False,
+        attn_implementation="eager",
+        torch_dtype=torch.float32,
+        rope_scaling={
+            "type": "llama3",
+            "factor": cfg["rope_freq"]["factor"],
+            "low_freq_factor": cfg["rope_freq"]["low_freq_factor"],
+            "high_freq_factor": cfg["rope_freq"]["high_freq_factor"],
+            "original_max_position_embeddings": cfg["rope_freq"]["original_context_length"],
+        },
+    )
+    theirs = LlamaForCausalLM(hf_cfg)
+
+    hf_state = theirs.state_dict()
+    nb_imports.load_weights_into_llama(ours, {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}, hf_state)
+
+    x = torch.randint(0, cfg["vocab_size"], (2, 8), dtype=torch.long)
+    ours_logits = ours(x)
+    theirs_logits = theirs(x).logits.to(ours_logits.dtype)
+
+    torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)

+ 0 - 0
ch05/07_gpt_to_llama/tests/tests.py → ch05/07_gpt_to_llama/tests/tests_rope_and_parts.py


+ 122 - 0
ch05/11_qwen3/tests/test_qwen3_nb.py

@@ -0,0 +1,122 @@
+# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
+# Source for "Build a Large Language Model From Scratch"
+#   - https://www.manning.com/books/build-a-large-language-model-from-scratch
+# Code: https://github.com/rasbt/LLMs-from-scratch
+
+import importlib
+from pathlib import Path
+
+import pytest
+import torch
+
+from llms_from_scratch.utils import import_definitions_from_notebook
+
+
+transformers_installed = importlib.util.find_spec("transformers") is not None
+
+
+@pytest.fixture
+def nb_imports():
+    nb_dir = Path(__file__).resolve().parents[1]
+    mod = import_definitions_from_notebook(nb_dir, "standalone-qwen3.ipynb")
+    return mod
+
+
+@pytest.fixture
+def dummy_input():
+    torch.manual_seed(123)
+    return torch.randint(0, 100, (1, 8))  # batch size 1, seq length 8
+
+
+@pytest.fixture
+def dummy_cfg_base():
+    return {
+        "vocab_size": 100,
+        "emb_dim": 32,
+        "hidden_dim": 64,
+        "n_layers": 2,
+        "n_heads": 4,
+        "head_dim": 8,
+        "n_kv_groups": 1,
+        "qk_norm": False,
+        "dtype": torch.float32,
+        "rope_base": 10000,
+        "context_length": 64,
+        "num_experts": 0,
+    }
+
+
+@pytest.fixture
+def dummy_cfg_moe(dummy_cfg_base):
+    cfg = dummy_cfg_base.copy()
+    cfg.update({
+        "num_experts": 4,
+        "num_experts_per_tok": 2,
+        "moe_intermediate_size": 64,
+    })
+    return cfg
+
+
+@torch.inference_mode()
+def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input, nb_imports):
+    torch.manual_seed(123)
+    model = nb_imports.Qwen3Model(dummy_cfg_base)
+    out = model(dummy_input)
+    assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]), \
+        f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
+
+
+@torch.inference_mode()
+@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
+def test_qwen3_base_equivalence_with_transformers(nb_imports):
+    from transformers import Qwen3Config, Qwen3ForCausalLM
+
+    # Tiny config so the test is fast
+    cfg = {
+        "vocab_size": 257,
+        "context_length": 8,
+        "emb_dim": 32,
+        "n_heads": 4,
+        "n_layers": 2,
+        "hidden_dim": 64,
+        "head_dim": 8,
+        "qk_norm": True,
+        "n_kv_groups": 2,
+        "rope_base": 1_000_000.0,
+        "rope_local_base": 10_000.0,
+        "sliding_window": 4,
+        "layer_types": ["full_attention", "full_attention"],
+        "dtype": torch.float32,
+        "query_pre_attn_scalar": 256,
+    }
+    model = nb_imports.Qwen3Model(cfg)
+
+    hf_cfg = Qwen3Config(
+        vocab_size=cfg["vocab_size"],
+        max_position_embeddings=cfg["context_length"],
+        hidden_size=cfg["emb_dim"],
+        num_attention_heads=cfg["n_heads"],
+        num_hidden_layers=cfg["n_layers"],
+        intermediate_size=cfg["hidden_dim"],
+        head_dim=cfg["head_dim"],
+        num_key_value_heads=cfg["n_kv_groups"],
+        rope_theta=cfg["rope_base"],
+        rope_local_base_freq=cfg["rope_local_base"],
+        layer_types=cfg["layer_types"],
+        sliding_window=cfg["sliding_window"],
+        tie_word_embeddings=False,
+        attn_implementation="eager",
+        torch_dtype=torch.float32,
+        query_pre_attn_scalar=cfg["query_pre_attn_scalar"],
+        rope_scaling={"rope_type": "default"},
+    )
+    hf_model = Qwen3ForCausalLM(hf_cfg)
+
+    hf_state = hf_model.state_dict()
+    param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
+    nb_imports.load_weights_into_qwen(model, param_config, hf_state)
+
+    x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
+    ours_logits = model(x)
+    theirs_logits = hf_model(x).logits
+    torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)

+ 13 - 78
ch05/12_gemma3/tests/test_gemma3.py → ch05/12_gemma3/tests/test_gemma3_nb.py

@@ -4,77 +4,21 @@
 # Code: https://github.com/rasbt/LLMs-from-scratch
 
 import importlib
-import types
-import re
 from pathlib import Path
 
-import nbformat
 import pytest
 import torch
 
+from llms_from_scratch.utils import import_definitions_from_notebook
+
+
 transformers_installed = importlib.util.find_spec("transformers") is not None
 
 
-def _extract_defs_and_classes_from_code(src):
-    lines = src.splitlines()
-    kept = []
-    i = 0
-    while i < len(lines):
-        line = lines[i]
-        stripped = line.lstrip()
-        # Keep decorators attached to the next def/class
-        if stripped.startswith("@"):
-            # Look ahead: if the next non-empty line starts with def/class, keep decorator
-            j = i + 1
-            while j < len(lines) and not lines[j].strip():
-                j += 1
-            if j < len(lines) and lines[j].lstrip().startswith(("def ", "class ")):
-                kept.append(line)
-                i += 1
-                continue
-        if stripped.startswith("def ") or stripped.startswith("class "):
-            kept.append(line)
-            # capture until we leave the indentation block
-            base_indent = len(line) - len(stripped)
-            i += 1
-            while i < len(lines):
-                nxt = lines[i]
-                if nxt.strip() == "":
-                    kept.append(nxt)
-                    i += 1
-                    continue
-                indent = len(nxt) - len(nxt.lstrip())
-                if indent <= base_indent and not nxt.lstrip().startswith(("#", "@")):
-                    break
-                kept.append(nxt)
-                i += 1
-            continue
-        i += 1
-    code = "\n".join(kept)
-    code = re.sub(r"def\s+load_weights_into_gemma\s*\(\s*Gemma3Model\s*,",
-                  "def load_weights_into_gemma(model,",
-                  code)
-    return code
-
-
-def import_definitions_from_notebook(nb_dir_or_path, notebook_name):
-    nb_path = Path(nb_dir_or_path)
-    if nb_path.is_dir():
-        nb_file = nb_path / notebook_name
-    else:
-        nb_file = nb_path
-    if not nb_file.exists():
-        raise FileNotFoundError(f"Notebook not found: {nb_file}")
-
-    nb = nbformat.read(nb_file, as_version=4)
-    pieces = ["import torch", "import torch.nn as nn"]
-    for cell in nb.cells:
-        if cell.cell_type == "code":
-            pieces.append(_extract_defs_and_classes_from_code(cell.source))
-    src = "\n\n".join(pieces)
-
-    mod = types.ModuleType("gemma3_defs")
-    exec(src, mod.__dict__)
+@pytest.fixture
+def nb_imports():
+    nb_dir = Path(__file__).resolve().parents[1]
+    mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3.ipynb")
     return mod
 
 
@@ -106,25 +50,16 @@ def dummy_cfg_base():
 
 
 @torch.inference_mode()
-def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input):
-    nb_dir = Path(__file__).resolve().parents[1]
-    mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3.ipynb")
-    Gemma3Model = mod.Gemma3Model
-
+def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, nb_imports):
     torch.manual_seed(123)
-    model = Gemma3Model(dummy_cfg_base)
+    model = nb_imports.Gemma3Model(dummy_cfg_base)
     out = model(dummy_input)
-    assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]),         f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
+    assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"])
 
 
 @torch.inference_mode()
 @pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
-def test_gemma3_base_equivalence_with_transformers():
-    nb_dir = Path(__file__).resolve().parents[1]
-    mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3.ipynb")
-    Gemma3Model = mod.Gemma3Model
-    load_weights_into_gemma = mod.load_weights_into_gemma
-
+def test_gemma3_base_equivalence_with_transformers(nb_imports):
     from transformers import Gemma3TextConfig, Gemma3ForCausalLM
 
     # Tiny config so the test is fast
@@ -145,7 +80,7 @@ def test_gemma3_base_equivalence_with_transformers():
         "dtype": torch.float32,
         "query_pre_attn_scalar": 256,
     }
-    model = Gemma3Model(cfg)
+    model = nb_imports.Gemma3Model(cfg)
 
     hf_cfg = Gemma3TextConfig(
         vocab_size=cfg["vocab_size"],
@@ -170,7 +105,7 @@ def test_gemma3_base_equivalence_with_transformers():
 
     hf_state = hf_model.state_dict()
     param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
-    load_weights_into_gemma(model, param_config, hf_state)
+    nb_imports.load_weights_into_gemma(model, param_config, hf_state)
 
     x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
     ours_logits = model(x)

+ 1 - 1
pkg/llms_from_scratch/qwen3.py

@@ -116,7 +116,7 @@ QWEN3_CONFIG_30B_A3B = {
     "dtype": torch.bfloat16,
     "num_experts": 128,
     "num_experts_per_tok": 8,
-        "moe_intermediate_size": 768,
+    "moe_intermediate_size": 768,
 }
 
 

+ 124 - 0
pkg/llms_from_scratch/utils.py

@@ -0,0 +1,124 @@
+# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
+# Source for "Build a Large Language Model From Scratch"
+#   - https://www.manning.com/books/build-a-large-language-model-from-scratch
+# Code: https://github.com/rasbt/LLMs-from-scratch
+
+# Internal utility functions (not intended for public use)
+
+import ast
+import re
+import types
+from pathlib import Path
+
+import nbformat
+
+
+def _extract_imports(src: str):
+    out = []
+    try:
+        tree = ast.parse(src)
+    except SyntaxError:
+        return out
+    for node in tree.body:
+        if isinstance(node, ast.Import):
+            parts = []
+            for n in node.names:
+                parts.append(f"{n.name} as {n.asname}" if n.asname else n.name)
+            out.append("import " + ", ".join(parts))
+        elif isinstance(node, ast.ImportFrom):
+            module = node.module or ""
+            parts = []
+            for n in node.names:
+                parts.append(f"{n.name} as {n.asname}" if n.asname else n.name)
+            level = "." * node.level if getattr(node, "level", 0) else ""
+            out.append(f"from {level}{module} import " + ", ".join(parts))
+    return out
+
+
+def _extract_defs_and_classes_from_code(src):
+    lines = src.splitlines()
+    kept = []
+    i = 0
+    while i < len(lines):
+        line = lines[i]
+        stripped = line.lstrip()
+        if stripped.startswith("@"):
+            j = i + 1
+            while j < len(lines) and not lines[j].strip():
+                j += 1
+            if j < len(lines) and lines[j].lstrip().startswith(("def ", "class ")):
+                kept.append(line)
+                i += 1
+                continue
+        if stripped.startswith("def ") or stripped.startswith("class "):
+            kept.append(line)
+            base_indent = len(line) - len(stripped)
+            i += 1
+            while i < len(lines):
+                nxt = lines[i]
+                if nxt.strip() == "":
+                    kept.append(nxt)
+                    i += 1
+                    continue
+                indent = len(nxt) - len(nxt.lstrip())
+                if indent <= base_indent and not nxt.lstrip().startswith(("#", "@")):
+                    break
+                kept.append(nxt)
+                i += 1
+            continue
+        i += 1
+
+    code = "\n".join(kept)
+
+    # General rule:
+    # replace functions defined like `def load_weights_into_xxx(ClassName, ...`
+    # with `def load_weights_into_xxx(model, ...`
+    code = re.sub(
+        r"(def\s+load_weights_into_\w+\s*\()\s*\w+\s*,",
+        r"\1model,",
+        code
+    )
+    return code
+
+
+def import_definitions_from_notebook(nb_dir_or_path, notebook_name=None, *, extra_globals=None):
+    nb_path = Path(nb_dir_or_path)
+    if notebook_name is not None:
+        nb_file = nb_path / notebook_name if nb_path.is_dir() else nb_path
+    else:
+        nb_file = nb_path
+
+    if not nb_file.exists():
+        raise FileNotFoundError(f"Notebook not found: {nb_file}")
+
+    nb = nbformat.read(nb_file, as_version=4)
+
+    import_lines = []
+    seen = set()
+    for cell in nb.cells:
+        if cell.cell_type == "code":
+            for line in _extract_imports(cell.source):
+                if line not in seen:
+                    import_lines.append(line)
+                    seen.add(line)
+
+    for required in ("import torch", "import torch.nn as nn"):
+        if required not in seen:
+            import_lines.append(required)
+            seen.add(required)
+
+    pieces = []
+    for cell in nb.cells:
+        if cell.cell_type == "code":
+            pieces.append(_extract_defs_and_classes_from_code(cell.source))
+
+    src = "\n\n".join(import_lines + pieces)
+
+    mod_name = nb_file.stem.replace("-", "_").replace(" ", "_") or "notebook_defs"
+    mod = types.ModuleType(mod_name)
+
+    if extra_globals:
+        mod.__dict__.update(extra_globals)
+
+    exec(src, mod.__dict__)
+    return mod

+ 1 - 0
pyproject.toml

@@ -30,6 +30,7 @@ dev = [
     "llms-from-scratch",
     "twine>=6.1.0",
     "tokenizers>=0.21.1",
+    "safetensors>=0.6.2",
 ]
 
 [tool.ruff]