ai-edge-torch-nightly 0.3.0.dev20240909__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. ai_edge_torch/_convert/test/test_convert.py +35 -13
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  3. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  4. ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
  5. ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  7. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
  8. ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
  9. ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
  10. ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
  11. ai_edge_torch/generative/examples/t5/t5.py +35 -22
  12. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  13. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  14. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
  15. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  16. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
  17. ai_edge_torch/generative/layers/attention.py +77 -73
  18. ai_edge_torch/generative/layers/builder.py +5 -3
  19. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  20. ai_edge_torch/generative/layers/model_config.py +38 -19
  21. ai_edge_torch/generative/layers/normalization.py +158 -0
  22. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  23. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  24. ai_edge_torch/generative/test/test_loader.py +1 -1
  25. ai_edge_torch/generative/test/test_model_conversion.py +72 -34
  26. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  27. ai_edge_torch/generative/test/utils.py +54 -0
  28. ai_edge_torch/generative/utilities/loader.py +15 -15
  29. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  30. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  31. ai_edge_torch/odml_torch/lowerings/_convolution.py +196 -74
  32. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
  33. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  34. ai_edge_torch/version.py +1 -1
  35. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
  36. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +41 -47
  37. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  38. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  39. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  40. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  41. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  42. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  43. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  44. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  45. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  46. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  47. /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
  48. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
  49. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
  50. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -13,11 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting TinyLlama model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
@@ -48,20 +51,36 @@ def convert_tiny_llama_to_tflite(
48
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
49
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
50
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/tiny_llama_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/tiny_llama_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/tiny_llama')
67
- convert_tiny_llama_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama')
86
+ convert_tiny_llama_to_tflite(path)
@@ -12,13 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building a TinyLlama model from the Edge Generative API layers.
15
+
16
+ """Example of building a TinyLlama model."""
16
17
 
17
18
  import os
18
- from pathlib import Path
19
+ import pathlib
19
20
 
20
21
  from ai_edge_torch.generative.layers import attention
21
22
  from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
25
  import ai_edge_torch.generative.layers.model_config as cfg
24
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
@@ -42,13 +44,12 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
42
44
  )
43
45
 
44
46
 
45
- class TinyLLamma(nn.Module):
47
+ class TinyLlama(nn.Module):
46
48
  """A TinyLlama model built from the Edge Generative API layers."""
47
49
 
48
50
  def __init__(self, config: cfg.ModelConfig):
49
51
  super().__init__()
50
52
 
51
- self.config = config
52
53
  # Construct model layers.
53
54
  self.lm_head = nn.Linear(
54
55
  config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
@@ -56,18 +57,20 @@ class TinyLLamma(nn.Module):
56
57
  self.tok_embedding = nn.Embedding(
57
58
  config.vocab_size, config.embedding_dim, padding_idx=0
58
59
  )
60
+ # TinyLlama has only one block config.
61
+ block_config = config.block_config(0)
59
62
  self.transformer_blocks = nn.ModuleList(
60
- attention.TransformerBlock(config) for _ in range(config.num_layers)
63
+ attention.TransformerBlock(block_config, config)
64
+ for _ in range(config.num_layers)
61
65
  )
62
66
  self.final_norm = builder.build_norm(
63
67
  config.embedding_dim,
64
68
  config.final_norm_config,
65
69
  )
70
+ attn_config = block_config.attn_config
66
71
  self.rope_cache = attn_utils.build_rope_cache(
67
72
  size=config.kv_cache_max,
68
- dim=int(
69
- config.attn_config.rotary_percentage * config.attn_config.head_dim
70
- ),
73
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
71
74
  base=10_000,
72
75
  condense_ratio=1,
73
76
  dtype=torch.float32,
@@ -80,16 +83,22 @@ class TinyLLamma(nn.Module):
80
83
  )
81
84
  self.config = config
82
85
 
83
- # The model's forward function takes in additional k/v cache tensors
84
- # and returns the updated k/v cache tensors to the caller.
85
- # This can be eliminated if we handle k/v cache updates inside the model itself.
86
86
  @torch.inference_mode
87
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
88
- _, seq_len = idx.size()
87
+ def forward(
88
+ self,
89
+ tokens: torch.Tensor,
90
+ input_pos: torch.Tensor,
91
+ kv_cache: kv_utils.KVCache,
92
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
93
+ _, seq_len = tokens.size()
89
94
  assert self.config.max_seq_len >= seq_len, (
90
95
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
91
96
  f" {self.config.max_seq_len}"
92
97
  )
98
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
99
+ "The number of transformer blocks and the number of KV cache entries"
100
+ " must be the same."
101
+ )
93
102
 
94
103
  cos, sin = self.rope_cache
95
104
  cos = cos.index_select(0, input_pos)
@@ -97,16 +106,20 @@ class TinyLLamma(nn.Module):
97
106
  mask = self.mask_cache.index_select(2, input_pos)
98
107
  mask = mask[:, :, :, : self.config.kv_cache_max]
99
108
 
100
- # forward the model itself
101
- x = self.tok_embedding(idx) # token embeddings of shape (b, t, n_embd)
109
+ # token embeddings of shape (b, t, n_embd)
110
+ x = self.tok_embedding(tokens)
102
111
 
103
- for _, block in enumerate(self.transformer_blocks):
104
- x = block(x, (cos, sin), mask, input_pos)
112
+ updated_kv_entires = []
113
+ for i, block in enumerate(self.transformer_blocks):
114
+ kv_entry = kv_cache.caches[i] if kv_cache else None
115
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
116
+ if kv_entry:
117
+ updated_kv_entires.append(kv_entry)
118
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
105
119
 
106
120
  x = self.final_norm(x)
107
-
108
- res = self.lm_head(x) # (b, t, vocab_size)
109
- return res
121
+ logits = self.lm_head(x) # (b, t, vocab_size)
122
+ return {"logits": logits, "kv_cache": updated_kv_cache}
110
123
 
111
124
 
112
125
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -131,55 +144,63 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
131
144
  intermediate_size=5632,
132
145
  )
