|
@@ -481,7 +481,7 @@
|
|
|
" ):\n",
|
|
" ):\n",
|
|
|
" super().__init__()\n",
|
|
" super().__init__()\n",
|
|
|
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
|
|
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
|
|
|
- " assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n",
|
|
|
|
|
|
|
+ " assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\" # NEW\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
" self.d_out = d_out\n",
|
|
" self.d_out = d_out\n",
|
|
|
" self.num_heads = num_heads\n",
|
|
" self.num_heads = num_heads\n",
|
|
@@ -886,7 +886,7 @@
|
|
|
" \"n_heads\": 32, # Number of attention heads\n",
|
|
" \"n_heads\": 32, # Number of attention heads\n",
|
|
|
" \"n_layers\": 32, # Number of layers\n",
|
|
" \"n_layers\": 32, # Number of layers\n",
|
|
|
" \"hidden_dim\": 11_008, # Size of the intermediate dimension in FeedForward\n",
|
|
" \"hidden_dim\": 11_008, # Size of the intermediate dimension in FeedForward\n",
|
|
|
- " \"dtype\": torch.bfloat16 # Lower-precision dtype to save memory\n",
|
|
|
|
|
|
|
+ " \"dtype\": torch.bfloat16 # Lower-precision dtype to reduce memory usage\n",
|
|
|
"}"
|
|
"}"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
@@ -909,7 +909,7 @@
|
|
|
" \"n_kv_groups\": 8, # NEW: Key-Value groups for grouped-query attention\n",
|
|
" \"n_kv_groups\": 8, # NEW: Key-Value groups for grouped-query attention\n",
|
|
|
" \"rope_base\": 500_000.0, # NEW: The base in RoPE's \"theta\" was increased to 500_000\n",
|
|
" \"rope_base\": 500_000.0, # NEW: The base in RoPE's \"theta\" was increased to 500_000\n",
|
|
|
" \"rope_freq\": None, # NEW: Additional configuration for adjusting the RoPE frequencies\n",
|
|
" \"rope_freq\": None, # NEW: Additional configuration for adjusting the RoPE frequencies\n",
|
|
|
- " \"dtype\": torch.bfloat16 # Lower-precision dtype to save memory\n",
|
|
|
|
|
|
|
+ " \"dtype\": torch.bfloat16 # Lower-precision dtype to reduce memory usage\n",
|
|
|
"}"
|
|
"}"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
@@ -2062,7 +2062,7 @@
|
|
|
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
|
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
|
|
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
|
|
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
|
|
|
" \"rope_freq\": None, # Additional configuration for adjusting the RoPE frequencies\n",
|
|
" \"rope_freq\": None, # Additional configuration for adjusting the RoPE frequencies\n",
|
|
|
- " \"dtype\": torch.bfloat16 # Lower-precision dtype to save memory\n",
|
|
|
|
|
|
|
+ " \"dtype\": torch.bfloat16 # Lower-precision dtype to reduce memory usage\n",
|
|
|
"}\n",
|
|
"}\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
"LLAMA31_CONFIG_8B = {\n",
|
|
"LLAMA31_CONFIG_8B = {\n",
|
|
@@ -2074,7 +2074,7 @@
|
|
|
" \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
|
|
" \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
|
|
|
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
|
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
|
|
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
|
|
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
|
|
|
- " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
|
|
|
|
|
|
|
+ " \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n",
|
|
|
" \"rope_freq\": { # NEW: RoPE frequency scaling\n",
|
|
" \"rope_freq\": { # NEW: RoPE frequency scaling\n",
|
|
|
" \"factor\": 8.0,\n",
|
|
" \"factor\": 8.0,\n",
|
|
|
" \"low_freq_factor\": 1.0,\n",
|
|
" \"low_freq_factor\": 1.0,\n",
|
|
@@ -2448,7 +2448,7 @@
|
|
|
" \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
|
|
" \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
|
|
|
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
|
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
|
|
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
|
|
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
|
|
|
- " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
|
|
|
|
|
|
|
+ " \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usagey\n",
|
|
|
" \"rope_freq\": { # NEW: RoPE frequency scaling\n",
|
|
" \"rope_freq\": { # NEW: RoPE frequency scaling\n",
|
|
|
" \"factor\": 8.0,\n",
|
|
" \"factor\": 8.0,\n",
|
|
|
" \"low_freq_factor\": 1.0,\n",
|
|
" \"low_freq_factor\": 1.0,\n",
|
|
@@ -2467,7 +2467,7 @@
|
|
|
" \"hidden_dim\": 8192, # NEW: Almost half the size of the intermediate dimension in FeedForward\n",
|
|
" \"hidden_dim\": 8192, # NEW: Almost half the size of the intermediate dimension in FeedForward\n",
|
|
|
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
|
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
|
|
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
|
|
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
|
|
|
- " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
|
|
|
|
|
|
|
+ " \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n",
|
|
|
" \"rope_freq\": { # RoPE frequency scaling\n",
|
|
" \"rope_freq\": { # RoPE frequency scaling\n",
|
|
|
" \"factor\": 32.0, # NEW: Adjustment of the rescaling factor\n",
|
|
" \"factor\": 32.0, # NEW: Adjustment of the rescaling factor\n",
|
|
|
" \"low_freq_factor\": 1.0,\n",
|
|
" \"low_freq_factor\": 1.0,\n",
|