ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ai_edge_torch/_convert/conversion.py +2 -1
  2. ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
  3. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
  4. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
  5. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
  6. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
  7. ai_edge_torch/config.py +4 -1
  8. ai_edge_torch/fx_pass_base.py +101 -0
  9. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +35 -16
  10. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +29 -10
  11. ai_edge_torch/generative/examples/gemma/gemma.py +52 -32
  12. ai_edge_torch/generative/examples/gemma/gemma2.py +87 -60
  13. ai_edge_torch/generative/examples/{experimental/gemma → openelm}/convert_to_tflite.py +16 -18
  14. ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
  15. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +15 -16
  16. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +48 -45
  17. ai_edge_torch/generative/examples/{experimental/tiny_llama → smollm}/convert_to_tflite.py +16 -17
  18. ai_edge_torch/generative/examples/smollm/smollm.py +131 -0
  19. ai_edge_torch/generative/examples/stable_diffusion/clip.py +12 -6
  20. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
  21. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
  22. ai_edge_torch/generative/examples/t5/t5.py +43 -30
  23. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  24. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  25. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +75 -34
  26. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +29 -10
  27. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +57 -36
  28. ai_edge_torch/generative/fx_passes/__init__.py +4 -4
  29. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
  30. ai_edge_torch/generative/layers/attention.py +84 -73
  31. ai_edge_torch/generative/layers/builder.py +38 -14
  32. ai_edge_torch/generative/layers/feed_forward.py +26 -8
  33. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  34. ai_edge_torch/generative/layers/model_config.py +61 -33
  35. ai_edge_torch/generative/layers/normalization.py +158 -0
  36. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  37. ai_edge_torch/generative/quantize/example.py +2 -2
  38. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  39. ai_edge_torch/generative/test/test_loader.py +1 -1
  40. ai_edge_torch/generative/test/test_model_conversion.py +77 -62
  41. ai_edge_torch/generative/test/test_model_conversion_large.py +61 -68
  42. ai_edge_torch/generative/test/test_quantize.py +5 -5
  43. ai_edge_torch/generative/test/utils.py +54 -0
  44. ai_edge_torch/generative/utilities/loader.py +28 -15
  45. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  46. ai_edge_torch/odml_torch/export.py +40 -0
  47. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  48. ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
  49. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
  50. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  51. ai_edge_torch/version.py +1 -1
  52. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
  53. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +59 -63
  54. ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
  55. ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
  56. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  57. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  58. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  59. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  60. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  61. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  62. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  63. /ai_edge_torch/generative/examples/{experimental → openelm}/__init__.py +0 -0
  64. /ai_edge_torch/generative/examples/{experimental/gemma → phi}/__init__.py +0 -0
  65. /ai_edge_torch/generative/examples/{experimental/phi → smollm}/__init__.py +0 -0
  66. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
  67. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
  68. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
@@ -12,13 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building a TinyLlama model from the Edge Generative API layers.
15
+
16
+ """Example of building a TinyLlama model."""
16
17
 
17
18
  import os
18
- from pathlib import Path
19
+ import pathlib
19
20
 
20
21
  from ai_edge_torch.generative.layers import attention
21
22
  from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
25
  import ai_edge_torch.generative.layers.model_config as cfg
24
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
@@ -42,13 +44,12 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
42
44
  )
43
45
 
44
46
 
45
- class TinyLLamma(nn.Module):
47
+ class TinyLlama(nn.Module):
46
48
  """A TinyLlama model built from the Edge Generative API layers."""
47
49
 
48
50
  def __init__(self, config: cfg.ModelConfig):
49
51
  super().__init__()
50
52
 
51
- self.config = config
52
53
  # Construct model layers.
