|
|
@@ -27,7 +27,7 @@ class MultiHeadAttention(nn.Module):
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
self.register_buffer(
|
|
|
"mask",
|
|
|
- torch.triu(torch.ones(context_length, context_length),diagonal=1),
|
|
|
+ torch.triu(torch.ones(context_length, context_length), diagonal=1),
|
|
|
persistent=False
|
|
|
)
|
|
|
|
|
|
@@ -35,6 +35,7 @@ class MultiHeadAttention(nn.Module):
|
|
|
# NEW
|
|
|
self.register_buffer("cache_k", None, persistent=False)
|
|
|
self.register_buffer("cache_v", None, persistent=False)
|
|
|
+ self.ptr_current_pos = 0
|
|
|
####################################################
|
|
|
|
|
|
def forward(self, x, use_cache=False):
|
|
|
@@ -71,8 +72,19 @@ class MultiHeadAttention(nn.Module):
|
|
|
# Compute scaled dot-product attention (aka self-attention) with a causal mask
|
|
|
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
|
|
|
|
|
+ ####################################################
|
|
|
+ # NEW
|
|
|
+ num_tokens_Q = queries.shape[-2]
|
|
|
+ num_tokens_K = keys.shape[-2]
|
|
|
+ if use_cache:
|
|
|
+ mask_bool = self.mask.bool()[
|
|
|
+ self.ptr_current_pos:self.ptr_current_pos + num_tokens_Q, :num_tokens_K
|
|
|
+ ]
|
|
|
+ self.ptr_current_pos += num_tokens_Q
|
|
|
+ ####################################################
|
|
|
# Original mask truncated to the number of tokens and converted to boolean
|
|
|
- mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
|
|
|
+ else:
|
|
|
+ mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K]
|
|
|
|
|
|
# Use the mask to fill attention scores
|
|
|
attn_scores.masked_fill_(mask_bool, -torch.inf)
|
|
|
@@ -93,6 +105,7 @@ class MultiHeadAttention(nn.Module):
|
|
|
# NEW
|
|
|
def reset_cache(self):
|
|
|
self.cache_k, self.cache_v = None, None
|
|
|
+ self.ptr_current_pos = 0
|
|
|
####################################################
|
|
|
|
|
|
|
|
|
@@ -264,30 +277,29 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
|
|
|
|
|
|
####################################################
|
|
|
# NEW
|
|
|
-def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
|
|
|
+def generate_text_simple_cached(model, idx, max_new_tokens,
|
|
|
+ context_size=None, use_cache=True):
|
|
|
model.eval()
|
|
|
+ ctx_len = context_size or model.pos_emb.num_embeddings
|
|
|
|
|
|
- ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024
|
|
|
- if use_cache:
|
|
|
- # Init cache with full prompt
|
|
|
- model.reset_kv_cache()
|
|
|
- with torch.no_grad():
|
|
|
+ with torch.no_grad():
|
|
|
+ if use_cache:
|
|
|
+ # Init cache with full prompt
|
|
|
+ model.reset_kv_cache()
|
|
|
logits = model(idx[:, -ctx_len:], use_cache=True)
|
|
|
|
|
|
- for _ in range(max_new_tokens):
|
|
|
- # a) pick the token with the highest log-probability (greedy sampling)
|
|
|
- next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
|
|
- # b) append it to the running sequence
|
|
|
- idx = torch.cat([idx, next_idx], dim=1)
|
|
|
- # c) feed model only the new token
|
|
|
- with torch.no_grad():
|
|
|
+ for _ in range(max_new_tokens):
|
|
|
+ # a) pick the token with the highest log-probability (greedy sampling)
|
|
|
+ next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
|
|
+ # b) append it to the running sequence
|
|
|
+ idx = torch.cat([idx, next_idx], dim=1)
|
|
|
+ # c) feed model only the new token
|
|
|
logits = model(next_idx, use_cache=True)
|
|
|
- else:
|
|
|
- for _ in range(max_new_tokens):
|
|
|
- with torch.no_grad():
|
|
|
+ else:
|
|
|
+ for _ in range(max_new_tokens):
|
|
|
logits = model(idx[:, -ctx_len:], use_cache=False)
|
|
|
- next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
|
|
- idx = torch.cat([idx, next_idx], dim=1)
|
|
|
+ next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
|
|
+ idx = torch.cat([idx, next_idx], dim=1)
|
|
|
|
|
|
return idx
|
|
|
####################################################
|