ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240911__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  2. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  3. ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
  4. ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
  5. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
  7. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
  8. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  9. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +38 -22
  10. ai_edge_torch/generative/layers/attention.py +60 -63
  11. ai_edge_torch/generative/layers/kv_cache.py +160 -51
  12. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
  13. ai_edge_torch/generative/test/test_model_conversion.py +71 -33
  14. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  15. ai_edge_torch/generative/test/utils.py +54 -0
  16. ai_edge_torch/version.py +1 -1
  17. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/METADATA +1 -1
  18. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/RECORD +22 -32
  19. ai_edge_torch/generative/examples/experimental/gemma/__init__.py +0 -14
  20. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +0 -88
  21. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  22. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  23. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  24. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  25. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  26. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  27. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  28. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  29. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  30. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  31. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/LICENSE +0 -0
  32. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/WHEEL +0 -0
  33. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/top_level.txt +0 -0
@@ -12,14 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # A toy example which has basic transformer block (w/ KV-Cache).
15
+
16
+ """A toy example which has basic transformer block (w/ externalized KV-Cache)."""
17
+
16
18
  from typing import Tuple
17
19
 
18
20
  import ai_edge_torch
19
21
  from ai_edge_torch import lowertools
20
- from ai_edge_torch.generative.layers.attention import TransformerBlock
22
+ from ai_edge_torch.generative.layers import attention
23
+ from ai_edge_torch.generative.layers import builder
24
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
- import ai_edge_torch.generative.layers.builder as builder
23
26
  import ai_edge_torch.generative.layers.model_config as cfg
24
27
  import torch
25
28
  import torch.nn as nn
@@ -27,7 +30,7 @@ import torch.nn as nn
27
30
  RoPECache = Tuple[torch.Tensor, torch.Tensor]
28
31
 
29
32
 
30
- class ToyModelWithKV(torch.nn.Module):
33
+ class ToyModelWithKVCache(torch.nn.Module):
31
34
 
32
35
  def __init__(self, config: cfg.ModelConfig) -> None:
33
36
  super().__init__()
@@ -36,7 +39,7 @@ class ToyModelWithKV(torch.nn.Module):
36
39
  )
37
40
  self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
38
41
  self.transformer_blocks = nn.ModuleList(
39
- TransformerBlock(config) for _ in range(config.num_layers)
42
+ attention.TransformerBlock(config) for _ in range(config.num_layers)
40
43
  )
41
44
  self.final_norm = builder.build_norm(
42
45
  config.embedding_dim,
@@ -57,18 +60,29 @@ class ToyModelWithKV(torch.nn.Module):
57
60
  )
58
61
  self.config = config
59
62
 
60
- @torch.inference_mode
61
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
62
- x = self.tok_embedding(idx)
63
+ def forward(
64
+ self,
65
+ tokens: torch.Tensor,
66
+ input_pos: torch.Tensor,
67
+ kv_cache: kv_utils.KVCache,
68
+ ) -> Tuple[torch.Tensor, kv_utils.KVCache]:
69
+ x = self.tok_embedding(tokens)
63
70
  cos, sin = self.rope_cache
64
71
  cos = cos.index_select(0, input_pos)
65
72
  sin = sin.index_select(0, input_pos)
66
73
  mask = self.mask_cache.index_select(2, input_pos)
67
74
  mask = mask[:, :, :, : self.config.max_seq_len]
75
+
76
+ updated_kv_entires = []
68
77
  for i, block in enumerate(self.transformer_blocks):
69
- x = block(x, (cos, sin), mask, input_pos)
78
+ kv_entry = kv_cache.caches[i] if kv_cache else None
79
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
80
+ if kv_entry:
81
+ updated_kv_entires.append(kv_entry)
82
+
70
83
  x = self.final_norm(x)
71
- return self.lm_head(x)
84
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
85
+ return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
72
86
 
73
87
 
74
88
  def _export_stablehlo_mlir(model, args):