53
54
  self.lm_head = nn.Linear(
54
55
  config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
@@ -56,18 +57,20 @@ class TinyLLamma(nn.Module):
56
57
  self.tok_embedding = nn.Embedding(
57
58
  config.vocab_size, config.embedding_dim, padding_idx=0
58
59
  )
60
+ # TinyLlama has only one block config.
61
+ block_config = config.block_config(0)
59
62
  self.transformer_blocks = nn.ModuleList(
60
- attention.TransformerBlock(config) for _ in range(config.num_layers)
63
+ attention.TransformerBlock(block_config, config)
64
+ for _ in range(config.num_layers)
61
65
  )
62
66
  self.final_norm = builder.build_norm(
63
67
  config.embedding_dim,
64
68
  config.final_norm_config,
65
69
  )
70
+ attn_config = block_config.attn_config
66
71
  self.rope_cache = attn_utils.build_rope_cache(
67
72
  size=config.kv_cache_max,
68
- dim=int(
69
- config.attn_config.rotary_percentage * config.attn_config.head_dim
70
- ),
73
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
71
74
  base=10_000,
72
75
  condense_ratio=1,
73
76
  dtype=torch.float32,
@@ -80,16 +83,22 @@ class TinyLLamma(nn.Module):
80
83
  )
81
84
  self.config = config
82
85
 
83
- # The model's forward function takes in additional k/v cache tensors
84
- # and returns the updated k/v cache tensors to the caller.
85
- # This can be eliminated if we handle k/v cache updates inside the model itself.
86
86
  @torch.inference_mode
87
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
88
- _, seq_len = idx.size()
87
+ def forward(
88
+ self,
89
+ tokens: torch.Tensor,
90
+ input_pos: torch.Tensor,
91
+ kv_cache: kv_utils.KVCache,
92
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
93
+ _, seq_len = tokens.size()
89
94
  assert self.config.max_seq_len >= seq_len, (
90
95
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
91
96
  f" {self.config.max_seq_len}"
92
97
  )
98
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
99
+ "The number of transformer blocks and the number of KV cache entries"
100
+ " must be the same."
101
+ )
93
102
 
94
103
  cos, sin = self.rope_cache
95
104
  cos = cos.index_select(0, input_pos)
@@ -97,16 +106,20 @@ class TinyLLamma(nn.Module):
97
106
  mask = self.mask_cache.index_select(2, input_pos)
98
107
  mask = mask[:, :, :, : self.config.kv_cache_max]
99
108
 
100
- # forward the model itself
101
- x = self.tok_embedding(idx) # token embeddings of shape (b, t, n_embd)
109
+ # token embeddings of shape (b, t, n_embd)
110
+ x = self.tok_embedding(tokens)
102
111
 
103
- for _, block in enumerate(self.transformer_blocks):
104
- x = block(x, (cos, sin), mask, input_pos)
112
+ updated_kv_entires = []
113
+ for i, block in enumerate(self.transformer_blocks):
114
+ kv_entry = kv_cache.caches[i] if kv_cache else None
115
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
116
+ if kv_entry:
117
+ updated_kv_entires.append(kv_entry)
118
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
105
119
 
106
120
  x = self.final_norm(x)
107
-
108
- res = self.lm_head(x) # (b, t, vocab_size)
109
- return res
121
+ logits = self.lm_head(x) # (b, t, vocab_size)
122
+ return {"logits": logits, "kv_cache": updated_kv_cache}
110
123
 
111
124
 
112
125
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -131,55 +144,63 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
131
144
  intermediate_size=5632,
132
145
  )
133
146
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
147
+ block_config = cfg.TransformerBlockConfig(
148
+ attn_config=attn_config,
149
+ ff_config=ff_config,
150
+ pre_attention_norm_config=norm_config,
151
+ post_attention_norm_config=norm_config,
152
+ )
134
153
  config = cfg.ModelConfig(
135
154
  vocab_size=32000,
136
155
  num_layers=22,
137
156
  max_seq_len=2048,
138
157
  embedding_dim=2048,
139
158
  kv_cache_max_len=kv_cache_max_len,
140
- attn_config=attn_config,
141
- ff_config=ff_config,
142
- pre_attention_norm_config=norm_config,
143
- post_attention_norm_config=norm_config,
159
+ block_configs=block_config,
144
160
  final_norm_config=norm_config,
145
161
  enable_hlfb=True,
146
162
  )
147
163
  return config
148
164
 
149
165
 
150
- def get_fake_model_config() -> cfg.ModelConfig:
151
- config = get_model_config()
166
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
167
+ config = get_model_config(**kwargs)
152
168
  config.vocab_size = 128
153
169
  config.num_layers = 2
154
- config.ff_config.intermediate_size = 64
170
+ # TinyLlama has only one block config.
171
+ config.block_config(0).ff_config.intermediate_size = 64
155
172
  return config
156
173
 