133
146
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
147
+ block_config = cfg.TransformerBlockConfig(
148
+ attn_config=attn_config,
149
+ ff_config=ff_config,
150
+ pre_attention_norm_config=norm_config,
151
+ post_attention_norm_config=norm_config,
152
+ )
134
153
  config = cfg.ModelConfig(
135
154
  vocab_size=32000,
136
155
  num_layers=22,
137
156
  max_seq_len=2048,
138
157
  embedding_dim=2048,
139
158
  kv_cache_max_len=kv_cache_max_len,
140
- attn_config=attn_config,
141
- ff_config=ff_config,
142
- pre_attention_norm_config=norm_config,
143
- post_attention_norm_config=norm_config,
159
+ block_configs=block_config,
144
160
  final_norm_config=norm_config,
145
161
  enable_hlfb=True,
146
162
  )
147
163
  return config
148
164
 
149
165
 
150
- def get_fake_model_config() -> cfg.ModelConfig:
151
- config = get_model_config()
166
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
167
+ config = get_model_config(**kwargs)
152
168
  config.vocab_size = 128
153
169
  config.num_layers = 2
154
- config.ff_config.intermediate_size = 64
170
+ # TinyLlama has only one block config.
171
+ config.block_config(0).ff_config.intermediate_size = 64
155
172
  return config
156
173
 
157
174
 
158
175
  def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
159
176
  config = get_model_config(**kwargs)
160
- model = TinyLLamma(config)
177
+ model = TinyLlama(config)
161
178
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
162
179
  loader.load(model)
180
+ model.eval()
163
181
  return model
164
182
 
165
183
 
166
- def define_and_run() -> None:
184
+ def define_and_run(checkpoint_path: str) -> None:
167
185
  """Instantiates and runs a TinyLlama model."""
168
186
 
169
- current_dir = Path(__file__).parent.resolve()
187
+ current_dir = pathlib.Path(__file__).parent.resolve()
170
188
  tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
171
189
  kv_cache_max_len = 1024
172
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/tiny_llama")
173
190
  model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
174
191
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
175
192
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
176
193
  tokens[0, :4] = idx
177
194
  input_pos = torch.arange(0, kv_cache_max_len)
178
- lm_logits = model.forward(tokens, input_pos)
195
+ kv = kv_utils.KVCache.from_model_config(model.config)
196
+ output = model.forward(tokens, input_pos, kv)
179
197
  assert torch.allclose(
180
- tiny_llama_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
198
+ tiny_llama_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
181
199
  )
182
200
 
183
201
 
184
202
  if __name__ == "__main__":
185
- define_and_run()
203
+ input_checkpoint_path = os.path.join(
204
+ pathlib.Path.home(), "Downloads/llm_data/tiny_llama"
205
+ )
206
+ define_and_run(input_checkpoint_path)
@@ -12,16 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Common building blocks for Attention layer.
16
15
 
17
- from typing import Optional, Tuple
16
+ """Common building blocks for Attention layer."""
18
17
 
19
- import ai_edge_torch.generative.layers.builder as builder
20
- from ai_edge_torch.generative.layers.kv_cache import KVCache
18
+ from typing import Optional, Tuple, Union
19
+
20
+ from ai_edge_torch.generative.layers import builder
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
21
23
  import ai_edge_torch.generative.layers.model_config as cfg
