瀏覽代碼

Update dependency checker (#218)

* update environment checker

* update requirements.txt
Sebastian Raschka 1 年之前
父節點
當前提交
e3cd400e5f
共有 2 個文件被更改,包括 47 次插入25 次删除
  1. 5 5
      requirements.txt
  2. 42 20
      setup/02_installing-python-libraries/python_environment_check.py

+ 5 - 5
requirements.txt

@@ -1,10 +1,10 @@
 torch >= 2.0.1        # all
 jupyterlab >= 4.0     # all
-tiktoken >= 0.5.1     # ch02, ch04, ch05
-matplotlib >= 3.7.1   # ch04, ch05
+tiktoken >= 0.5.1     # ch02; ch04; ch05
+matplotlib >= 3.7.1   # ch04; ch05
 numpy >= 1.24.3       # ch05
 tensorflow >= 2.15.0  # ch05
-tqdm >= 4.66.1        # ch05, ch07
-numpy < 2.0           # dependency of several other libraries like torch and pandas
+tqdm >= 4.66.1        # ch05; ch07
+numpy >= 1.25, < 2.0  # dependency of several other libraries like torch and pandas
 pandas >= 2.2.1       # ch06
-psutil >= 5.9.5       # ch07, already installed automatically as dependency of torch
+psutil >= 5.9.5       # ch07; already installed automatically as dependency of torch

+ 42 - 20
setup/02_installing-python-libraries/python_environment_check.py

@@ -10,11 +10,11 @@ from packaging.version import parse as version_parse
 import platform
 import sys
 
-if version_parse(platform.python_version()) < version_parse('3.9'):
-    print('[FAIL] We recommend Python 3.9 or newer but'
-          ' found version %s' % (sys.version))
+if version_parse(platform.python_version()) < version_parse("3.9"):
+    print("[FAIL] We recommend Python 3.9 or newer but"
+          " found version %s" % (sys.version))
 else:
-    print('[OK] Your Python version is %s' % (platform.python_version()))
+    print("[OK] Your Python version is %s" % (platform.python_version()))
 
 
 def get_packages(pkgs):
@@ -23,19 +23,19 @@ def get_packages(pkgs):
         try:
             imported = import_module(p)
             try:
-                version = (getattr(imported, '__version__', None) or
-                           getattr(imported, 'version', None) or
-                           getattr(imported, 'version_info', None))
+                version = (getattr(imported, "__version__", None) or
+                           getattr(imported, "version", None) or
+                           getattr(imported, "version_info", None))
                 if version is None:
-                    # If common attributes don't exist, use importlib.metadata
+                    # If common attributes don"t exist, use importlib.metadata
                     version = importlib.metadata.version(p)
                 versions.append(version)
             except PackageNotFoundError:
                 # Handle case where package is not installed
-                versions.append('0.0')
+                versions.append("0.0")
         except ImportError:
             # Fallback if importlib.import_module fails for unexpected reasons
-            versions.append('0.0')
+            versions.append("0.0")
     return versions
 
 
@@ -48,10 +48,20 @@ def get_requirements_dict():
         for line in f:
             if not line.strip():
                 continue
-            line = line.split("#")[0].strip()
-            line = line.split(" ")
-            line = [l.strip() for l in line]
-            d[line[0]] = line[-1]
+            if "," in line:
+                left, right = line.split(",")
+                lower = right.split("#")[0].strip()
+                package, _, upper = left.split(" ")
+                package = package.strip()
+                _, lower = lower.split(" ")
+                lower = lower.strip()
+                upper = upper.strip()
+                d[package] = (upper, lower)
+            else:
+                line = line.split("#")[0].strip()
+                line = line.split(" ")
+                line = [ln.strip() for ln in line]
+                d[line[0]] = line[-1]
     return d
 
 
@@ -59,13 +69,25 @@ def check_packages(d):
     versions = get_packages(d.keys())
 
     for (pkg_name, suggested_ver), actual_ver in zip(d.items(), versions):
-        if actual_ver == 'N/A':
+        if isinstance(suggested_ver, tuple):
+            lower, upper = suggested_ver[0], suggested_ver[1]
+        else:
+            lower = suggested_ver
+            upper = None
+        if actual_ver == "N/A":
             continue
-        actual_ver, suggested_ver = version_parse(actual_ver), version_parse(suggested_ver)
-        if actual_ver < suggested_ver:
-            print(f'[FAIL] {pkg_name} {actual_ver}, please upgrade to >= {suggested_ver}')
+        actual_ver = version_parse(actual_ver)
+        lower = version_parse(lower)
+        if upper is not None:
+            upper = version_parse(upper)
+        if actual_ver < lower and upper is None:
+            print(f"[FAIL] {pkg_name} {actual_ver}, please upgrade to >= {lower}")
+        elif actual_ver < lower:
+            print(f"[FAIL] {pkg_name} {actual_ver}, please upgrade to >= {lower} and < {upper}")
+        elif upper is not None and actual_ver >= upper:
+            print(f"[FAIL] {pkg_name} {actual_ver}, please downgrade to >= {lower} and < {upper}")
         else:
-            print(f'[OK] {pkg_name} {actual_ver}')
+            print(f"[OK] {pkg_name} {actual_ver}")
 
 
 def main():
@@ -73,5 +95,5 @@ def main():
     check_packages(d)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()