|
@@ -53,19 +53,27 @@
|
|
|
"output_type": "stream",
|
|
"output_type": "stream",
|
|
|
"text": [
|
|
"text": [
|
|
|
"thop version: 0.1.1-2209072238\n",
|
|
"thop version: 0.1.1-2209072238\n",
|
|
|
- "torch version: 2.2.2\n",
|
|
|
|
|
- "tiktoken version: 0.5.1\n"
|
|
|
|
|
|
|
+ "torch version: 2.2.1+cu121\n"
|
|
|
]
|
|
]
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
|
"source": [
|
|
"source": [
|
|
|
"from importlib.metadata import version\n",
|
|
"from importlib.metadata import version\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
- "import matplotlib\n",
|
|
|
|
|
- "import torch\n",
|
|
|
|
|
- "\n",
|
|
|
|
|
- "print(\"thop version:\", version(\"thop\"))\n",
|
|
|
|
|
- "print(\"torch version:\", version(\"torch\"))"
|
|
|
|
|
|
|
+ "pkgs = [\n",
|
|
|
|
|
+ " \"thop\",\n",
|
|
|
|
|
+ " \"torch\",\n",
|
|
|
|
|
+ "]\n",
|
|
|
|
|
+ "for p in pkgs:\n",
|
|
|
|
|
+ " print(f\"{p} version: {version(p)}\")"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ " \n",
|
|
|
|
|
+ "# Simple benchmark with fixed batch size"
|
|
|
]
|
|
]
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
@@ -112,7 +120,8 @@
|
|
|
"}\n",
|
|
"}\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|
|
- "input_tensor = torch.randint(0, 50257, (2, 1024)).to(device)\n",
|
|
|
|
|
|
|
+ "batch_size = 2\n",
|
|
|
|
|
+ "input_tensor = torch.randint(0, 50257, (batch_size, 1024)).to(device)\n",
|
|
|
"\n",
|
|
"\n",
|
|
|
"for size in model_configs:\n",
|
|
"for size in model_configs:\n",
|
|
|
" BASE_CONFIG.update(model_configs[size])\n",
|
|
" BASE_CONFIG.update(model_configs[size])\n",
|
|
@@ -129,6 +138,343 @@
|
|
|
" del model\n",
|
|
" del model\n",
|
|
|
" torch.cuda.empty_cache()"
|
|
" torch.cuda.empty_cache()"
|
|
|
]
|
|
]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ " \n",
|
|
|
|
|
+ "# Simple benchmark with automatic batch size finding"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "code",
|
|
|
|
|
+ "execution_count": 4,
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "outputs": [
|
|
|
|
|
+ {
|
|
|
|
|
+ "name": "stdout",
|
|
|
|
|
+ "output_type": "stream",
|
|
|
|
|
+ "text": [
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "Processing gpt-small (124M)\n",
|
|
|
|
|
+ " Batch size 128: 3.2e+13 FLOPS\n",
|
|
|
|
|
+ " Batch size 160: 4.0e+13 FLOPS\n",
|
|
|
|
|
+ " Batch size 176: 4.5e+13 FLOPS\n",
|
|
|
|
|
+ " Batch size 184: 4.7e+13 FLOPS\n",
|
|
|
|
|
+ " Batch size 186: 4.7e+13 FLOPS\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "Processing gpt-medium (355M)\n",
|
|
|
|
|
+ " Batch size 128: 9.3e+13 FLOPS\n",
|
|
|
|
|
+ " Batch size 136: 9.8e+13 FLOPS\n",
|
|
|
|
|
+ " Batch size 140: 1.0e+14 FLOPS\n",
|
|
|
|
|
+ " Batch size 142: 1.0e+14 FLOPS\n",
|
|
|
|
|
+ " Batch size 143: 1.0e+14 FLOPS\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "Processing gpt-large (774M)\n",
|
|
|
|
|
+ " Batch size 128: 2.0e+14 FLOPS\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "Processing gpt-xl (1558M)\n",
|
|
|
|
|
+ " Batch size 64: 2.0e+14 FLOPS\n",
|
|
|
|
|
+ " Batch size 96: 3.1e+14 FLOPS\n"
|
|
|
|
|
+ ]
|
|
|
|
|
+ }
|
|
|
|
|
+ ],
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "for size in model_configs:\n",
|
|
|
|
|
+ " print(f\"\\nProcessing {size}\")\n",
|
|
|
|
|
+ " config = BASE_CONFIG.copy()\n",
|
|
|
|
|
+ " config.update(model_configs[size])\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " min_batch_size = 1\n",
|
|
|
|
|
+ " max_batch_size = None\n",
|
|
|
|
|
+ " max_possible_batch_size = 4096\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " while min_batch_size <= max_possible_batch_size:\n",
|
|
|
|
|
+ " batch_size = (min_batch_size + max_possible_batch_size) // 2\n",
|
|
|
|
|
+ " try:\n",
|
|
|
|
|
+ " input_tensor = torch.randint(\n",
|
|
|
|
|
+ " 0, config[\"vocab_size\"],\n",
|
|
|
|
|
+ " (batch_size, config[\"context_length\"]),\n",
|
|
|
|
|
+ " device=device\n",
|
|
|
|
|
+ " )\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " model = GPTModel(config).bfloat16().to(device)\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # MACS = multiply-accumulate operations\n",
|
|
|
|
|
+ " # MACS are typically counted as two FLOPS (one multiply and one accumulate)\n",
|
|
|
|
|
+ " macs, params = profile(model, inputs=(input_tensor,), verbose=False)\n",
|
|
|
|
|
+ " flops = 2 * macs\n",
|
|
|
|
|
+ " print(f\" Batch size {batch_size}: {flops:.1e} FLOPS\")\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # If successful, try a larger batch size\n",
|
|
|
|
|
+ " min_batch_size = batch_size + 1\n",
|
|
|
|
|
+ " max_batch_size = batch_size\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # Clean up\n",
|
|
|
|
|
+ " del model, input_tensor\n",
|
|
|
|
|
+ " torch.cuda.empty_cache()\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " except RuntimeError as e:\n",
|
|
|
|
|
+ " if \"out of memory\" in str(e):\n",
|
|
|
|
|
+ " # Try smaller batch size\n",
|
|
|
|
|
+ " max_possible_batch_size = batch_size - 1\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # Clean up\n",
|
|
|
|
|
+ " try:\n",
|
|
|
|
|
+ " del model, input_tensor\n",
|
|
|
|
|
+ " torch.cuda.empty_cache()\n",
|
|
|
|
|
+ " except NameError:\n",
|
|
|
|
|
+ " pass\n",
|
|
|
|
|
+ " else:\n",
|
|
|
|
|
+ " raise e"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ " \n",
|
|
|
|
|
+ "# Benchmark with automatic batch size finding and Model FLOP Utilization (MFU)"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "- Model FLOPs Utilization (MFU) explanation from the [PaLM paper](https://arxiv.org/abs/2204.02311)\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "> We propose a new metric for efficiency that is implementation-independent and permits a cleaner comparison of system efficiency, called model FLOPs utilization (MFU). This is the ratio of the observed throughput (tokens-per-second) relative to the theoretical maximum throughput of a system operating at peak FLOPs. Crucially, the “theoretical maximum” throughput only accounts for the required operations to compute the forward+backward passes, and not rematerialization.\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "$$\\text{MFU} = \\frac{\\text{Observed Tokens per Second}}{\\text{Theoretical Max Tokens per Second}}$$\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "where \n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "$$\\text{Theoretical Max Tokens per Second} = \\frac{\\text{Max FLOPs per Second}}{\\text{Total FLOPs per Token}}$$\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "and\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "$$\\text{Tokens per Second} = \\frac{\\text{Batch Size} \\times \\text{Sequence Length}}{\\text{Total Time}}$$"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "code",
|
|
|
|
|
+ "execution_count": 5,
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "outputs": [],
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "# Max flops per second provided by the GPU manufacturer\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "flops_per_second = {\n",
|
|
|
|
|
+ " \"H100\": {\n",
|
|
|
|
|
+ " torch.float32: 60e12, # 60 TFLOPs for FP32 on NVIDIA H100\n",
|
|
|
|
|
+ " torch.float16: 1.979e15, # 1979 TFLOPs for FP16 on NVIDIA H100\n",
|
|
|
|
|
+ " torch.bfloat16: 1.979e15\n",
|
|
|
|
|
+ " },\n",
|
|
|
|
|
+ " \"L4\": {\n",
|
|
|
|
|
+ " torch.float32: 15e12, # 15 TFLOPs for FP32 on NVIDIA L4\n",
|
|
|
|
|
+ " torch.float16: 30e12, # 30 TFLOPs for FP16 on NVIDIA L4\n",
|
|
|
|
|
+ " torch.bfloat16: 30e12 \n",
|
|
|
|
|
+ " },\n",
|
|
|
|
|
+ " \"T4\": {\n",
|
|
|
|
|
+ " torch.float32: 8.1e12, # 8.1 TFLOPs for FP32 on NVIDIA T4\n",
|
|
|
|
|
+ " torch.float16: 130e12, # 130 TFLOPs for FP16 on NVIDIA T4\n",
|
|
|
|
|
+ " torch.bfloat16: 130e12\n",
|
|
|
|
|
+ " },\n",
|
|
|
|
|
+ " \"A10G\": {\n",
|
|
|
|
|
+ " torch.float32: 15.6e12, # 15.6 TFLOPs for FP32 on NVIDIA A10G\n",
|
|
|
|
|
+ " torch.float16: 78e12, # 78 TFLOPs for FP16 on NVIDIA A10G\n",
|
|
|
|
|
+ " torch.bfloat16: 78e12\n",
|
|
|
|
|
+ " },\n",
|
|
|
|
|
+ " \"A100\": {\n",
|
|
|
|
|
+ " torch.float32: 19.5e12, # 19.5 TFLOPs for FP32 on NVIDIA A100\n",
|
|
|
|
|
+ " torch.float16: 1.248e15, # 1248 TFLOPs for FP16 on NVIDIA A100\n",
|
|
|
|
|
+ " torch.bfloat16: 1.248e15\n",
|
|
|
|
|
+ " },\n",
|
|
|
|
|
+ " \"H200\": {\n",
|
|
|
|
|
+ " torch.float32: 70e12, # 70 TFLOPs for FP32 on NVIDIA H200\n",
|
|
|
|
|
+ " torch.float16: 1.2e15, # Assuming 1200 TFLOPs for FP16 on NVIDIA H200\n",
|
|
|
|
|
+ " torch.bfloat16: 1.2e15\n",
|
|
|
|
|
+ " },\n",
|
|
|
|
|
+ " \"RTX_3080\": {\n",
|
|
|
|
|
+ " torch.float32: 29.8e12, # 29.8 TFLOPs for FP32 on NVIDIA RTX 3080\n",
|
|
|
|
|
+ " torch.float16: 59.6e12, # 59.6 TFLOPs for FP16 on NVIDIA RTX 3080\n",
|
|
|
|
|
+ " torch.bfloat16: 59.6e12\n",
|
|
|
|
|
+ " },\n",
|
|
|
|
|
+ " \"RTX_3090\": {\n",
|
|
|
|
|
+ " torch.float32: 35.6e12, # 35.6 TFLOPs for FP32 on NVIDIA RTX 3090\n",
|
|
|
|
|
+ " torch.float16: 71.2e12, # 71.2 TFLOPs for FP16 on NVIDIA RTX 3090\n",
|
|
|
|
|
+ " torch.bfloat16: 71.2e12\n",
|
|
|
|
|
+ " },\n",
|
|
|
|
|
+ " \"GTX_1080\": {\n",
|
|
|
|
|
+ " torch.float32: 8.9e12, # 8.9 TFLOPs for FP32 on NVIDIA GTX 1080\n",
|
|
|
|
|
+ " torch.float16: 8.9e12, # No dedicated FP16 performance; using FP32 value\n",
|
|
|
|
|
+ " torch.bfloat16: 8.9e12\n",
|
|
|
|
|
+ " },\n",
|
|
|
|
|
+ " \"GTX_1080Ti\": {\n",
|
|
|
|
|
+ " torch.float32: 11.3e12, # 11.3 TFLOPs for FP32 on NVIDIA GTX 1080Ti\n",
|
|
|
|
|
+ " torch.float16: 11.3e12, # No dedicated FP16 performance; using FP32 value\n",
|
|
|
|
|
+ " torch.bfloat16: 11.3e12\n",
|
|
|
|
|
+ " },\n",
|
|
|
|
|
+ " \"GTX_1660\": {\n",
|
|
|
|
|
+ " torch.float32: 5e12, # 5 TFLOPs for FP32 on NVIDIA GTX 1660\n",
|
|
|
|
|
+ " torch.float16: 5e12, # No dedicated FP16 performance; using FP32 value\n",
|
|
|
|
|
+ " torch.bfloat16: 5e12\n",
|
|
|
|
|
+ " },\n",
|
|
|
|
|
+ " \"GTX_1660Ti\": {\n",
|
|
|
|
|
+ " torch.float32: 5.5e12, # 5.5 TFLOPs for FP32 on NVIDIA GTX 1660Ti\n",
|
|
|
|
|
+ " torch.float16: 5.5e12, # No dedicated FP16 performance; using FP32 value\n",
|
|
|
|
|
+ " torch.bfloat16: 5.5e12\n",
|
|
|
|
|
+ " }\n",
|
|
|
|
|
+ "}\n"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "code",
|
|
|
|
|
+ "execution_count": 10,
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "outputs": [
|
|
|
|
|
+ {
|
|
|
|
|
+ "name": "stdout",
|
|
|
|
|
+ "output_type": "stream",
|
|
|
|
|
+ "text": [
|
|
|
|
|
+ "GPU Model: L4\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "Processing gpt-small (124M)\n",
|
|
|
|
|
+ " Batch size 8: Tokens/sec: 14488.21, MFU: 0.3580\n",
|
|
|
|
|
+ " Batch size 12: Tokens/sec: 15378.16, MFU: 0.3799\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "Processing gpt-medium (355M)\n",
|
|
|
|
|
+ " Batch size 2: Tokens/sec: 6493.81, MFU: 0.4591\n",
|
|
|
|
|
+ " Batch size 3: Tokens/sec: 6328.82, MFU: 0.4474\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "Processing gpt-large (774M)\n",
|
|
|
|
|
+ " Batch size 4: Tokens/sec: 3130.38, MFU: 0.4834\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "Processing gpt-xl (1558M)\n",
|
|
|
|
|
+ " Batch size 2: Tokens/sec: 1896.17, MFU: 0.5897\n"
|
|
|
|
|
+ ]
|
|
|
|
|
+ }
|
|
|
|
|
+ ],
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "import time\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "def get_gpu_model(flops_per_second_dict):\n",
|
|
|
|
|
+ " device_name = torch.cuda.get_device_name(0)\n",
|
|
|
|
|
+ " for model in flops_per_second_dict.keys():\n",
|
|
|
|
|
+ " if model in device_name:\n",
|
|
|
|
|
+ " return model\n",
|
|
|
|
|
+ " return \"Unknown\" # Default if no matching model is found\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "gpu_model = get_gpu_model(flops_per_second)\n",
|
|
|
|
|
+ "print(\"GPU Model:\", gpu_model)\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "if gpu_model != \"Unknown\":\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " for size in model_configs:\n",
|
|
|
|
|
+ " print(f\"\\nProcessing {size}\")\n",
|
|
|
|
|
+ " config = BASE_CONFIG.copy()\n",
|
|
|
|
|
+ " config.update(model_configs[size])\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " min_batch_size = 1\n",
|
|
|
|
|
+ " max_batch_size = None\n",
|
|
|
|
|
+ " max_possible_batch_size = 4096\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " while min_batch_size <= max_possible_batch_size:\n",
|
|
|
|
|
+ " batch_size = (min_batch_size + max_possible_batch_size) // 2\n",
|
|
|
|
|
+ " try:\n",
|
|
|
|
|
+ " input_tensor = torch.randint(\n",
|
|
|
|
|
+ " 0, config[\"vocab_size\"],\n",
|
|
|
|
|
+ " (batch_size, config[\"context_length\"]),\n",
|
|
|
|
|
+ " device=device\n",
|
|
|
|
|
+ " )\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " model = GPTModel(config).bfloat16().to(device)\n",
|
|
|
|
|
+ " model.train()\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # Start timing\n",
|
|
|
|
|
+ " torch.cuda.synchronize()\n",
|
|
|
|
|
+ " start_time = time.time()\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # Forward & backward pass\n",
|
|
|
|
|
+ " output = model(input_tensor)\n",
|
|
|
|
|
+ " loss = output.sum() # Compute a dummy loss \n",
|
|
|
|
|
+ " loss.backward()\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # End timing\n",
|
|
|
|
|
+ " torch.cuda.synchronize()\n",
|
|
|
|
|
+ " end_time = time.time()\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " total_time_seconds = end_time - start_time\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # Calculate FLOPs for forward pass\n",
|
|
|
|
|
+ " macs, params = profile(model, inputs=(input_tensor,), verbose=False)\n",
|
|
|
|
|
+ " flops_forward = 2 * macs # Assuming one MAC equals two FLOPs\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # Estimate FLOPs for backward pass (typically 2x forward FLOPs)\n",
|
|
|
|
|
+ " flops_backward = 2 * flops_forward\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # Total FLOPs for forward + backward passes\n",
|
|
|
|
|
+ " total_flops = flops_forward + flops_backward # Or total_flops = flops_forward * 3\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " data_type = next(model.parameters()).dtype\n",
|
|
|
|
|
+ " max_flops_per_second = flops_per_second[gpu_model].get(data_type, 0)\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # Compute tokens per second\n",
|
|
|
|
|
+ " tokens_processed = batch_size * config[\"context_length\"]\n",
|
|
|
|
|
+ " tokens_per_second = tokens_processed / total_time_seconds\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # Compute FLOPs per token\n",
|
|
|
|
|
+ " flops_per_token = total_flops / tokens_processed\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # Compute theoretical max tokens per second\n",
|
|
|
|
|
+ " if flops_per_token > 0:\n",
|
|
|
|
|
+ " theoretical_max_tokens_per_second = max_flops_per_second / flops_per_token\n",
|
|
|
|
|
+ " else:\n",
|
|
|
|
|
+ " theoretical_max_tokens_per_second = 0 # Avoid division by zero\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # Compute MFU\n",
|
|
|
|
|
+ " if theoretical_max_tokens_per_second > 0:\n",
|
|
|
|
|
+ " mfu = tokens_per_second / theoretical_max_tokens_per_second\n",
|
|
|
|
|
+ " else:\n",
|
|
|
|
|
+ " mfu = 0 # Avoid division by zero\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " print(f\" Batch size {batch_size}: Tokens/sec: {tokens_per_second:.2f}, MFU: {mfu:.4f}\")\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # If successful, try a larger batch size\n",
|
|
|
|
|
+ " min_batch_size = batch_size + 1\n",
|
|
|
|
|
+ " max_batch_size = batch_size\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # Clean up\n",
|
|
|
|
|
+ " del model, input_tensor, output, loss\n",
|
|
|
|
|
+ " torch.cuda.empty_cache()\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " except RuntimeError as e:\n",
|
|
|
|
|
+ " if \"out of memory\" in str(e).lower():\n",
|
|
|
|
|
+ " # Try smaller batch size\n",
|
|
|
|
|
+ " max_possible_batch_size = batch_size - 1\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " # Clean up\n",
|
|
|
|
|
+ " try:\n",
|
|
|
|
|
+ " del model, input_tensor\n",
|
|
|
|
|
+ " torch.cuda.empty_cache()\n",
|
|
|
|
|
+ " except NameError:\n",
|
|
|
|
|
+ " pass\n",
|
|
|
|
|
+ " else:\n",
|
|
|
|
|
+ " raise e\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "else:\n",
|
|
|
|
|
+ " print(\"Unknown GPU model. Please update the flops_per_second dictionary with your GPU information.\")"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "- Note that the batch sizes are smaller than previously because we also carry out the backward pass here, which is more memory-intensive"
|
|
|
|
|
+ ]
|
|
|
}
|
|
}
|
|
|
],
|
|
],
|
|
|
"metadata": {
|
|
"metadata": {
|