22
24
  import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
23
- from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
24
- from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
25
25
  import torch
26
26
  from torch import nn
27
27
 
@@ -55,29 +55,35 @@ def _embed_rope(
55
55
 
56
56
  class TransformerBlock(nn.Module):
57
57
 
58
- def __init__(self, config: cfg.ModelConfig) -> None:
58
+ def __init__(
59
+ self,
60
+ config: cfg.TransformerBlockConfig,
61
+ model_config: cfg.ModelConfig,
62
+ ) -> None:
59
63
  """Initialize an instance of the TransformerBlock.
60
64
 
61
65
  Args:
62
- config (cfg.ModelConfig): the configuration object for this transformer
63
- block.
66
+ config (cfg.TransformerBlockConfig): the configuration object for this
67
+ transformer block.
68
+ model_config (cfg.ModelConfig): the configuration object for the model
69
+ this transformer block belongs to.
64
70
  """
65
-
66
71
  super().__init__()
67
72
  self.pre_atten_norm = builder.build_norm(
68
- config.embedding_dim, config.pre_attention_norm_config
73
+ model_config.embedding_dim,
74
+ config.pre_attention_norm_config,
69
75
  )
70
76
  self.atten_func = CausalSelfAttention(
71
- config.batch_size,
72
- config.embedding_dim,
77
+ model_config.batch_size,
78
+ model_config.embedding_dim,
73
79
  config.attn_config,
74
- config.kv_cache_max,
75
- config.enable_hlfb,
80
+ model_config.enable_hlfb,
76
81
  )
77
82
  self.post_atten_norm = builder.build_norm(
78
- config.embedding_dim, config.post_attention_norm_config
83
+ model_config.embedding_dim,
84
+ config.post_attention_norm_config,
79
85
  )
80
- self.ff = builder.build_ff(config.embedding_dim, config.ff_config)
86
+ self.ff = builder.build_ff(model_config.embedding_dim, config.ff_config)
81
87
  self.config = config
82
88
 
83
89
  def forward(
@@ -86,7 +92,8 @@ class TransformerBlock(nn.Module):
86
92
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
87
93
  mask: Optional[torch.Tensor] = None,
88
94
  input_pos: Optional[torch.Tensor] = None,
89
- ) -> torch.Tensor:
95
+ kv_cache: kv_utils.KVCacheEntry = None,
96
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
90
97
  """Forward function of the TransformerBlock.
91
98
 
92
99
  Args:
@@ -94,24 +101,34 @@ class TransformerBlock(nn.Module):
94
101
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
95
102
  mask (torch.Tensor): the optional mask tensor.
96
103
  input_pos (torch.Tensor): the optional input position tensor.
104
+ kv_cache (KVCacheEntry): the optional kv cache entry.
97
105
 
98
106
  Returns:
99
- output activation from this transformer block.
107
+ output activation from this transformer block, and updated kv cache (if
108
+ passed in).
100
109
  """
101
-
110
+ kv = None
102
111
  if self.config.parallel_residual:
103
112
  x_norm = self.pre_atten_norm(x)
104
- attn_out = self.atten_func(x_norm, rope, mask, input_pos)
113
+ atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
114
+ if kv_cache is None:
115
+ attn_out = atten_func_out
116
+ else:
117
+ attn_out, kv = atten_func_out
105
118
  ff_out = self.ff(x_norm)
106
119
  output = x + attn_out + ff_out
107
120
  else:
108
121
  x_norm = self.pre_atten_norm(x)
109
- attn_out = self.atten_func(x_norm, rope, mask, input_pos)
122
+ atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
123
+ if kv_cache is None:
124
+ attn_out = atten_func_out
125
+ else:
126
+ attn_out, kv = atten_func_out
110
127
  x = x + attn_out
111
128
  x_norm = self.post_atten_norm(x)
112
129
  output = x + self.ff(x_norm)
113
130
 
114
- return output
131
+ return output if kv is None else (output, kv)
115
132
 
116
133
 
117
134
  class CausalSelfAttention(nn.Module):
@@ -121,7 +138,6 @@ class CausalSelfAttention(nn.Module):
121
138
  batch_size: int,
122
139
  dim: int,
123
140
  config: cfg.AttentionConfig,
124
- kv_cache_max: int,
125
141
  enable_hlfb: bool,
126
142
  ) -> None:
