|
|
@@ -219,7 +219,7 @@ if __name__ == "__main__":
|
|
|
|
|
|
torch.manual_seed(123)
|
|
|
|
|
|
- CHOOSE_MODEL = "gpt2-small"
|
|
|
+ CHOOSE_MODEL = "gpt2-small (124M)"
|
|
|
INPUT_PROMPT = "Every effort moves"
|
|
|
|
|
|
BASE_CONFIG = {
|
|
|
@@ -230,19 +230,14 @@ if __name__ == "__main__":
|
|
|
}
|
|
|
|
|
|
model_configs = {
|
|
|
- "gpt2-small": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
|
|
|
- "gpt2-medium": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
|
|
|
- "gpt2-large": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
|
|
|
- "gpt2-xl": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
|
|
|
+ "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
|
|
|
+ "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
|
|
|
+ "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
|
|
|
+ "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
|
|
|
}
|
|
|
|
|
|
- model_sizes = {
|
|
|
- "gpt2-small": "124M",
|
|
|
- "gpt2-medium": "355M",
|
|
|
- "gpt2-large": "774M",
|
|
|
- "gpt2-xl": "1558"
|
|
|
- }
|
|
|
+ model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
|
|
|
|
|
|
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
|
|
|
|
|
|
- main(BASE_CONFIG, INPUT_PROMPT, model_sizes[CHOOSE_MODEL])
|
|
|
+ main(BASE_CONFIG, INPUT_PROMPT, model_size)
|