157
174
 
158
175
  def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
159
176
  config = get_model_config(**kwargs)
160
- model = TinyLLamma(config)
177
+ model = TinyLlama(config)
161
178
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
162
179
  loader.load(model)
180
+ model.eval()
163
181
  return model
164
182
 
165
183
 
166
- def define_and_run() -> None:
184
+ def define_and_run(checkpoint_path: str) -> None:
167
185
  """Instantiates and runs a TinyLlama model."""
168
186
 
169
- current_dir = Path(__file__).parent.resolve()
187
+ current_dir = pathlib.Path(__file__).parent.resolve()
170
188
  tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
171
189
  kv_cache_max_len = 1024
172
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/tiny_llama")
173
190
  model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
174
191
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
175
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
192
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
176
193
  tokens[0, :4] = idx
177
- input_pos = torch.arange(0, kv_cache_max_len)
178
- lm_logits = model.forward(tokens, input_pos)
194
+ input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
195
+ kv = kv_utils.KVCache.from_model_config(model.config)
196
+ output = model.forward(tokens, input_pos, kv)
179
197
  assert torch.allclose(
180
- tiny_llama_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
198
+ tiny_llama_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
181
199
  )
182
200
 
183
201
 
184
202
  if __name__ == "__main__":
185
- define_and_run()
203
+ input_checkpoint_path = os.path.join(
204
+ pathlib.Path.home(), "Downloads/llm_data/tiny_llama"
205
+ )
206
+ define_and_run(input_checkpoint_path)
@@ -12,16 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from ai_edge_torch._convert.fx_passes import CanonicalizePass
16
- from ai_edge_torch._convert.fx_passes import run_passes
17
- from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass # NOQA
15
+ from ai_edge_torch import fx_pass_base
16
+ from ai_edge_torch.fx_pass_base import CanonicalizePass
17
+ from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass
18
18
  import torch
19
19
 
20
20
 
21
21
  def run_generative_passes(
22
22
  exported_program: torch.export.ExportedProgram,
23
23
  ) -> torch.export.ExportedProgram:
24
- return run_passes(
24
+ return fx_pass_base.run_passes(
25
25
  exported_program,
26
26
  [
27
27
  RemoveSDPACompositeZeroMaskPass(),
@@ -12,13 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from ai_edge_torch import fx_pass_base
15
16
  from ai_edge_torch import lowertools
16
- from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassBase
17
- from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult
18
17
  import torch
19
18
 
20
19
 
21
- class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase):
20
+ class RemoveSDPACompositeZeroMaskPass(fx_pass_base.ExportedProgramPassBase):
22
21
 
23
22
  def is_zero_tensor_node(self, node: torch.fx.Node):
24
23
  return node.target == torch.ops.aten.zeros.default
@@ -48,4 +47,4 @@ class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase):
48
47
 
49
48
  exported_program.graph_module.graph.lint()
50
49
  exported_program.graph_module.recompile()
51
- return ExportedProgramPassResult(exported_program, True)
50
+ return fx_pass_base.ExportedProgramPassResult(exported_program, True)
@@ -12,16 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Common building blocks for Attention layer.
16
15
 
17
- from typing import Optional, Tuple
16
+ """Common building blocks for Attention layer."""
18
17
 
19
- import ai_edge_torch.generative.layers.builder as builder
20
- from ai_edge_torch.generative.layers.kv_cache import KVCache
18
+ from typing import Optional, Tuple, Union
19
+
20
+ from ai_edge_torch.generative.layers import builder
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
21
23
  import ai_edge_torch.generative.layers.model_config as cfg
22
24
  import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
23
- from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
24
- from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
25
25
  import torch
26
26
  from torch import nn
27
27
 
@@ -55,29 +55,35 @@ def _embed_rope(
55
55
 
56
56
  class TransformerBlock(nn.Module):
57
57
 
58
- def __init__(self, config: cfg.ModelConfig) -> None:
58
+ def __init__(
59
+ self,
60
+ config: cfg.TransformerBlockConfig,
61
+ model_config: cfg.ModelConfig,
62
+ ) -> None:
59
63
  """Initialize an instance of the TransformerBlock.
60
64
 
61
65
  Args:
62
- config (cfg.ModelConfig): the configuration object for this transformer
63
- block.
66
+ config (cfg.TransformerBlockConfig): the configuration object for this
67
+ transformer block.
68
+ model_config (cfg.ModelConfig): the configuration object for the model
69
+ this transformer block belongs to.
64
70
  """
65
-
66
71
  super().__init__()
67
72
  self.pre_atten_norm = builder.build_norm(
68
- config.embedding_dim, config.pre_attention_norm_config
73
+ model_config.embedding_dim,
74
+ config.pre_attention_norm_config,
69
75
  )
70
76
  self.atten_func = CausalSelfAttention(
71
- config.batch_size,
72
- config.embedding_dim,
77
+ model_config.batch_size,
78
+ model_config.embedding_dim,
73
79
  config.attn_config,
74
- config.kv_cache_max,
75
- config.enable_hlfb,
80
+ model_config.enable_hlfb,
76
81
  )
77
82
  self.post_atten_norm = builder.build_norm(
78
- config.embedding_dim, config.post_attention_norm_config
83
+ model_config.embedding_dim,
84
+ config.post_attention_norm_config,
79
85
  )
80
- self.ff = builder.build_ff(config.embedding_dim, config.ff_config)
86
+ self.ff = builder.build_ff(model_config.embedding_dim, config.ff_config)
81
87
  self.config = config
82
88
 
83
89
  def forward(
@@ -86,7 +92,8 @@ class TransformerBlock(nn.Module):
86
92
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
87
93
  mask: Optional[torch.Tensor] = None,
88
94
  input_pos: Optional[torch.Tensor] = None,
89
- ) -> torch.Tensor:
95
+ kv_cache: kv_utils.KVCacheEntry = None,
96
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
90
97
  """Forward function of the TransformerBlock.
91
98
 
92
99
  Args:
@@ -94,24 +101,34 @@ class TransformerBlock(nn.Module):
94
101
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
95
102
  mask (torch.Tensor): the optional mask tensor.
96
103
  input_pos (torch.Tensor): the optional input position tensor.
104
+ kv_cache (KVCacheEntry): the optional kv cache entry.
97
105
 
98
106
  Returns:
99
- output activation from this transformer block.
107
+ output activation from this transformer block, and updated kv cache (if
108
+ passed in).
100
109
  """
101
-
110
+ kv = None
102
111
  if self.config.parallel_residual:
103
112
  x_norm = self.pre_atten_norm(x)
104
- attn_out = self.atten_func(x_norm, rope, mask, input_pos)
113
+ atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
114
+ if kv_cache is None:
115
+ attn_out = atten_func_out
116
+ else:
117
+ attn_out, kv = atten_func_out
105
118
  ff_out = self.ff(x_norm)
106
119
  output = x + attn_out + ff_out
107
120
  else:
108
121
  x_norm = self.pre_atten_norm(x)
109
- attn_out = self.atten_func(x_norm, rope, mask, input_pos)
122
+ atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
123
+ if kv_cache is None:
124
+ attn_out = atten_func_out
125
+ else:
126
+ attn_out, kv = atten_func_out
110
127
  x = x + attn_out
111
128
  x_norm = self.post_atten_norm(x)
112
129
  output = x + self.ff(x_norm)
113
130
 
114
- return output
131
+ return output if kv is None else (output, kv)
115
132
 
116
133
 
117
134
  class CausalSelfAttention(nn.Module):
@@ -121,7 +138,6 @@ class CausalSelfAttention(nn.Module):
121
138
  batch_size: int,
122
139
  dim: int,
123
140
  config: cfg.AttentionConfig,
124
- kv_cache_max: int,
125
141
  enable_hlfb: bool,
126
142
  ) -> None:
127
143
  """Initialize an instance of CausalSelfAttention.
@@ -130,12 +146,9 @@ class CausalSelfAttention(nn.Module):
130
146
  batch_size (int): batch size of the input tensor.
131
147
  dim (int): causal attention's input/output dimmension.
132
148
  config (cfg.AttentionConfig): attention specific configurations.
133
- kv_cache_max (int): determines the size of the KV Cache buffer, if
134
- enabled.
135
149
  enable_hlfb (bool): whether hlfb is enabled or not.
136
150
  """
137
151
  super().__init__()
138
- self.config = config
139
152
  self.kv_cache = None
140
153
  self.batch_size = batch_size
141
154
  qkv_shape = (
@@ -147,21 +160,17 @@ class CausalSelfAttention(nn.Module):
147
160
  self.output_projection = nn.Linear(
148
161
  output_shape, dim, bias=config.output_proj_use_bias
149
162
  )
150
-
151
- # Build a k/v cache with size (batch_size, kv_cache_max, n_heads, head_dim).
152
- if config.enable_kv_cache:
153
- self.kv_cache = KVCache(
154
- batch_size,
155
- kv_cache_max,
156
- config.num_query_groups,
157
- config.head_dim,
158
- enable_hlfb,
159
- )
160
-
161
- if enable_hlfb:
162
- self.sdpa_func = scaled_dot_product_attention_with_hlfb
163
- else:
164
- self.sdpa_func = scaled_dot_product_attention
163
+ self.query_norm = builder.build_norm(
164
+ config.head_dim, config.query_norm_config
165
+ )
166
+ self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
167
+ self.config = config
168
+ self.enable_hlfb = enable_hlfb
169
+ self.sdpa_func = (
170
+ sdpa.scaled_dot_product_attention_with_hlfb
171
+ if enable_hlfb
172
+ else sdpa.scaled_dot_product_attention
173
+ )
165
174
 
166
175
  def forward(
167
176
  self,
@@ -169,7 +178,8 @@ class CausalSelfAttention(nn.Module):
169
178
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
170
179
  mask: Optional[torch.Tensor] = None,
171
180
  input_pos: Optional[torch.Tensor] = None,
172
- ) -> torch.Tensor:
181
+ kv_cache: Optional[kv_utils.KVCacheEntry] = None,
182
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
173
183
  """Forward function of the CausalSelfAttention layer, which can support
174
184
 
175
185
  MQA, GQA and MHA.
@@ -179,9 +189,11 @@ class CausalSelfAttention(nn.Module):
179
189
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
180
190
  mask (torch.Tensor): the optional mask tensor.
181
191
  input_pos (torch.Tensor): the optional input position tensor.
192
+ kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
182
193
 
183
194
  Returns:
184
- output activation from this self attention layer.
195
+ output activation from this self attention layer, and the updated
196
+ KV Cach Entry (if passed in).
185
197
  """
186
198
  # Batch size, sequence length, embedding dimensionality.
187
199
  B, T, E = x.size()
@@ -216,6 +228,9 @@ class CausalSelfAttention(nn.Module):
216
228
  dim=-1,
217
229
  )
218
230
 
231
+ q = self.query_norm(q)
232
+ k = self.key_norm(k)
233
+
219
234
  q = q.reshape(B, T, -1, self.config.head_dim)
220
235
  k = k.reshape(B, T, -1, self.config.head_dim)
221
236
  v = v.reshape(B, T, -1, self.config.head_dim)
@@ -224,9 +239,11 @@ class CausalSelfAttention(nn.Module):
224
239
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
225
240
  q, k = _embed_rope(q, k, n_elem, rope)
226
241
 
227
- if self.kv_cache is not None:
228
- # TODO(haoliang): Handle when execeeding max sequence length.
229
- k, v = self.kv_cache.update_cache(input_pos, k, v)
242
+ if kv_cache is not None:
243
+ kv_cache = kv_utils.update(
244
+ kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
245
+ )
246
+ k, v = kv_cache.k_cache, kv_cache.v_cache
230
247
 
231
248
  y = self.sdpa_func(
232
249
  q,
@@ -240,7 +257,7 @@ class CausalSelfAttention(nn.Module):
240
257
 
241
258
  # Compute the output projection.
242
259
  y = self.output_projection(y)
243
- return y
260
+ return y if kv_cache is None else (y, kv_cache)
244
261
 
245
262
 
246
263
  class SelfAttention(CausalSelfAttention):
@@ -251,16 +268,19 @@ class SelfAttention(CausalSelfAttention):
251
268
  x: torch.Tensor,
252
269
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
253
270
  input_pos: Optional[torch.Tensor] = None,
254
- ) -> torch.Tensor:
271
+ kv_cache: Optional[kv_utils.KVCacheEntry] = None,
272
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
255
273
  """Forward function of the SelfAttention layer, which can support MQA, GQA and MHA.
256
274
 
257
275
  Args:
258
276
  x (torch.Tensor): the input tensor.
259
277
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
260
278
  input_pos (torch.Tensor): the optional input position tensor.
279
+ kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
261
280
 
262
281
  Returns:
263
- output activation from this self attention layer.
282
+ output activation from this self attention layer, and the updated
283
+ KV Cach Entry (if passed in).
264
284
  """
265
285
  B, T, _ = x.size()
266
286
  return super().forward(
@@ -279,9 +299,8 @@ class CrossAttention(nn.Module):
279
299
  query_dim: int,
280
300
  cross_dim: int,
281
301
  config: cfg.AttentionConfig,
282
- kv_cache_max: int,
283
302
  enable_hlfb: bool,
284
- ) -> None:
303
+ ):
285
304
  """Initialize an instance of CrossAttention.
286
305
 
287
306
  Args:
@@ -289,8 +308,6 @@ class CrossAttention(nn.Module):
289
308
  query_dim (int): query tensor's dimension.
290
309
  cross_dim (int): cross attention's dimensions, for key and value tensors.
291
310
  config (cfg.AttentionConfig): attention specific configurations.
292
- kv_cache_max (int): determines the size of the KV Cache buffer, if
293
- enabled.
294
311
  enable_hlfb (bool): whether hlfb is enabled or not.
295
312
  """