127
143
  """Initialize an instance of CausalSelfAttention.
@@ -130,12 +146,9 @@ class CausalSelfAttention(nn.Module):
130
146
  batch_size (int): batch size of the input tensor.
131
147
  dim (int): causal attention's input/output dimmension.
132
148
  config (cfg.AttentionConfig): attention specific configurations.
133
- kv_cache_max (int): determines the size of the KV Cache buffer, if
134
- enabled.
135
149
  enable_hlfb (bool): whether hlfb is enabled or not.
136
150
  """
137
151
  super().__init__()
138
- self.config = config
139
152
  self.kv_cache = None
140
153
  self.batch_size = batch_size
141
154
  qkv_shape = (
@@ -147,21 +160,13 @@ class CausalSelfAttention(nn.Module):
147
160
  self.output_projection = nn.Linear(
148
161
  output_shape, dim, bias=config.output_proj_use_bias
149
162
  )
150
-
151
- # Build a k/v cache with size (batch_size, kv_cache_max, n_heads, head_dim).
152
- if config.enable_kv_cache:
153
- self.kv_cache = KVCache(
154
- batch_size,
155
- kv_cache_max,
156
- config.num_query_groups,
157
- config.head_dim,
158
- enable_hlfb,
159
- )
160
-
161
- if enable_hlfb:
162
- self.sdpa_func = scaled_dot_product_attention_with_hlfb
163
- else:
164
- self.sdpa_func = scaled_dot_product_attention
163
+ self.config = config
164
+ self.enable_hlfb = enable_hlfb
165
+ self.sdpa_func = (
166
+ sdpa.scaled_dot_product_attention_with_hlfb
167
+ if enable_hlfb
168
+ else sdpa.scaled_dot_product_attention
169
+ )
165
170
 
166
171
  def forward(
167
172
  self,
@@ -169,7 +174,8 @@ class CausalSelfAttention(nn.Module):
169
174
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
170
175
  mask: Optional[torch.Tensor] = None,
171
176
  input_pos: Optional[torch.Tensor] = None,
172
- ) -> torch.Tensor:
177
+ kv_cache: Optional[kv_utils.KVCacheEntry] = None,
178
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
173
179
  """Forward function of the CausalSelfAttention layer, which can support
174
180
 
175
181
  MQA, GQA and MHA.
@@ -179,9 +185,11 @@ class CausalSelfAttention(nn.Module):
179
185
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
180
186
  mask (torch.Tensor): the optional mask tensor.
181
187
  input_pos (torch.Tensor): the optional input position tensor.
188
+ kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
182
189
 
183
190
  Returns:
184
- output activation from this self attention layer.
191
+ output activation from this self attention layer, and the updated
192
+ KV Cach Entry (if passed in).
185
193
  """
186
194
  # Batch size, sequence length, embedding dimensionality.
187
195
  B, T, E = x.size()
@@ -224,9 +232,11 @@ class CausalSelfAttention(nn.Module):
224
232
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
225
233
  q, k = _embed_rope(q, k, n_elem, rope)
226
234
 
227
- if self.kv_cache is not None:
228
- # TODO(haoliang): Handle when execeeding max sequence length.
229
- k, v = self.kv_cache.update_cache(input_pos, k, v)
235
+ if kv_cache is not None:
236
+ kv_cache = kv_utils.update(
237
+ kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
238
+ )
239
+ k, v = kv_cache.k_cache, kv_cache.v_cache
230
240
 
231
241
  y = self.sdpa_func(
232
242
  q,
@@ -240,7 +250,7 @@ class CausalSelfAttention(nn.Module):
240
250
 
241
251
  # Compute the output projection.
242
252
  y = self.output_projection(y)
243
- return y
253
+ return y if kv_cache is None else (y, kv_cache)
244
254
 
245
255
 
246
256
  class SelfAttention(CausalSelfAttention):
@@ -251,16 +261,19 @@ class SelfAttention(CausalSelfAttention):
251
261
  x: torch.Tensor,
252
262
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
253
263
  input_pos: Optional[torch.Tensor] = None,
254
- ) -> torch.Tensor:
264
+ kv_cache: Optional[kv_utils.KVCacheEntry] = None,
265
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
255
266
  """Forward function of the SelfAttention layer, which can support MQA, GQA and MHA.
256
267
 
257
268
  Args:
258
269
  x (torch.Tensor): the input tensor.
259
270
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
260
271
  input_pos (torch.Tensor): the optional input position tensor.
272
+ kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
261
273
 
262
274
  Returns:
263
- output activation from this self attention layer.
275
+ output activation from this self attention layer, and the updated
276
+ KV Cach Entry (if passed in).
264
277
  """
265
278
  B, T, _ = x.size()
266
279
  return super().forward(
@@ -279,9 +292,8 @@ class CrossAttention(nn.Module):
279
292
  query_dim: int,
280
293
  cross_dim: int,
281
294
  config: cfg.AttentionConfig,
282
- kv_cache_max: int,
283
295
  enable_hlfb: bool,
284
- ) -> None:
296
+ ):
285
297
  """Initialize an instance of CrossAttention.
286
298
 
287
299
  Args:
@@ -289,8 +301,6 @@ class CrossAttention(nn.Module):
289
301
  query_dim (int): query tensor's dimension.
290
302
  cross_dim (int): cross attention's dimensions, for key and value tensors.
291
303
  config (cfg.AttentionConfig): attention specific configurations.
292
- kv_cache_max (int): determines the size of the KV Cache buffer, if
293
- enabled.
294
304
  enable_hlfb (bool): whether hlfb is enabled or not.
295
305
  """
