generate.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
  2. # Source for "Build a Large Language Model From Scratch"
  3. # - https://www.manning.com/books/build-a-large-language-model-from-scratch
  4. # Code: https://github.com/rasbt/LLMs-from-scratch
  5. from .utils import KVCache
  6. import torch
  7. def generate_text_simple(model, idx, max_new_tokens, context_size=None, use_cache=True):
  8. model.eval()
  9. ctx_len = context_size or model.cfg["context_length"]
  10. batch_size = idx.size(0)
  11. with torch.no_grad():
  12. if use_cache:
  13. # initialize cache and positions
  14. cache = KVCache(n_layers=model.cfg["n_layers"], batch_size=batch_size)
  15. model.reset_kv_cache(batch_size=batch_size, device=idx.device)
  16. # initial full-context pass
  17. input_ids = idx[:, -ctx_len:]
  18. seq_len = input_ids.size(1)
  19. start_pos = model.current_pos.clone()
  20. logits = model(
  21. input_ids,
  22. cache=cache,
  23. start_pos=start_pos
  24. )
  25. model.current_pos += seq_len
  26. # iterative generation
  27. for _ in range(max_new_tokens):
  28. next_token = logits[:, -1].argmax(dim=-1, keepdim=True) # (B, 1)
  29. logits = model(
  30. next_token,
  31. cache=cache,
  32. start_pos=model.current_pos.clone()
  33. )
  34. model.current_pos += 1
  35. idx = torch.cat([idx, next_token], dim=1)
  36. else:
  37. # no cache
  38. for _ in range(max_new_tokens):
  39. input_ids = idx[:, -ctx_len:]
  40. logits = model(input_ids, cache=None, start_pos=None)
  41. next_token = logits[:, -1].argmax(dim=-1, keepdim=True)
  42. idx = torch.cat([idx, next_token], dim=1)
  43. return idx