python_environment_check.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. import importlib
  6. from os.path import dirname, join, realpath
  7. from packaging.version import parse as version_parse
  8. import platform
  9. import sys
  10. if version_parse(platform.python_version()) < version_parse('3.9'):
  11. print('[FAIL] We recommend Python 3.9 or newer but'
  12. ' found version %s' % (sys.version))
  13. else:
  14. print('[OK] Your Python version is %s' % (platform.python_version()))
  15. def get_packages(pkgs):
  16. versions = []
  17. for p in pkgs:
  18. try:
  19. imported = importlib.import_module(p)
  20. try:
  21. version = (getattr(imported, '__version__', None) or
  22. getattr(imported, 'version', None) or
  23. getattr(imported, 'version_info', None))
  24. if version is None:
  25. # If common attributes don't exist, use importlib.metadata
  26. version = importlib.metadata.version(p)
  27. versions.append(version)
  28. except importlib.metadata.PackageNotFoundError:
  29. # Handle case where package is not installed
  30. versions.append('0.0')
  31. except ImportError:
  32. # Fallback if importlib.import_module fails for unexpected reasons
  33. versions.append('0.0')
  34. return versions
  35. def get_requirements_dict():
  36. PROJECT_ROOT = dirname(realpath(__file__))
  37. PROJECT_ROOT_UP_TWO = dirname(dirname(PROJECT_ROOT))
  38. REQUIREMENTS_FILE = join(PROJECT_ROOT_UP_TWO, "requirements.txt")
  39. d = {}
  40. with open(REQUIREMENTS_FILE) as f:
  41. for line in f:
  42. if not line.strip():
  43. continue
  44. line = line.split("#")[0].strip()
  45. line = line.split(" ")
  46. line = [l.strip() for l in line]
  47. d[line[0]] = line[-1]
  48. return d
  49. def check_packages(d):
  50. versions = get_packages(d.keys())
  51. for (pkg_name, suggested_ver), actual_ver in zip(d.items(), versions):
  52. if actual_ver == 'N/A':
  53. continue
  54. actual_ver, suggested_ver = version_parse(actual_ver), version_parse(suggested_ver)
  55. if actual_ver < suggested_ver:
  56. print(f'[FAIL] {pkg_name} {actual_ver}, please upgrade to >= {suggested_ver}')
  57. else:
  58. print(f'[OK] {pkg_name} {actual_ver}')
  59. def main():
  60. d = get_requirements_dict()
  61. check_packages(d)
  62. if __name__ == '__main__':
  63. main()