python_environment_check.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. from importlib.metadata import PackageNotFoundError, import_module, version as get_version
  6. from os.path import dirname, exists, join, realpath
  7. from packaging.version import parse as version_parse
  8. from packaging.requirements import Requirement
  9. from packaging.specifiers import SpecifierSet
  10. import platform
  11. import sys
  12. if version_parse(platform.python_version()) < version_parse("3.9"):
  13. print("[FAIL] We recommend Python 3.9 or newer but found version %s" % sys.version)
  14. else:
  15. print("[OK] Your Python version is %s" % platform.python_version())
  16. def get_packages(pkgs):
  17. """
  18. Returns a dictionary mapping package names (in lowercase) to their installed version.
  19. """
  20. PACKAGE_MODULE_OVERRIDES = {
  21. "tensorflow-cpu": ["tensorflow", "tensorflow_cpu"],
  22. }
  23. result = {}
  24. for p in pkgs:
  25. # Determine possible module names to try.
  26. module_names = PACKAGE_MODULE_OVERRIDES.get(p.lower(), [p])
  27. version_found = None
  28. for module_name in module_names:
  29. try:
  30. imported = import_module(module_name)
  31. version_found = getattr(imported, "__version__", None)
  32. if version_found is None:
  33. try:
  34. version_found = get_version(module_name)
  35. except PackageNotFoundError:
  36. version_found = None
  37. if version_found is not None:
  38. break # Stop if we successfully got a version.
  39. except ImportError:
  40. # Also try replacing hyphens with underscores as a fallback.
  41. alt_module = module_name.replace("-", "_")
  42. if alt_module != module_name:
  43. try:
  44. imported = import_module(alt_module)
  45. version_found = getattr(imported, "__version__", None)
  46. if version_found is None:
  47. try:
  48. version_found = get_version(alt_module)
  49. except PackageNotFoundError:
  50. version_found = None
  51. if version_found is not None:
  52. break
  53. except ImportError:
  54. continue
  55. continue
  56. if version_found is None:
  57. version_found = "0.0"
  58. result[p.lower()] = version_found
  59. return result
  60. def get_requirements_dict():
  61. """
  62. Parses requirements.txt and returns a dictionary mapping package names (in lowercase)
  63. to specifier strings (e.g. ">=2.18.0,<3.0"). It uses the Requirement class from
  64. packaging.requirements to properly handle environment markers, and converts each object's
  65. specifier to a string.
  66. """
  67. PROJECT_ROOT = dirname(realpath(__file__))
  68. PROJECT_ROOT_UP_TWO = dirname(dirname(PROJECT_ROOT))
  69. REQUIREMENTS_FILE = join(PROJECT_ROOT_UP_TWO, "requirements.txt")
  70. if not exists(REQUIREMENTS_FILE):
  71. REQUIREMENTS_FILE = join(PROJECT_ROOT, "requirements.txt")
  72. reqs = {}
  73. with open(REQUIREMENTS_FILE) as f:
  74. for line in f:
  75. # Remove inline comments and trailing whitespace.
  76. # This splits on the first '#' and takes the part before it.
  77. line = line.split("#", 1)[0].strip()
  78. if not line:
  79. continue
  80. try:
  81. req = Requirement(line)
  82. except Exception as e:
  83. print(f"Skipping line due to parsing error: {line} ({e})")
  84. continue
  85. # Evaluate the marker if present.
  86. if req.marker is not None and not req.marker.evaluate():
  87. continue
  88. # Store the package name and its version specifier.
  89. spec = str(req.specifier) if req.specifier else ">=0"
  90. reqs[req.name.lower()] = spec
  91. return reqs
  92. def check_packages(reqs):
  93. """
  94. Checks the installed versions of packages against the requirements.
  95. """
  96. installed = get_packages(reqs.keys())
  97. for pkg_name, spec_str in reqs.items():
  98. spec_set = SpecifierSet(spec_str)
  99. actual_ver = installed.get(pkg_name, "0.0")
  100. if actual_ver == "N/A":
  101. continue
  102. actual_ver_parsed = version_parse(actual_ver)
  103. # If the installed version is a pre-release, allow pre-releases in the specifier.
  104. if actual_ver_parsed.is_prerelease:
  105. spec_set.prereleases = True
  106. if actual_ver_parsed not in spec_set:
  107. print(f"[FAIL] {pkg_name} {actual_ver_parsed}, please install a version matching {spec_set}")
  108. else:
  109. print(f"[OK] {pkg_name} {actual_ver_parsed}")
  110. def main():
  111. reqs = get_requirements_dict()
  112. check_packages(reqs)
  113. if __name__ == "__main__":
  114. main()