浏览代码

check gpt files (#208)

Sebastian Raschka 1 年之前
父节点
当前提交
bcccda728b
共有 1 个文件被更改,包括 58 次插入0 次删除
  1. 58 0
      ch05/01_main-chapter-code/tests.py

+ 58 - 0
ch05/01_main-chapter-code/tests.py

@@ -7,6 +7,8 @@
 
 import pytest
 from gpt_train import main
+import http.client
+from urllib.parse import urlparse
 
 
 @pytest.fixture
@@ -38,3 +40,59 @@ def test_main(gpt_config, other_settings):
     assert len(train_losses) == 39, "Unexpected number of training losses"
     assert len(val_losses) == 39, "Unexpected number of validation losses"
     assert len(tokens_seen) == 39, "Unexpected number of tokens seen"
+
+
+def check_file_size(url, expected_size):
+    parsed_url = urlparse(url)
+    if parsed_url.scheme == "https":
+        conn = http.client.HTTPSConnection(parsed_url.netloc)
+    else:
+        conn = http.client.HTTPConnection(parsed_url.netloc)
+
+    conn.request("HEAD", parsed_url.path)
+    response = conn.getresponse()
+    if response.status != 200:
+        return False, f"{url} not accessible"
+    size = response.getheader("Content-Length")
+    if size is None:
+        return False, "Content-Length header is missing"
+    size = int(size)
+    if size != expected_size:
+        return False, f"{url} file has expected size {expected_size}, but got {size}"
+    return True, f"{url} file size is correct"
+
+
+def test_model_files():
+    base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
+
+    model_size = "124M"
+    files = {
+        "checkpoint": 77,
+        "encoder.json": 1042301,
+        "hparams.json": 90,
+        "model.ckpt.data-00000-of-00001": 497759232,
+        "model.ckpt.index": 5215,
+        "model.ckpt.meta": 471155,
+        "vocab.bpe": 456318
+    }
+
+    for file_name, expected_size in files.items():
+        url = f"{base_url}/{model_size}/{file_name}"
+        valid, message = check_file_size(url, expected_size)
+        assert valid, message
+
+    model_size = "355M"
+    files = {
+        "checkpoint": 77,
+        "encoder.json": 1042301,
+        "hparams.json": 91,
+        "model.ckpt.data-00000-of-00001": 1419292672,
+        "model.ckpt.index": 10399,
+        "model.ckpt.meta": 926519,
+        "vocab.bpe": 456318
+    }
+
+    for file_name, expected_size in files.items():
+        url = f"{base_url}/{model_size}/{file_name}"
+        valid, message = check_file_size(url, expected_size)
+        assert valid, message