@@ -89,7 +103,7 @@ def get_model_config() -> cfg.ModelConfig:
89
103
  config = cfg.ModelConfig(
90
104
  vocab_size=150,
91
105
  num_layers=2,
92
- max_seq_len=500,
106
+ max_seq_len=100,
93
107
  embedding_dim=128,
94
108
  attn_config=attn_config,
95
109
  ff_config=ff_config,
@@ -102,40 +116,59 @@ def get_model_config() -> cfg.ModelConfig:
102
116
 
103
117
 
104
118
  def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
105
- idx = torch.unsqueeze(torch.arange(0, 100), 0)
119
+ tokens = torch.unsqueeze(torch.arange(0, 100), 0)
106
120
  input_pos = torch.arange(0, 100)
107
- return idx, input_pos
121
+ return tokens, input_pos
108
122
 
109
123
 
110
124
  def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
111
- idx = torch.tensor([[1]], dtype=torch.long)
112
- input_pos = torch.tensor([10], dtype=torch.int64)
113
- return idx, input_pos
125
+ tokens = torch.tensor([[1]], dtype=torch.long)
126
+ input_pos = torch.tensor([10])
127
+ return tokens, input_pos
114
128
 
115
129
 
116
130
  def define_and_run() -> None:
117
131
  dump_mlir = False
118
132
 
119
133
  config = get_model_config()
120
- model = ToyModelWithKV(config)
134
+ model = ToyModelWithExternalKV(config)
135
+ model.eval()
121
136
  print('running an inference')
122
- idx, input_pos = get_sample_prefill_inputs()
123
- decode_idx, decode_input_pos = get_sample_decode_inputs()
124
- print(model.forward(idx, input_pos))
137
+ kv = kv_utils.KVCache.from_model_config(config)
138
+
139
+ tokens, input_pos = get_sample_prefill_inputs()
140
+ decode_token, decode_input_pos = get_sample_decode_inputs()
141
+ print(model.forward(tokens, input_pos, kv))
125
142
 
126
143
  if dump_mlir:
127
- mlir_text = _export_stablehlo_mlir(model, (idx, input_pos))
128
- with open('/tmp/toy_model_with_kv.stablehlo.mlir', 'w') as f:
144
+ mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
145
+ with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
129
146
  f.write(mlir_text)
130
147
 
131
148
  # Convert model to tflite with 2 signatures (prefill + decode).
132
149
  print('converting toy model to tflite with 2 signatures (prefill + decode)')
133
150
  edge_model = (
134
- ai_edge_torch.signature('prefill', model, (idx, input_pos))
135
- .signature('decode', model, (decode_idx, decode_input_pos))
151
+ ai_edge_torch.signature(
152
+ 'prefill',
153
+ model,
154
+ sample_kwargs={
155
+ 'tokens': tokens,
156
+ 'input_pos': input_pos,
157
+ 'kv_cache': kv,
158
+ },
159
+ )
160
+ .signature(
161
+ 'decode',
162
+ model,
163
+ sample_kwargs={
164
+ 'tokens': decode_token,
165
+ 'input_pos': decode_input_pos,
166
+ 'kv_cache': kv,
167
+ },
168
+ )
136
169
  .convert()
137
170
  )
138
- edge_model.export('/tmp/toy_kv_cache.tflite')
171
+ edge_model.export('/tmp/toy_external_kv_cache.tflite')
139
172
 
140
173
 
141
174
  if __name__ == '__main__':
@@ -13,11 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting TinyLlama model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
@@ -48,20 +51,36 @@ def convert_tiny_llama_to_tflite(
48
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
49
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
50
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/tiny_llama_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/tiny_llama_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/tiny_llama')
67
- convert_tiny_llama_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama')
86
+ convert_tiny_llama_to_tflite(path)
@@ -12,13 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building a TinyLlama model from the Edge Generative API layers.
15
+
16
+ """Example of building a TinyLlama model."""
16
17
 
17
18
  import os
18
- from pathlib import Path
19
+ import pathlib
19
20
 
20
21
  from ai_edge_torch.generative.layers import attention
21
22
  from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
25
  import ai_edge_torch.generative.layers.model_config as cfg
24
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
@@ -80,16 +82,22 @@ class TinyLLamma(nn.Module):
80
82
  )
81
83
  self.config = config
82
84
 
83
- # The model's forward function takes in additional k/v cache tensors
84
- # and returns the updated k/v cache tensors to the caller.
85
- # This can be eliminated if we handle k/v cache updates inside the model itself.
86
85
  @torch.inference_mode
87
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
88
- _, seq_len = idx.size()
86
+ def forward(
87
+ self,
88
+ tokens: torch.Tensor,
89
+ input_pos: torch.Tensor,
90
+ kv_cache: kv_utils.KVCache,
91
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
92
+ _, seq_len = tokens.size()
89
93
  assert self.config.max_seq_len >= seq_len, (
90
94
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
91
95
  f" {self.config.max_seq_len}"
92
96
  )
97
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
98
+ "The number of transformer blocks and the number of KV cache entries"
99
+ " must be the same."
100
+ )
93
101
 
94
102
  cos, sin = self.rope_cache
95
103
  cos = cos.index_select(0, input_pos)
@@ -97,16 +105,20 @@ class TinyLLamma(nn.Module):
97
105
  mask = self.mask_cache.index_select(2, input_pos)
98
106
  mask = mask[:, :, :, : self.config.kv_cache_max]
99
107
 
100
- # forward the model itself
101
- x = self.tok_embedding(idx) # token embeddings of shape (b, t, n_embd)
108
+ # token embeddings of shape (b, t, n_embd)
109
+ x = self.tok_embedding(tokens)
102
110
 
103
- for _, block in enumerate(self.transformer_blocks):
104
- x = block(x, (cos, sin), mask, input_pos)
111
+ updated_kv_entires = []
112
+ for i, block in enumerate(self.transformer_blocks):
113
+ kv_entry = kv_cache.caches[i] if kv_cache else None
114
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
115
+ if kv_entry:
116
+ updated_kv_entires.append(kv_entry)
117
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
105
118
 
106
119
  x = self.final_norm(x)
107
-
108
- res = self.lm_head(x) # (b, t, vocab_size)
109
- return res
120
+ logits = self.lm_head(x) # (b, t, vocab_size)
121
+ return {"logits": logits, "kv_cache": updated_kv_cache}
110
122
 
111
123
 
112
124
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -147,8 +159,8 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
147
159
  return config
148
160
 
149
161
 
150
- def get_fake_model_config() -> cfg.ModelConfig:
151
- config = get_model_config()
162
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
163
+ config = get_model_config(**kwargs)
152
164
  config.vocab_size = 128
153
165
  config.num_layers = 2
154
166
  config.ff_config.intermediate_size = 64
@@ -160,26 +172,30 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
160
172
  model = TinyLLamma(config)
161
173
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
162
174
  loader.load(model)
175
+ model.eval()
163
176
  return model
164
177
 
165
178
 
166
- def define_and_run() -> None:
179
+ def define_and_run(checkpoint_path: str) -> None:
167
180
  """Instantiates and runs a TinyLlama model."""
168
181
 
169
- current_dir = Path(__file__).parent.resolve()
182
+ current_dir = pathlib.Path(__file__).parent.resolve()
170
183
  tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
171
184
  kv_cache_max_len = 1024
172
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/tiny_llama")
173
185
  model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
174
186
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
175
187
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
176
188
  tokens[0, :4] = idx
177
189
  input_pos = torch.arange(0, kv_cache_max_len)
178
- lm_logits = model.forward(tokens, input_pos)
190
+ kv = kv_utils.KVCache.from_model_config(model.config)
191
+ output = model.forward(tokens, input_pos, kv)
179
192
  assert torch.allclose(
180
- tiny_llama_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
193
+ tiny_llama_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
181
194
  )
182
195
 
183
196
 
184
197
  if __name__ == "__main__":
185
- define_and_run()
198
+ input_checkpoint_path = os.path.join(
199
+ pathlib.Path.home(), "Downloads/llm_data/tiny_llama"
200
+ )
201
+ define_and_run(input_checkpoint_path)
@@ -12,16 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Common building blocks for Attention layer.
16
15
 
17
- from typing import Optional, Tuple
16
+ """Common building blocks for Attention layer."""
18
17
 
19
- import ai_edge_torch.generative.layers.builder as builder
20
- from ai_edge_torch.generative.layers.kv_cache import KVCache
18
+ from typing import Optional, Tuple, Union
19
+
20
+ from ai_edge_torch.generative.layers import builder
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
21
23
  import ai_edge_torch.generative.layers.model_config as cfg
22
24
  import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
23
- from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
24
- from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
25
25
  import torch
26
26
  from torch import nn
27
27
 
@@ -62,7 +62,6 @@ class TransformerBlock(nn.Module):
62
62
  config (cfg.ModelConfig): the configuration object for this transformer
63
63
  block.
64
64
  """
65
-
66
65
  super().__init__()
67
66
  self.pre_atten_norm = builder.build_norm(
68
67
  config.embedding_dim, config.pre_attention_norm_config
@@ -71,7 +70,6 @@ class TransformerBlock(nn.Module):
71
70
  config.batch_size,
72
71
  config.embedding_dim,
73
72
  config.attn_config,
74
- config.kv_cache_max,
75
73
  config.enable_hlfb,
76
74
  )
77
75
  self.post_atten_norm = builder.build_norm(
@@ -86,7 +84,8 @@ class TransformerBlock(nn.Module):
86
84
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
87
85
  mask: Optional[torch.Tensor] = None,
88
86
  input_pos: Optional[torch.Tensor] = None,
89
- ) -> torch.Tensor:
87
+ kv_cache: kv_utils.KVCacheEntry = None,
88
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
90
89
  """Forward function of the TransformerBlock.
91
90
 
92
91
  Args:
@@ -94,24 +93,34 @@ class TransformerBlock(nn.Module):
94
93
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
95
94
  mask (torch.Tensor): the optional mask tensor.
96
95
  input_pos (torch.Tensor): the optional input position tensor.
96
+ kv_cache (KVCacheEntry): the optional kv cache entry.
97
97
 
98
98
  Returns:
99
- output activation from this transformer block.
99
+ output activation from this transformer block, and updated kv cache (if
100
+ passed in).
100
101
  """
101
-
102
+ kv = None
102
103
  if self.config.parallel_residual:
103
104
  x_norm = self.pre_atten_norm(x)
104
- attn_out = self.atten_func(x_norm, rope, mask, input_pos)
105
+ atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
106
+ if kv_cache is None:
107
+ attn_out = atten_func_out
108
+ else:
109
+ attn_out, kv = atten_func_out
105
110
  ff_out = self.ff(x_norm)
106
111
  output = x + attn_out + ff_out
107
112
  else:
108
113
  x_norm = self.pre_atten_norm(x)
109
- attn_out = self.atten_func(x_norm, rope, mask, input_pos)
114
+ atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
115
+ if kv_cache is None:
116
+ attn_out = atten_func_out
117
+ else:
118
+ attn_out, kv = atten_func_out
110
119
  x = x + attn_out
111
120
  x_norm = self.post_atten_norm(x)
112
121
  output = x + self.ff(x_norm)
113
122
 
114
- return output
123
+ return output if kv is None else (output, kv)
115
124
 
116
125
 
117
126
  class CausalSelfAttention(nn.Module):
@@ -121,7 +130,6 @@ class CausalSelfAttention(nn.Module):
121
130
  batch_size: int,
122
131
  dim: int,
123
132
  config: cfg.AttentionConfig,
124
- kv_cache_max: int,
125
133
  enable_hlfb: bool,
126
134
  ) -> None:
127
135
  """Initialize an instance of CausalSelfAttention.
@@ -130,8 +138,6 @@ class CausalSelfAttention(nn.Module):
130
138
  batch_size (int): batch size of the input tensor.
131
139
  dim (int): causal attention's input/output dimmension.
132
140
  config (cfg.AttentionConfig): attention specific configurations.
133
- kv_cache_max (int): determines the size of the KV Cache buffer, if
134
- enabled.
135
141
  enable_hlfb (bool): whether hlfb is enabled or not.
136
142
  """
137
143
  super().__init__()
@@ -147,21 +153,13 @@ class CausalSelfAttention(nn.Module):
147
153
  self.output_projection = nn.Linear(
148
154
  output_shape, dim, bias=config.output_proj_use_bias
149
155
  )
150
-
151
- # Build a k/v cache with size (batch_size, kv_cache_max, n_heads, head_dim).
152
- if config.enable_kv_cache:
153
- self.kv_cache = KVCache(
154
- batch_size,
155
- kv_cache_max,
156
- config.num_query_groups,
157
- config.head_dim,
158
- enable_hlfb,
159
- )
160
-
161
- if enable_hlfb:
162
- self.sdpa_func = scaled_dot_product_attention_with_hlfb
163
- else:
164
- self.sdpa_func = scaled_dot_product_attention
156
+ self.config = config
157
+ self.enable_hlfb = enable_hlfb
158
+ self.sdpa_func = (
159
+ sdpa.scaled_dot_product_attention_with_hlfb
160
+ if enable_hlfb
161
+ else sdpa.scaled_dot_product_attention
162
+ )
165
163
 
166
164
  def forward(
167
165
  self,
@@ -169,7 +167,8 @@ class CausalSelfAttention(nn.Module):
169
167
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
170
168
  mask: Optional[torch.Tensor] = None,
171
169
  input_pos: Optional[torch.Tensor] = None,
172
- ) -> torch.Tensor:
170
+ kv_cache: Optional[kv_utils.KVCacheEntry] = None,
171
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
173
172
  """Forward function of the CausalSelfAttention layer, which can support
174
173
 
175
174
  MQA, GQA and MHA.
@@ -179,9 +178,11 @@ class CausalSelfAttention(nn.Module):
179
178
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
180
179
  mask (torch.Tensor): the optional mask tensor.
181
180
  input_pos (torch.Tensor): the optional input position tensor.
181
+ kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
182
182
 
183
183
  Returns:
184
- output activation from this self attention layer.
184
+ output activation from this self attention layer, and the updated
185
+ KV Cach Entry (if passed in).
185
186
  """
186
187
  # Batch size, sequence length, embedding dimensionality.
187
188
  B, T, E = x.size()
@@ -224,9 +225,11 @@ class CausalSelfAttention(nn.Module):
224
225
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
225
226
  q, k = _embed_rope(q, k, n_elem, rope)
226
227
 
227
- if self.kv_cache is not None:
228
- # TODO(haoliang): Handle when execeeding max sequence length.
229
- k, v = self.kv_cache.update_cache(input_pos, k, v)
228
+ if kv_cache is not None:
229
+ kv_cache = kv_utils.update(
230
+ kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
231
+ )
232
+ k, v = kv_cache.k_cache, kv_cache.v_cache
230
233
 
231
234
  y = self.sdpa_func(
232
235
  q,
@@ -240,7 +243,7 @@ class CausalSelfAttention(nn.Module):
240
243
 
241
244
  # Compute the output projection.
242
245
  y = self.output_projection(y)
243
- return y
246
+ return y if kv_cache is None else (y, kv_cache)
244
247
 
245
248
 
246
249
  class SelfAttention(CausalSelfAttention):
@@ -251,16 +254,19 @@ class SelfAttention(CausalSelfAttention):
251
254
  x: torch.Tensor,
252
255
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
253
256
  input_pos: Optional[torch.Tensor] = None,
254
- ) -> torch.Tensor:
257
+ kv_cache: Optional[kv_utils.KVCacheEntry] = None,
258
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
255
259
  """Forward function of the SelfAttention layer, which can support MQA, GQA and MHA.
256
260
 
257
261
  Args:
258
262
  x (torch.Tensor): the input tensor.
259
263
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
260
264
  input_pos (torch.Tensor): the optional input position tensor.
265
+ kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
261
266
 
262
267
  Returns:
263
- output activation from this self attention layer.
268
+ output activation from this self attention layer, and the updated
269
+ KV Cach Entry (if passed in).
264
270
  """
265
271
  B, T, _ = x.size()
266
272
  return super().forward(
@@ -279,9 +285,8 @@ class CrossAttention(nn.Module):
279
285
  query_dim: int,
280
286
  cross_dim: int,
281
287
  config: cfg.AttentionConfig,
282
- kv_cache_max: int,
283
288
  enable_hlfb: bool,
284
- ) -> None:
289
+ ):
285
290
  """Initialize an instance of CrossAttention.
286
291
 
287
292
  Args:
@@ -289,8 +294,6 @@ class CrossAttention(nn.Module):
289
294
  query_dim (int): query tensor's dimension.
290
295
  cross_dim (int): cross attention's dimensions, for key and value tensors.
291
296
  config (cfg.AttentionConfig): attention specific configurations.
292
- kv_cache_max (int): determines the size of the KV Cache buffer, if
293
- enabled.
294
297
  enable_hlfb (bool): whether hlfb is enabled or not.
295
298
  """
296
299
  super().__init__()
@@ -309,21 +312,11 @@ class CrossAttention(nn.Module):
309
312
  query_dim, query_dim, bias=config.output_proj_use_bias
310
313
  )
311
314
 
312
- self.kv_cache = None
313
- # Build a k/v cache with size (batch_size, kv_cache_max, n_heads, head_dim).
314
- if config.enable_kv_cache:
315
- self.kv_cache = KVCache(
316
- batch_size,
317
- kv_cache_max,
318
- config.num_query_groups,
319
- self.config.head_dim,
320
- enable_hlfb,
321
- )
322
-
323
- if enable_hlfb:
324
- self.sdpa_func = scaled_dot_product_attention_with_hlfb
325
- else:
326
- self.sdpa_func = scaled_dot_product_attention
315
+ self.sdpa_func = (
316
+ sdpa.scaled_dot_product_attention_with_hlfb
317
+ if enable_hlfb
318
+ else sdpa.scaled_dot_product_attention
319
+ )
327
320
 
328
321
  def forward(
329
322
  self,
@@ -332,6 +325,7 @@ class CrossAttention(nn.Module):
332
325
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
333
326
  mask: Optional[torch.Tensor] = None,
334
327
  input_pos: Optional[torch.Tensor] = None,
328
+ kv_cache: Optional[kv_utils.KVCacheEntry] = None,
335
329
  ):
336
330
  """Forward function of the CrossAttention layer.
337
331
 
@@ -342,6 +336,7 @@ class CrossAttention(nn.Module):
342
336
  mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape
343
337
  [B, n_heads, target_seq_len, source_seq_len].
344
338
  input_pos (torch.Tensor): the optional input position tensor.
339
+ kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
345
340
 
346
341
  Returns:
347
342
  output activation from this cross attention layer.
@@ -363,9 +358,11 @@ class CrossAttention(nn.Module):
363
358
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
364
359
  q, k = _embed_rope(q, k, n_elem, rope)
365
360
 
366
- if self.kv_cache is not None:
367
- # TODO(haoliang): Handle when execeeding max sequence length.
368
- k, v = self.kv_cache.update_cache(input_pos, k, v)
361
+ if kv_cache is not None:
362
+ kv_cache = kv_utils.update(
363
+ kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
364
+ )
365
+ k, v = kv_cache.k_cache, kv_cache.v_cache
369
366
  if mask is None:
370
367
  mask = torch.zeros(
371
368
  (batch_size, 1, target_seq_len, source_seq_len), dtype=torch.float32
@@ -375,4 +372,4 @@ class CrossAttention(nn.Module):
375
372
 
376
373
  # Compute the output projection.
377
374
  y = self.output_projection(y)
378
- return y
375
+ return y if kv_cache is None else (y, kv_cache)