|
|
@@ -497,3 +497,77 @@ class Llama3ModelFast(nn.Module):
|
|
|
x = self.final_norm(x)
|
|
|
logits = self.out_head(x.to(self.cfg["dtype"]))
|
|
|
return logits
|
|
|
+
|
|
|
+
|
|
|
+def assign(left, right, tensor_name="unknown"):
|
|
|
+ if left.shape != right.shape:
|
|
|
+ raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")
|
|
|
+
|
|
|
+ if isinstance(right, torch.Tensor):
|
|
|
+ return torch.nn.Parameter(right.clone().detach())
|
|
|
+ else:
|
|
|
+ return torch.nn.Parameter(torch.tensor(right))
|
|
|
+
|
|
|
+
|
|
|
+def load_weights_into_llama(model, param_config, params):
|
|
|
+ model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
|
|
|
+
|
|
|
+ for l in range(param_config["n_layers"]):
|
|
|
+
|
|
|
+ # Load attention weights
|
|
|
+ model.trf_blocks[l].att.W_query.weight = assign(
|
|
|
+ model.trf_blocks[l].att.W_query.weight,
|
|
|
+ params[f"model.layers.{l}.self_attn.q_proj.weight"],
|
|
|
+ f"model.layers.{l}.self_attn.q_proj.weight"
|
|
|
+ )
|
|
|
+ model.trf_blocks[l].att.W_key.weight = assign(
|
|
|
+ model.trf_blocks[l].att.W_key.weight,
|
|
|
+ params[f"model.layers.{l}.self_attn.k_proj.weight"],
|
|
|
+ f"model.layers.{l}.self_attn.k_proj.weight"
|
|
|
+ )
|
|
|
+ model.trf_blocks[l].att.W_value.weight = assign(
|
|
|
+ model.trf_blocks[l].att.W_value.weight,
|
|
|
+ params[f"model.layers.{l}.self_attn.v_proj.weight"],
|
|
|
+ f"model.layers.{l}.self_attn.v_proj.weight"
|
|
|
+ )
|
|
|
+ model.trf_blocks[l].att.out_proj.weight = assign(
|
|
|
+ model.trf_blocks[l].att.out_proj.weight,
|
|
|
+ params[f"model.layers.{l}.self_attn.o_proj.weight"],
|
|
|
+ f"model.layers.{l}.self_attn.o_proj.weight"
|
|
|
+ )
|
|
|
+ model.trf_blocks[l].norm1.weight = assign(
|
|
|
+ model.trf_blocks[l].norm1.weight,
|
|
|
+ params[f"model.layers.{l}.input_layernorm.weight"],
|
|
|
+ f"model.layers.{l}.input_layernorm.weight"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Load FeedForward weights
|
|
|
+ model.trf_blocks[l].ff.fc1.weight = assign(
|
|
|
+ model.trf_blocks[l].ff.fc1.weight,
|
|
|
+ params[f"model.layers.{l}.mlp.gate_proj.weight"],
|
|
|
+ f"model.layers.{l}.mlp.gate_proj.weight"
|
|
|
+ )
|
|
|
+ model.trf_blocks[l].ff.fc2.weight = assign(
|
|
|
+ model.trf_blocks[l].ff.fc2.weight,
|
|
|
+ params[f"model.layers.{l}.mlp.up_proj.weight"],
|
|
|
+ f"model.layers.{l}.mlp.up_proj.weight"
|
|
|
+ )
|
|
|
+ model.trf_blocks[l].ff.fc3.weight = assign(
|
|
|
+ model.trf_blocks[l].ff.fc3.weight,
|
|
|
+ params[f"model.layers.{l}.mlp.down_proj.weight"],
|
|
|
+ f"model.layers.{l}.mlp.down_proj.weight"
|
|
|
+ )
|
|
|
+ model.trf_blocks[l].norm2.weight = assign(
|
|
|
+ model.trf_blocks[l].norm2.weight,
|
|
|
+ params[f"model.layers.{l}.post_attention_layernorm.weight"],
|
|
|
+ f"model.layers.{l}.post_attention_layernorm.weight"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Load output layer weights
|
|
|
+ model.final_norm.weight = assign(model.final_norm.weight, params["model.norm.weight"], "model.norm.weight")
|
|
|
+
|
|
|
+ if "lm_head.weight" in params.keys():
|
|
|
+ model.out_head.weight = assign(model.out_head.weight, params["lm_head.weight"], "lm_head.weight")
|
|
|
+ else:
|
|
|
+ model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
|
|
|
+ print("Model uses weight tying.")
|