utils.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. # Internal utility functions (not intended for public use)
  6. import ast
  7. import re
  8. import types
  9. from pathlib import Path
  10. import nbformat
  11. def _extract_imports(src: str):
  12. out = []
  13. try:
  14. tree = ast.parse(src)
  15. except SyntaxError:
  16. return out
  17. for node in tree.body:
  18. if isinstance(node, ast.Import):
  19. parts = []
  20. for n in node.names:
  21. parts.append(f"{n.name} as {n.asname}" if n.asname else n.name)
  22. out.append("import " + ", ".join(parts))
  23. elif isinstance(node, ast.ImportFrom):
  24. module = node.module or ""
  25. parts = []
  26. for n in node.names:
  27. parts.append(f"{n.name} as {n.asname}" if n.asname else n.name)
  28. level = "." * node.level if getattr(node, "level", 0) else ""
  29. out.append(f"from {level}{module} import " + ", ".join(parts))
  30. return out
  31. def _extract_defs_and_classes_from_code(src):
  32. lines = src.splitlines()
  33. kept = []
  34. i = 0
  35. while i < len(lines):
  36. line = lines[i]
  37. stripped = line.lstrip()
  38. if stripped.startswith("@"):
  39. j = i + 1
  40. while j < len(lines) and not lines[j].strip():
  41. j += 1
  42. if j < len(lines) and lines[j].lstrip().startswith(("def ", "class ")):
  43. kept.append(line)
  44. i += 1
  45. continue
  46. if stripped.startswith("def ") or stripped.startswith("class "):
  47. kept.append(line)
  48. base_indent = len(line) - len(stripped)
  49. i += 1
  50. while i < len(lines):
  51. nxt = lines[i]
  52. if nxt.strip() == "":
  53. kept.append(nxt)
  54. i += 1
  55. continue
  56. indent = len(nxt) - len(nxt.lstrip())
  57. if indent <= base_indent and not nxt.lstrip().startswith(("#", "@")):
  58. break
  59. kept.append(nxt)
  60. i += 1
  61. continue
  62. i += 1
  63. code = "\n".join(kept)
  64. # General rule:
  65. # replace functions defined like `def load_weights_into_xxx(ClassName, ...`
  66. # with `def load_weights_into_xxx(model, ...`
  67. code = re.sub(
  68. r"(def\s+load_weights_into_\w+\s*\()\s*\w+\s*,",
  69. r"\1model,",
  70. code
  71. )
  72. return code
  73. def import_definitions_from_notebook(nb_dir_or_path, notebook_name=None, *, extra_globals=None):
  74. nb_path = Path(nb_dir_or_path)
  75. if notebook_name is not None:
  76. nb_file = nb_path / notebook_name if nb_path.is_dir() else nb_path
  77. else:
  78. nb_file = nb_path
  79. if not nb_file.exists():
  80. raise FileNotFoundError(f"Notebook not found: {nb_file}")
  81. nb = nbformat.read(nb_file, as_version=4)
  82. import_lines = []
  83. seen = set()
  84. for cell in nb.cells:
  85. if cell.cell_type == "code":
  86. for line in _extract_imports(cell.source):
  87. if line not in seen:
  88. import_lines.append(line)
  89. seen.add(line)
  90. for required in ("import torch", "import torch.nn as nn"):
  91. if required not in seen:
  92. import_lines.append(required)
  93. seen.add(required)
  94. pieces = []
  95. for cell in nb.cells:
  96. if cell.cell_type == "code":
  97. pieces.append(_extract_defs_and_classes_from_code(cell.source))
  98. src = "\n\n".join(import_lines + pieces)
  99. mod_name = nb_file.stem.replace("-", "_").replace(" ", "_") or "notebook_defs"
  100. mod = types.ModuleType(mod_name)
  101. if extra_globals:
  102. mod.__dict__.update(extra_globals)
  103. exec(src, mod.__dict__)
  104. return mod