|
|
@@ -18,7 +18,7 @@
|
|
|
"id": "6f678e62-7bcb-4405-86ae-dce94f494303"
|
|
|
},
|
|
|
"source": [
|
|
|
- "# Efficient Multi-Head Attention Implementations"
|
|
|
+ "# Comparing Efficient Multi-Head Attention Implementations"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
@@ -73,6 +73,9 @@
|
|
|
"id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6"
|
|
|
},
|
|
|
"source": [
|
|
|
+ "<br>\n",
|
|
|
+ " \n",
|
|
|
+ "\n",
|
|
|
"## 1) CausalAttention MHA wrapper class from chapter 3"
|
|
|
]
|
|
|
},
|
|
|
@@ -119,6 +122,9 @@
|
|
|
"id": "21930804-b327-40b1-8e63-94dcad39ce7b"
|
|
|
},
|
|
|
"source": [
|
|
|
+ "<br>\n",
|
|
|
+ " \n",
|
|
|
+ "\n",
|
|
|
"## 2) The multi-head attention class from chapter 3"
|
|
|
]
|
|
|
},
|
|
|
@@ -165,6 +171,9 @@
|
|
|
"id": "73cd11da-ea3b-4081-b483-c4965dfefbc4"
|
|
|
},
|
|
|
"source": [
|
|
|
+ "<br>\n",
|
|
|
+ " \n",
|
|
|
+ "\n",
|
|
|
"## 3) An alternative multi-head attention with combined weights"
|
|
|
]
|
|
|
},
|
|
|
@@ -286,6 +295,9 @@
|
|
|
"id": "48a042d3-ee78-4c29-bf63-d92fe6706632"
|
|
|
},
|
|
|
"source": [
|
|
|
+ "<br>\n",
|
|
|
+ " \n",
|
|
|
+ "\n",
|
|
|
"## 4) Multihead attention with PyTorch's scaled dot product attention"
|
|
|
]
|
|
|
},
|
|
|
@@ -393,6 +405,9 @@
|
|
|
"id": "351c318f-4835-4d74-8d58-a070222447c4"
|
|
|
},
|
|
|
"source": [
|
|
|
+ "<br>\n",
|
|
|
+ " \n",
|
|
|
+ "\n",
|
|
|
"## 5) Using PyTorch's torch.nn.MultiheadAttention"
|
|
|
]
|
|
|
},
|
|
|
@@ -488,6 +503,9 @@
|
|
|
"id": "a3953bff-1056-4de2-bfd1-dfccf659eee4"
|
|
|
},
|
|
|
"source": [
|
|
|
+ "<br>\n",
|
|
|
+ " \n",
|
|
|
+ "\n",
|
|
|
"## 6) Using PyTorch's torch.nn.MultiheadAttention with `scaled_dot_product_attention`"
|
|
|
]
|
|
|
},
|
|
|
@@ -548,6 +566,9 @@
|
|
|
"id": "8877de71-f84f-4f6d-bc87-7552013b6301"
|
|
|
},
|
|
|
"source": [
|
|
|
+ "<br>\n",
|
|
|
+ " \n",
|
|
|
+ "\n",
|
|
|
"## Quick speed comparison (M3 Macbook Air CPU)"
|
|
|
]
|
|
|
},
|
|
|
@@ -706,6 +727,9 @@
|
|
|
"id": "a78ff594-6cc2-496d-a302-789fa104c3c9"
|
|
|
},
|
|
|
"source": [
|
|
|
+ "<br>\n",
|
|
|
+ " \n",
|
|
|
+ "\n",
|
|
|
"## Quick speed comparison (Nvidia A100 GPU)"
|
|
|
]
|
|
|
},
|
|
|
@@ -866,6 +890,10 @@
|
|
|
"id": "dabc6575-0316-4640-a729-e616d5c17b73"
|
|
|
},
|
|
|
"source": [
|
|
|
+ "<br>\n",
|
|
|
+ " \n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
"## Speed comparison (Nvidia A100 GPU) with warmup"
|
|
|
]
|
|
|
},
|
|
|
@@ -1003,7 +1031,7 @@
|
|
|
"name": "python",
|
|
|
"nbconvert_exporter": "python",
|
|
|
"pygments_lexer": "ipython3",
|
|
|
- "version": "3.10.6"
|
|
|
+ "version": "3.10.12"
|
|
|
}
|
|
|
},
|
|
|
"nbformat": 4,
|