296
306
  super().__init__()
@@ -309,21 +319,11 @@ class CrossAttention(nn.Module):
309
319
  query_dim, query_dim, bias=config.output_proj_use_bias
310
320
  )
311
321
 
312
- self.kv_cache = None
313
- # Build a k/v cache with size (batch_size, kv_cache_max, n_heads, head_dim).
314
- if config.enable_kv_cache:
315
- self.kv_cache = KVCache(
316
- batch_size,
317
- kv_cache_max,
318
- config.num_query_groups,
319
- self.config.head_dim,
320
- enable_hlfb,
321
- )
322
-
323
- if enable_hlfb:
324
- self.sdpa_func = scaled_dot_product_attention_with_hlfb
325
- else:
326
- self.sdpa_func = scaled_dot_product_attention
322
+ self.sdpa_func = (
323
+ sdpa.scaled_dot_product_attention_with_hlfb
324
+ if enable_hlfb
325
+ else sdpa.scaled_dot_product_attention
326
+ )
327
327
 
328
328
  def forward(
329
329
  self,
@@ -332,6 +332,7 @@ class CrossAttention(nn.Module):
332
332
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
333
333
  mask: Optional[torch.Tensor] = None,
334
334
  input_pos: Optional[torch.Tensor] = None,
335
+ kv_cache: Optional[kv_utils.KVCacheEntry] = None,
335
336
  ):
336
337
  """Forward function of the CrossAttention layer.
337
338
 
@@ -342,6 +343,7 @@ class CrossAttention(nn.Module):
342
343
  mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape
343
344
  [B, n_heads, target_seq_len, source_seq_len].
344
345
  input_pos (torch.Tensor): the optional input position tensor.
346
+ kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
345
347
 
346
348
  Returns:
347
349
  output activation from this cross attention layer.
@@ -363,9 +365,11 @@ class CrossAttention(nn.Module):
363
365
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
364
366
  q, k = _embed_rope(q, k, n_elem, rope)
365
367
 
366
- if self.kv_cache is not None:
367
- # TODO(haoliang): Handle when execeeding max sequence length.
368
- k, v = self.kv_cache.update_cache(input_pos, k, v)
368
+ if kv_cache is not None:
369
+ kv_cache = kv_utils.update(
370
+ kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
371
+ )
372
+ k, v = kv_cache.k_cache, kv_cache.v_cache
369
373
  if mask is None:
370
374
  mask = torch.zeros(
371
375
  (batch_size, 1, target_seq_len, source_seq_len), dtype=torch.float32
@@ -375,4 +379,4 @@ class CrossAttention(nn.Module):
375
379
 
376
380
  # Compute the output projection.
377
381
  y = self.output_projection(y)
378
- return y
382
+ return y if kv_cache is None else (y, kv_cache)
@@ -59,9 +59,11 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
59
59
  zero_centered_gamma=config.zero_centered,
60
60
  )
61
61
  elif config.type == cfg.NormalizationType.LAYER_NORM:
62
- return nn.LayerNorm(dim, eps=config.epsilon)
62
+ return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb)
63
63
  elif config.type == cfg.NormalizationType.GROUP_NORM:
64
- return nn.GroupNorm(config.group_num, dim, config.epsilon)
64
+ return normalization.GroupNorm(
65
+ config.group_num, dim, config.epsilon, config.enable_hlfb
66
+ )
65
67
  else:
66
68
  raise ValueError("Unsupported norm type.")
67
69
 
@@ -71,7 +73,7 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
71
73
 
72
74
  Args:
73
75
  dim (int): dimension of the input tensor.
74
- config (`ModelConfig` object): the model configuration.
76
+ config (`FeedForwardConfig` object): the model configuration.
75
77
 
76
78
  Returns:
77
79
  The constructed `nn.Module` feedforward layer.