296
313
  super().__init__()
@@ -309,21 +326,11 @@ class CrossAttention(nn.Module):
309
326
  query_dim, query_dim, bias=config.output_proj_use_bias
310
327
  )
311
328
 
312
- self.kv_cache = None
313
- # Build a k/v cache with size (batch_size, kv_cache_max, n_heads, head_dim).
314
- if config.enable_kv_cache:
315
- self.kv_cache = KVCache(
316
- batch_size,
317
- kv_cache_max,
318
- config.num_query_groups,
319
- self.config.head_dim,
320
- enable_hlfb,
321
- )
322
-
323
- if enable_hlfb:
324
- self.sdpa_func = scaled_dot_product_attention_with_hlfb
325
- else:
326
- self.sdpa_func = scaled_dot_product_attention
329
+ self.sdpa_func = (
330
+ sdpa.scaled_dot_product_attention_with_hlfb
331
+ if enable_hlfb
332
+ else sdpa.scaled_dot_product_attention
333
+ )
327
334
 
328
335
  def forward(
329
336
  self,
@@ -332,6 +339,7 @@ class CrossAttention(nn.Module):
332
339
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
333
340
  mask: Optional[torch.Tensor] = None,
334
341
  input_pos: Optional[torch.Tensor] = None,
342
+ kv_cache: Optional[kv_utils.KVCacheEntry] = None,
335
343
  ):
