Explorar o código

MoE Nb readability improvements (#761)

Sebastian Raschka hai 3 meses
pai
achega
5febcf8a1b

+ 24 - 11
ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb

@@ -152,13 +152,28 @@
     "        self.num_experts = cfg[\"num_experts\"]\n",
     "        self.gate = nn.Linear(cfg[\"emb_dim\"], cfg[\"num_experts\"], bias=False, dtype=cfg[\"dtype\"])\n",
     "\n",
-    "        meta_device = torch.device(\"meta\")  # to reduce memory pressure and only load them when used (trades compute for memory)\n",
-    "        self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n",
-    "                                  for _ in range(cfg[\"num_experts\"])])\n",
-    "        self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n",
-    "                                  for _ in range(cfg[\"num_experts\"])])\n",
-    "        self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n",
-    "                                  for _ in range(cfg[\"num_experts\"])])\n",
+    "        # meta device to reduce memory pressure when initializing the model before loading weights\n",
+    "        meta_device = torch.device(\"meta\")\n",
+    "        self.fc1 = nn.ModuleList([\n",
+    "            nn.Linear(\n",
+    "                cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n",
+    "                bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n",
+    "            for _ in range(cfg[\"num_experts\"])]\n",
+    "        )\n",
+    "        self.fc2 = nn.ModuleList([\n",
+    "            nn.Linear(\n",
+    "                cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n",
+    "                bias=False, dtype=cfg[\"dtype\"], device=meta_device\n",
+    "                )\n",
+    "            for _ in range(cfg[\"num_experts\"])]\n",
+    "        )\n",
+    "        self.fc3 = nn.ModuleList([\n",
+    "            nn.Linear(\n",
+    "                cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"],\n",
+    "                bias=False, dtype=cfg[\"dtype\"], device=meta_device\n",
+    "                )\n",
+    "            for _ in range(cfg[\"num_experts\"])]\n",
+    "        )\n",
     "\n",
     "    def forward(self, x):\n",
     "        b, seq_len, embed_dim = x.shape\n",
@@ -194,20 +209,18 @@
     "        #     topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)\n",
     "        #     topk_probs = torch.softmax(topk_scores, dim=-1)\n",
     "        #     y = torch.zeros_like(x)\n",
-    "\n",
+    "        #\n",
     "        #     for i in range(self.num_experts_per_tok):\n",
     "        #         # expert_indices is (b, seq_len) with values in [0, num_experts)\n",
     "        #         expert_indices = topk_indices[..., i]\n",
     "        #         prob = topk_probs[..., i].unsqueeze(-1)  # (b, seq_len, 1)\n",
-    "\n",
+    "        #\n",
     "        #         # For each expert, process only the tokens assigned to it\n",
     "        #         for e in range(self.num_experts):\n",
     "        #             mask = (expert_indices == e)  # (b, seq_len) boolean mask\n",
     "        #             if mask.any():\n",
     "        #                 selected = x[mask]  # (num_tokens_e, emb_dim)\n",
-    "        #                 # Compute FF for expert e\n",
     "        #                 out = self.fc3[e](torch.nn.functional.silu(self.fc1[e](selected)) * self.fc2[e](selected))\n",
-    "        #                 # Scale by gating prob and scatter back\n",
     "        #                 y[mask] += prob[mask] * out\n",
     "        #     return y"
    ]

+ 24 - 11
ch05/11_qwen3/standalone-qwen3-moe.ipynb

@@ -152,13 +152,28 @@
     "        self.num_experts = cfg[\"num_experts\"]\n",
     "        self.gate = nn.Linear(cfg[\"emb_dim\"], cfg[\"num_experts\"], bias=False, dtype=cfg[\"dtype\"])\n",
     "\n",
-    "        meta_device = torch.device(\"meta\")  # to reduce memory pressure and only load them when used (trades compute for memory)\n",
-    "        self.fc1 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n",
-    "                                  for _ in range(cfg[\"num_experts\"])])\n",
-    "        self.fc2 = nn.ModuleList([nn.Linear(cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n",
-    "                                  for _ in range(cfg[\"num_experts\"])])\n",
-    "        self.fc3 = nn.ModuleList([nn.Linear(cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"], bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n",
-    "                                  for _ in range(cfg[\"num_experts\"])])\n",
+    "        # meta device to reduce memory pressure when initializing the model before loading weights\n",
+    "        meta_device = torch.device(\"meta\")\n",
+    "        self.fc1 = nn.ModuleList([\n",
+    "            nn.Linear(\n",
+    "                cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n",
+    "                bias=False, dtype=cfg[\"dtype\"], device=meta_device)\n",
+    "            for _ in range(cfg[\"num_experts\"])]\n",
+    "        )\n",
+    "        self.fc2 = nn.ModuleList([\n",
+    "            nn.Linear(\n",
+    "                cfg[\"emb_dim\"], cfg[\"moe_intermediate_size\"],\n",
+    "                bias=False, dtype=cfg[\"dtype\"], device=meta_device\n",
+    "                )\n",
+    "            for _ in range(cfg[\"num_experts\"])]\n",
+    "        )\n",
+    "        self.fc3 = nn.ModuleList([\n",
+    "            nn.Linear(\n",
+    "                cfg[\"moe_intermediate_size\"], cfg[\"emb_dim\"],\n",
+    "                bias=False, dtype=cfg[\"dtype\"], device=meta_device\n",
+    "                )\n",
+    "            for _ in range(cfg[\"num_experts\"])]\n",
+    "        )\n",
     "\n",
     "    def forward(self, x):\n",
     "        b, seq_len, embed_dim = x.shape\n",
@@ -194,20 +209,18 @@
     "        #     topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)\n",
     "        #     topk_probs = torch.softmax(topk_scores, dim=-1)\n",
     "        #     y = torch.zeros_like(x)\n",
-    "\n",
+    "        #\n",
     "        #     for i in range(self.num_experts_per_tok):\n",
     "        #         # expert_indices is (b, seq_len) with values in [0, num_experts)\n",
     "        #         expert_indices = topk_indices[..., i]\n",
     "        #         prob = topk_probs[..., i].unsqueeze(-1)  # (b, seq_len, 1)\n",
-    "\n",
+    "        #\n",
     "        #         # For each expert, process only the tokens assigned to it\n",
     "        #         for e in range(self.num_experts):\n",
     "        #             mask = (expert_indices == e)  # (b, seq_len) boolean mask\n",
     "        #             if mask.any():\n",
     "        #                 selected = x[mask]  # (num_tokens_e, emb_dim)\n",
-    "        #                 # Compute FF for expert e\n",
     "        #                 out = self.fc3[e](torch.nn.functional.silu(self.fc1[e](selected)) * self.fc2[e](selected))\n",
-    "        #                 # Scale by gating prob and scatter back\n",
     "        #                 y[mask] += prob[mask] * out\n",
     "        #     return y"
    ]