Explorar o código

add weight sizes

rasbt hai 1 ano
pai
achega
83adc4a2ac

A diferenza do arquivo foi suprimida porque é demasiado grande
+ 31 - 31
ch05/01_main-chapter-code/ch05.ipynb


+ 7 - 12
ch05/01_main-chapter-code/gpt_generate.py

@@ -219,7 +219,7 @@ if __name__ == "__main__":
 
     torch.manual_seed(123)
 
-    CHOOSE_MODEL = "gpt2-small"
+    CHOOSE_MODEL = "gpt2-small (124M)"
     INPUT_PROMPT = "Every effort moves"
 
     BASE_CONFIG = {
@@ -230,19 +230,14 @@ if __name__ == "__main__":
     }
 
     model_configs = {
-        "gpt2-small": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
-        "gpt2-medium": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
-        "gpt2-large": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
-        "gpt2-xl": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
+        "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
+        "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
+        "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
+        "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
     }
 
-    model_sizes = {
-        "gpt2-small": "124M",
-        "gpt2-medium": "355M",
-        "gpt2-large": "774M",
-        "gpt2-xl": "1558"
-    }
+    model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
 
     BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
 
-    main(BASE_CONFIG, INPUT_PROMPT, model_sizes[CHOOSE_MODEL])
+    main(BASE_CONFIG, INPUT_PROMPT, model_size)

+ 5 - 5
ch05/02_alternative_weight_loading/weight-loading-hf-transformers.ipynb

@@ -126,13 +126,13 @@
     "\n",
     "# allowed model names\n",
     "model_names = {\n",
-    "    \"gpt2-small\": \"openai-community/gpt2\",         # 124M\n",
-    "    \"gpt2-medium\": \"openai-community/gpt2-medium\", # 355M\n",
-    "    \"gpt2-large\": \"openai-community/gpt2-large\",   # 774M\n",
-    "    \"gpt2-xl\": \"openai-community/gpt2-xl\"          # 1558M\n",
+    "    \"gpt2-small (124M)\": \"openai-community/gpt2\",\n",
+    "    \"gpt2-medium (355M)\": \"openai-community/gpt2-medium\",\n",
+    "    \"gpt2-large (774M)\": \"openai-community/gpt2-large\",\n",
+    "    \"gpt2-xl (1558M)\": \"openai-community/gpt2-xl\"\n",
     "}\n",
     "\n",
-    "CHOOSE_MODEL = \"gpt2-small\"\n",
+    "CHOOSE_MODEL = \"gpt2-small (124M)\"\n",
     "\n",
     "gpt_hf = GPT2Model.from_pretrained(model_names[CHOOSE_MODEL], cache_dir=\"checkpoints\")\n",
     "gpt_hf.eval()"

Algúns arquivos non se mostraron porque demasiados arquivos cambiaron neste cambio