336
344
  """Forward function of the CrossAttention layer.
337
345
 
@@ -342,6 +350,7 @@ class CrossAttention(nn.Module):
342
350
  mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape
343
351
  [B, n_heads, target_seq_len, source_seq_len].
344
352
  input_pos (torch.Tensor): the optional input position tensor.
353
+ kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
345
354
 
346
355
  Returns:
347
356
  output activation from this cross attention layer.
@@ -363,9 +372,11 @@ class CrossAttention(nn.Module):
363
372
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
364
373
  q, k = _embed_rope(q, k, n_elem, rope)
365
374
 
366
- if self.kv_cache is not None:
367
- # TODO(haoliang): Handle when execeeding max sequence length.
368
- k, v = self.kv_cache.update_cache(input_pos, k, v)
375
+ if kv_cache is not None:
376
+ kv_cache = kv_utils.update(
377
+ kv_cache, input_pos, k, v, enable_hlfb=self.enable_hlfb
378
+ )
379
+ k, v = kv_cache.k_cache, kv_cache.v_cache
369
380
  if mask is None:
370
381
  mask = torch.zeros(
371
382
  (batch_size, 1, target_seq_len, source_seq_len), dtype=torch.float32
@@ -375,4 +386,4 @@ class CrossAttention(nn.Module):
375
386
 
376
387
  # Compute the output projection.
377
388
  y = self.output_projection(y)
378
- return y
389
+ return y if kv_cache is None else (y, kv_cache)
@@ -13,6 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  # Builder class for individual components.
16
+ from typing import Callable
17
+
16
18
  import ai_edge_torch.generative.layers.feed_forward as feed_forward
17
19
  import ai_edge_torch.generative.layers.model_config as cfg
18
20
  import ai_edge_torch.generative.layers.normalization as normalization
@@ -21,20 +23,34 @@ from torch import nn
21
23
  import torch.nn.functional as F
22
24
 
23
25
 
24
- class GeGLU(nn.Module):
25
- """GeGLU is an activation function which is a variant of GELU.
26
+ def build_glu(
27
+ act: Callable[[torch.Tensor], torch.Tensor], gate_is_front: bool = False
28
+ ) -> Callable[[torch.Tensor], torch.Tensor]:
29
+ """Builds an activation function with GLU (Gated Linear Unit).
30
+
31
+ If gate_is_front is True,
32
+ f(x) = act(x) * y
33
+ otherwise,
34
+ f(x) = x * act(y),
35
+ where x is the first half of the input and y is the second half of the input.
36
+
37
+ Args:
38
+ act (Callable[[torch.Tensor], torch.Tensor]): activation function to apply
39
+ to the gate.
40
+ gate_is_front: whether the gate is in front half of the input. Other part is
41
+ the output in GLU.
26
42
 
27
- GeGLU(x) = (xW+b) * GELU(xV+c)
28
- See: https://arxiv.org/abs/2002.05202v1
43
+ Returns:
44
+ A callable activation function with GLU.
29
45
  """
