Ver código fonte

Llama3 from scratch improvements (#621)

* Llama3 from scratch improvements

* restore
Sebastian Raschka 7 meses atrás
pai
commit
47c036058d
1 arquivos alterados com 38 adições e 11 exclusões
  1. 38 11
      ch05/07_gpt_to_llama/README.md

+ 38 - 11
ch05/07_gpt_to_llama/README.md

@@ -17,13 +17,16 @@ This folder contains code for converting the GPT implementation from chapter 4 a
 For an easy way to use the Llama 3.2 1B and 3B models, you can also use the `llms-from-scratch` PyPI package based on the source code in this repository at [pkg/llms_from_scratch](../../pkg/llms_from_scratch).
 
  
-##### 1) Installation
+#### 1) Installation
 
 ```bash
 pip install llms_from_scratch blobfile
 ```
+
+(Note that `blobfile` is needed to load the tokenizer.)
+
  
-##### 2) Model and text generation settings
+#### 2) Model and text generation settings
 
 Specify which model to use:
 
@@ -51,7 +54,7 @@ TOP_K = 1
 ```
 
  
-##### 3) Weight download and loading
+#### 3) Weight download and loading
 
 This automatically downloads the weight file based on the model choice above:
 
@@ -82,7 +85,7 @@ else:
 LLAMA32_CONFIG["context_length"] = MODEL_CONTEXT_LENGTH
 
 model = Llama3Model(LLAMA32_CONFIG)
-model.load_state_dict(torch.load(MODEL_FILE, weights_only=True))
+model.load_state_dict(torch.load(MODEL_FILE, weights_only=True, map_location="cpu"))
 
 device = (
     torch.device("cuda") if torch.cuda.is_available() else
@@ -93,7 +96,7 @@ model.to(device)
 ```
 
  
-##### 4) Initialize tokenizer
+#### 4) Initialize tokenizer
 
 The following code downloads and initializes the tokenizer:
 
@@ -115,14 +118,14 @@ if "instruct" in MODEL_FILE:
 ```
 
  
-##### 5) Generating text
+#### 5) Generating text
 
 Lastly, we can generate text via the following code:
 
 ```python
 import time
 
-from llms_from_scratch.ch05 import (
+from ch05 import (
     generate,
     text_to_token_ids,
     token_ids_to_text
@@ -141,7 +144,9 @@ token_ids = generate(
     temperature=TEMPERATURE
 )
 
-print(f"Time: {time.time() - start:.2f} sec")
+total_time = time.time() - start
+print(f"Time: {total_time:.2f} sec")
+print(f"{int(len(token_ids[0])/total_time)} tokens/sec")
 
 if torch.cuda.is_available():
     max_mem_bytes = torch.cuda.max_memory_allocated()
@@ -159,7 +164,8 @@ print("\n\nOutput text:\n\n", output_text)
 When using the Llama 3.2 1B Instruct model, the output should look similar to the one shown below:
 
 ```
-Time: 4.12 sec
+Time: 3.17 sec
+50 tokens/sec
 Max memory allocated: 2.91 GB
 
 
@@ -176,7 +182,22 @@ It's worth noting that the specific diet of llamas can vary depending on factors
 ```
 
  
-**Pro tip**
+#### Pro tip 1: speed up inference with FlashAttention
+
+Instead of using `Llama3Model`, you can use `Llama3ModelFast` as a drop-in replacement. For more information, I encourage you to inspect the [pkg/llms_from_scratch/llama3.py](../../pkg/llms_from_scratch/llama3.py) code.
+
+The `Llama3ModelFast` replaces my from-scratch scaled dot-product code in the `GroupedQueryAttention` module with PyTorch's `scaled_dot_product` function, which uses `FlashAttention` on Ampere GPUs or newer.
+
+The following table shows a performance comparison on an A100:
+
+|                 | Tokens/sec | Memory  |
+| --------------- | ---------- | ------- |
+| Llama3Model     | 50         | 2.91 GB |
+| Llama3ModelFast | 58         | 2.85 GB |
+
+ 
+#### Pro tip 2: speed up inference with compilation
+
 
 For up to a 4× speed-up, replace
 
@@ -191,5 +212,11 @@ model = torch.compile(model)
 model.to(device)
 ```
 
-Note: the speed-up takes effect after the first `generate` call.
+Note: There is a significant multi-minute upfront cost when compiling, and the speed-up takes effect after the first `generate` call. 
+
+The following table shows a performance comparison on an A100 for consequent `generate` calls:
 
+|                 | Tokens/sec | Memory  |
+| --------------- | ---------- | ------- |
+| Llama3Model     | 156        | 3.12 GB |
+| Llama3ModelFast | 159        | 2.84 GB |