|
|
@@ -83,8 +83,8 @@
|
|
|
},
|
|
|
"source": [
|
|
|
"- To run all the code in this notebook, please ensure you update to at least PyTorch 2.5 (FlexAttention is not included in earlier PyTorch releases)\n",
|
|
|
- "If the code cell above shows a PyTorch version lower than 2.5, you can upgrade your PyTorch installation by uncommenting and running the following code cell (Please note that PyTorch 2.5 requires Python 3.9 or later)\n",
|
|
|
- "- For more specific instructions and CUDA versions, please refer to the official installation guide at https://pytorch.org."
|
|
|
+ "- If the code cell above shows a PyTorch version lower than 2.5, you can upgrade your PyTorch installation by uncommenting and running the following code cell (Please note that PyTorch 2.5 requires Python 3.9 or later)\n",
|
|
|
+ "- For more specific instructions and CUDA versions, please refer to the official installation guide at https://pytorch.org"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
@@ -886,12 +886,14 @@
|
|
|
"id": "d2164859-31a0-4537-b4fb-27d57675ba77"
|
|
|
},
|
|
|
"source": [
|
|
|
- "- Set `need_weights` (default `True`) to need_weights=False so that `MultiheadAttention` uses `scaled_dot_product_attention` [according to the documentation](https://github.com/pytorch/pytorch/blob/71d020262793542974cf13b30f2a9099773f015c/torch/nn/modules/activation.py#L1096)\n",
|
|
|
+ "- Set `need_weights` (default `True`) to `False` so that `MultiheadAttention` uses `scaled_dot_product_attention` [according to the documentation](https://github.com/pytorch/pytorch/blob/71d020262793542974cf13b30f2a9099773f015c/torch/nn/modules/activation.py#L1096)\n",
|
|
|
"\n",
|
|
|
- "> need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.\n",
|
|
|
- " Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``\n",
|
|
|
- " and achieve the best performance for MHA.\n",
|
|
|
- " Default: ``True``."
|
|
|
+ "```markdown\n",
|
|
|
+ "need_weights: If specified, returns `attn_output_weights` in addition to `attn_outputs`.\n",
|
|
|
+ " Set `need_weights=False` to use the optimized `scaled_dot_product_attention`\n",
|
|
|
+ " and achieve the best performance for MHA.\n",
|
|
|
+ " Default: `True`\n",
|
|
|
+ "```"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
@@ -1965,7 +1967,7 @@
|
|
|
"provenance": []
|
|
|
},
|
|
|
"kernelspec": {
|
|
|
- "display_name": "Python 3 (ipykernel)",
|
|
|
+ "display_name": "pt",
|
|
|
"language": "python",
|
|
|
"name": "python3"
|
|
|
},
|
|
|
@@ -1979,7 +1981,7 @@
|
|
|
"name": "python",
|
|
|
"nbconvert_exporter": "python",
|
|
|
"pygments_lexer": "ipython3",
|
|
|
- "version": "3.11.4"
|
|
|
+ "version": "3.11.9"
|
|
|
}
|
|
|
},
|
|
|
"nbformat": 4,
|