30
46
 
31
- def __init__(self, d_in: int, d_out: int):
32
- super().__init__()
33
- self.proj = nn.Linear(d_in, d_out * 2)
47
+ def _glu(x):
48
+ x, y = x.chunk(2, dim=-1)
49
+ if gate_is_front:
50
+ return act(x) * y
51
+ return x * act(y)
34
52
 
35
- def forward(self, x: torch.Tensor):
36
- x, gate = self.proj(x).chunk(2, dim=-1)
37
- return x * F.gelu(gate)
53
+ return _glu
38
54
 
39
55
 
40
56
  def build_norm(dim: int, config: cfg.NormalizationConfig):
@@ -59,9 +75,11 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
59
75
  zero_centered_gamma=config.zero_centered,
60
76
  )
61
77
  elif config.type == cfg.NormalizationType.LAYER_NORM:
62
- return nn.LayerNorm(dim, eps=config.epsilon)
78
+ return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb)
63
79
  elif config.type == cfg.NormalizationType.GROUP_NORM:
64
- return nn.GroupNorm(config.group_num, dim, config.epsilon)
80
+ return normalization.GroupNorm(
81
+ config.group_num, dim, config.epsilon, config.enable_hlfb
82
+ )
65
83
  else:
66
84
  raise ValueError("Unsupported norm type.")
67
85
 
@@ -71,7 +89,7 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
71
89
 
72
90
  Args:
73
91
  dim (int): dimension of the input tensor.
74
- config (`ModelConfig` object): the model configuration.
92
+ config (`FeedForwardConfig` object): the model configuration.
75
93
 
76
94
  Returns:
77
95
  The constructed `nn.Module` feedforward layer.
@@ -97,6 +115,10 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
97
115
  hidden_dim=config.intermediate_size,
98
116
  activation=activation,
99
117
  use_bias=config.use_bias,
118
+ use_glu=(
119
+ config.activation.type == cfg.ActivationType.GE_GLU
120
+ or config.activation.type == cfg.ActivationType.SILU_GLU
121
+ ),
100
122
  pre_ff_norm=pre_ff_norm,
101
123
  post_ff_norm=post_ff_norm,
102
124
  )
@@ -127,8 +149,10 @@ def get_activation(config: cfg.ActivationConfig):
127
149
  # See: https://github.com/hendrycks/GELUs
128
150
  return lambda x: x * F.sigmoid(1.702 * x)
129
151
  elif config.type == cfg.ActivationType.GE_GLU:
130
- return GeGLU(config.dim_in, config.dim_out)
152
+ return build_glu(F.gelu, config.gate_is_front)
131
153
  elif config.type == cfg.ActivationType.RELU:
132
154
  return F.relu
155
+ elif config.type == cfg.ActivationType.SILU_GLU:
156
+ return build_glu(F.silu, config.gate_is_front)
133
157
  else:
134
158
  raise ValueError("Unsupported activation type.")