ai-edge-torch-nightly 0.5.0.dev20250423__py3-none-any.whl → 0.5.0.dev20250424__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,10 +17,10 @@
17
17
 
18
18
  from typing import List, Optional, Tuple
19
19
 
20
+ from ai_edge_torch.generative.layers import attention
20
21
  from ai_edge_torch.generative.layers import builder
21
22
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
- from ai_edge_torch.generative.layers.experimental import attention
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
25
  import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
26
26
  from ai_edge_torch.generative.utilities import export_config as export_cfg
@@ -21,6 +21,7 @@ from ai_edge_torch.generative.layers import builder
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
22
  from ai_edge_torch.generative.layers import lora as lora_utils
23
23
  from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
24
+ from ai_edge_torch.generative.layers import sdpa_with_kv_update
24
25
  import ai_edge_torch.generative.layers.model_config as cfg
25
26
  import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
26
27
  import torch
@@ -142,11 +143,6 @@ class CausalSelfAttention(nn.Module):
142
143
  self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
143
144
  self.config = config
144
145
  self.enable_hlfb = enable_hlfb
145
- self.sdpa_func = (
146
- sdpa.scaled_dot_product_attention_with_hlfb
147
- if enable_hlfb
148
- else sdpa.scaled_dot_product_attention
149
- )
150
146
 
151
147
  def forward(
152
148
  self,
@@ -174,7 +170,7 @@ class CausalSelfAttention(nn.Module):
174
170
  KV Cach Entry (if passed in).
175
171
  """
176
172
  # Batch size, sequence length, embedding dimensionality.
177
- B, T, E = x.size()
173
+ B, T, _ = x.size()
178
174
  qkv = self.qkv_projection(x)
179
175
 
180
176
  # Assemble into a number of query groups to support MHA, MQA and GQA.
@@ -218,19 +214,9 @@ class CausalSelfAttention(nn.Module):
218
214
  cos, sin = rope
219
215
  q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
220
216
 
221
- if kv_cache is not None:
222
- kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
223
- k, v = kv_cache.k_cache, kv_cache.v_cache
224
-
225
- sdpa_out = self.sdpa_func(
226
- q,
227
- k,
228
- v,
229
- self.config.head_dim,
230
- mask=mask,
231
- softcap=self.config.logit_softcap,
217
+ sdpa_out, kv_cache = sdpa_with_kv_update.sdpa_with_kv_update(
218
+ q, k, v, kv_cache, input_pos, mask, self.config, self.enable_hlfb
232
219
  )
233
- sdpa_out = sdpa_out.reshape(B, T, -1)
234
220
 
235
221
  # Compute the output projection.
236
222
  y = self.output_projection(sdpa_out)
@@ -12,16 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Common utility functions for data loading etc.
16
- from dataclasses import dataclass
15
+
16
+ """Common utility functions for data loading etc."""
17
+
17
18
  from typing import Tuple
19
+
18
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
19
21
  from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa_default
20
22
  from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
21
23
  from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
22
24
  import ai_edge_torch.generative.layers.model_config as cfg
23
- from ai_edge_torch.generative.utilities import types
24
- from multipledispatch import dispatch
25
25
  import torch
26
26
 
27
27
 
@@ -33,32 +33,27 @@ def sdpa_with_kv_update(
33
33
  input_pos: torch.Tensor,
34
34
  mask: torch.Tensor,
35
35
  config: cfg.AttentionConfig,
36
+ enable_hlfb: bool,
36
37
  ) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
37
- return sdpa_with_kv_update_impl(
38
- kv.kv_layout[0](), # key layout
39
- kv.kv_layout[1](), # value layout
40
- query=query,
41
- key=key,
42
- value=value,
43
- kv=kv,
44
- input_pos=input_pos,
45
- mask=mask,
46
- config=config,
38
+ """Wrapper function for scaled dot product attention with KV cache update."""
39
+ if kv is not None and kv.kv_layout == kv_utils.KV_LAYOUT_TRANSPOSED:
40
+ return _sdpa_with_kv_update_transposed(
41
+ query, key, value, kv, input_pos, mask, config
42
+ )
43
+ return _sdpa_with_kv_update_default(
44
+ query, key, value, kv, input_pos, mask, config, enable_hlfb
47
45
  )
48
46
 
49
47
 
50
- @dispatch(types.BNTH, types.BNHT)
51
- def sdpa_with_kv_update_impl(
52
- k_type, v_type, *args, **kwargs
48
+ def _sdpa_with_kv_update_transposed(
49
+ query: torch.Tensor,
50
+ key: torch.Tensor,
51
+ value: torch.Tensor,
52
+ kv: kv_utils.KVCacheEntry,
53
+ input_pos: torch.Tensor,
54
+ mask: torch.Tensor,
55
+ config: cfg.AttentionConfig,
53
56
  ) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
54
- query = kwargs["query"]
55
- key = kwargs["key"]
56
- value = kwargs["value"]
57
- kv = kwargs["kv"]
58
- input_pos = kwargs["input_pos"]
59
- mask = kwargs["mask"]
60
- config = kwargs["config"]
61
-
62
57
  # Transpose k/v to specific layout for GPU implementation.
63
58
  b, seq_len, n, h = query.shape
64
59
  g = n // config.num_query_groups
@@ -74,9 +69,8 @@ def sdpa_with_kv_update_impl(
74
69
  1, -1, config.head_dim, seq_len
75
70
  ) # 1, bk, h, s
76
71
 
77
- if kv is not None:
78
- kv = kv_utils_experimental.update(kv, input_pos, key, value)
79
- key, value = kv.k_cache, kv.v_cache
72
+ kv = kv_utils_experimental.update(kv, input_pos, key, value)
73
+ key, value = kv.k_cache, kv.v_cache
80
74
 
81
75
  sdpa_out = sdpa.scaled_dot_product_attention(
82
76
  kv,
@@ -95,24 +89,26 @@ def sdpa_with_kv_update_impl(
95
89
  return sdpa_out, kv
96
90
 
97
91
 
98
- @dispatch(object, object)
99
- def sdpa_with_kv_update_impl(
100
- k_type, v_type, *args, **kwargs
92
+ def _sdpa_with_kv_update_default(
93
+ query: torch.Tensor,
94
+ key: torch.Tensor,
95
+ value: torch.Tensor,
96
+ kv: kv_utils.KVCacheEntry,
97
+ input_pos: torch.Tensor,
98
+ mask: torch.Tensor,
99
+ config: cfg.AttentionConfig,
100
+ enable_hlfb: bool,
101
101
  ) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
102
- query = kwargs["query"]
103
- key = kwargs["key"]
104
- value = kwargs["value"]
105
- kv = kwargs["kv"]
106
- input_pos = kwargs["input_pos"]
107
- mask = kwargs["mask"]
108
- config = kwargs["config"]
109
-
110
102
  b, seq_len, _, _ = query.shape
111
103
  if kv is not None:
112
104
  kv = kv_utils.update(kv, input_pos, key, value)
113
105
  key, value = kv.k_cache, kv.v_cache
114
106
 
115
- sdpa_out = sdpa_default.scaled_dot_product_attention(
107
+ if enable_hlfb:
108
+ sdpa_func = sdpa_default.scaled_dot_product_attention_with_hlfb
109
+ else:
110
+ sdpa_func = sdpa_default.scaled_dot_product_attention
111
+ sdpa_out = sdpa_func(
116
112
  query,
117
113
  key,
118
114
  value,
@@ -41,7 +41,7 @@ class TestModelConversion(googletest.TestCase):
41
41
  )
42
42
  )
43
43
 
44
- def _get_params(self, enable_hlfb: bool):
44
+ def _get_params(self, enable_hlfb: bool, kv_layout: kv_cache.KVLayout):
45
45
  """Returns a model, edge model and the kwargs to use for testing."""
46
46
  config = toy_model_with_kv_cache.get_model_config()
47
47
  config.enable_hlfb = enable_hlfb
@@ -49,7 +49,7 @@ class TestModelConversion(googletest.TestCase):
49
49
  tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
50
50
  [10], dtype=torch.int
51
51
  )
52
- kv = kv_cache.KVCache.from_model_config(config)
52
+ kv = kv_cache.KVCache.from_model_config(config, kv_layout=kv_layout)
53
53
  kwargs = {
54
54
  "tokens": tokens,
55
55
  "input_pos": input_pos,
@@ -65,8 +65,12 @@ class TestModelConversion(googletest.TestCase):
65
65
  )
66
66
  return pytorch_model, edge_model, kwargs
67
67
 
68
- def _test_model_with_kv_cache(self, enable_hlfb: bool):
69
- pytorch_model, edge_model, kwargs = self._get_params(enable_hlfb)
68
+ def _test_model_with_kv_cache(
69
+ self,
70
+ enable_hlfb: bool = False,
71
+ kv_layout: kv_cache.KVLayout = kv_cache.KV_LAYOUT_DEFAULT,
72
+ ):
73
+ pytorch_model, edge_model, kwargs = self._get_params(enable_hlfb, kv_layout)
70
74
 
71
75
  self.assertTrue(
72
76
  test_utils.compare_tflite_torch(
@@ -95,13 +99,22 @@ class TestModelConversion(googletest.TestCase):
95
99
  def test_toy_model_with_kv_cache_with_hlfb(self):
96
100
  self._test_model_with_kv_cache(enable_hlfb=True)
97
101
 
102
+ @googletest.skipIf(
103
+ ai_edge_torch.config.in_oss,
104
+ reason="tests with custom ops are not supported in oss",
105
+ )
106
+ def test_toy_model_with_kv_cache_transposed(self):
107
+ self._test_model_with_kv_cache(kv_layout=kv_cache.KV_LAYOUT_TRANSPOSED)
108
+
98
109
  @googletest.skipIf(
99
110
  ai_edge_torch.config.in_oss,
100
111
  reason="tests with custom ops are not supported in oss",
101
112
  )
102
113
  def test_toy_model_has_dus_op(self):
103
114
  """Tests that the model has the dynamic update slice op."""
104
- _, edge_model, _ = self._get_params(enable_hlfb=True)
115
+ _, edge_model, _ = self._get_params(
116
+ enable_hlfb=True, kv_layout=kv_cache.KV_LAYOUT_DEFAULT
117
+ )
105
118
  interpreter_ = interpreter.InterpreterWithCustomOps(
106
119
  custom_op_registerers=["GenAIOpsRegisterer"],
107
120
  model_content=edge_model.tflite_model(),
@@ -112,7 +125,14 @@ class TestModelConversion(googletest.TestCase):
112
125
  op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
113
126
  self.assertIn("DYNAMIC_UPDATE_SLICE", op_names)
114
127
 
115
- def _test_multisig_model(self, config, pytorch_model, atol, rtol):
128
+ def _test_multisig_model(
129
+ self,
130
+ config,
131
+ pytorch_model,
132
+ atol,
133
+ rtol,
134
+ kv_layout=kv_cache.KV_LAYOUT_DEFAULT,
135
+ ):
116
136
  # prefill
117
137
  seq_len = 10
118
138
  prefill_tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
@@ -124,7 +144,7 @@ class TestModelConversion(googletest.TestCase):
124
144
  decode_token = torch.tensor([[1]], dtype=torch.int)
125
145
  decode_input_pos = torch.tensor([5], dtype=torch.int)
126
146
 
127
- kv = kv_cache.KVCache.from_model_config(config)
147
+ kv = kv_cache.KVCache.from_model_config(config, kv_layout=kv_layout)
128
148
 
129
149
  edge_model = (
130
150
  ai_edge_torch.signature(
@@ -160,7 +180,7 @@ class TestModelConversion(googletest.TestCase):
160
180
  kv,
161
181
  signature_name="prefill",
162
182
  atol=atol,
163
- rtol=atol,
183
+ rtol=rtol,
164
184
  )
165
185
  )
166
186
 
@@ -173,7 +193,7 @@ class TestModelConversion(googletest.TestCase):
173
193
  kv,
174
194
  signature_name="decode",
175
195
  atol=atol,
176
- rtol=atol,
196
+ rtol=rtol,
177
197
  )
178
198
  )
179
199
 
@@ -186,6 +206,21 @@ class TestModelConversion(googletest.TestCase):
186
206
  pytorch_model = tiny_llama.TinyLlama(config).eval()
187
207
  self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
188
208
 
209
+ @googletest.skipIf(
210
+ ai_edge_torch.config.in_oss,
211
+ reason="tests with custom ops are not supported in oss",
212
+ )
213
+ def test_tiny_llama_multisig_kv_layout_transposed(self):
214
+ config = tiny_llama.get_fake_model_config()
215
+ pytorch_model = tiny_llama.TinyLlama(config).eval()
216
+ self._test_multisig_model(
217
+ config,
218
+ pytorch_model,
219
+ atol=1e-5,
220
+ rtol=1e-5,
221
+ kv_layout=kv_cache.KV_LAYOUT_TRANSPOSED,
222
+ )
223
+
189
224
 
190
225
  if __name__ == "__main__":
191
226
  googletest.main()
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.5.0.dev20250423"
16
+ __version__ = "0.5.0.dev20250424"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.5.0.dev20250423
3
+ Version: 0.5.0.dev20250424
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=DjzQwP8czvLmUu-dJhnWVQJHOuaOqJJKuH2_TOViMvg,706
5
+ ai_edge_torch/version.py,sha256=Nixp49eAXZPPMWEWkqpm_M4Mi_WGPx-I8q2noKuh0hw,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=dOr3TUfF0UCvkmlUrMqKvgaN4jh3lJ9XFuO-sHaAmIw,5521
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -67,7 +67,7 @@ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-
67
67
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
68
68
  ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
69
69
  ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=szssSBrIUYdNIoU7LHdAq7wCqgjaY6qbV8yvTgg796Q,2945
70
- ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=n6ZQfqNEHuOhY7Pu21bb8Eax8yn2Sx5osTKJKmhonXY,15659
70
+ ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=eXWE5CSX0KeUMsPevgsYOfvyajl9F1RFF4DCWhHcYPA,15646
71
71
  ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=GACDBI_MsFowR8A3wAWrpzradPYe-AUgB9ZjXaVBG-s,6485
72
72
  ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
73
73
  ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
@@ -150,7 +150,7 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=LRu6PSw7Lqu6HGbv1t
150
150
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=PFSMsA1vfBfrV9ssBCkYJNl8Hx_bLdWjN01iyjPM5jE,1094
151
151
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
152
152
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
153
- ai_edge_torch/generative/layers/attention.py,sha256=wLZ1jgUlcODBWgK3hnnhclHuuQDqYuGOZdYAI9EooOM,13247
153
+ ai_edge_torch/generative/layers/attention.py,sha256=uK1ih2kxPZherwi-pGSm8B--NNWnQ8npEAfgcjMIkEY,12964
154
154
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
155
155
  ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
156
156
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
@@ -160,9 +160,8 @@ ai_edge_torch/generative/layers/model_config.py,sha256=nLXvTkDAIHJQ0PTaWODF8oxJQ
160
160
  ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
161
161
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
162
162
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
163
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=oo9h7pi0GcuylRgp2yUuvUJCrhj03aoWt_fP7EDP4LM,3775
163
+ ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=D4rATT2Ppa9Su7yuRHYnQPJ1dFvUDAyH1GrFnCed7p8,3810
164
164
  ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
165
- ai_edge_torch/generative/layers/experimental/attention.py,sha256=XYbo1KlmiMEuwArye0Ul86jEsdxLr1RG-usRpidZiT8,8001
166
165
  ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=zgpFVftOfllvjh9-UEBSvUbm152SnQETn29rUMMMvAM,2978
167
166
  ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=YFcIGOkaNb-vvQKjI-G9-bC2Z1W0O_qRyIZPlsLl72U,2797
168
167
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -181,7 +180,7 @@ ai_edge_torch/generative/test/test_custom_dus.py,sha256=MjIhTvkTko872M35XMciobvI
181
180
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=1sXN2RPntq0PP3IEy0NkvIbzQ0Y8JhPIwRSFwO9JLlE,5728
182
181
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
183
182
  ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
184
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
183
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=jSNJ0Eex6VYCkGn3FXbCOOJ2S3-F_QuwJctu3VycjR4,7200
185
184
  ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=-v2Vj7Qdd3GyBn4k7BWVgyGzrbcL30Su3nxZYLtwkCs,14787
186
185
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
187
186
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
@@ -245,8 +244,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
245
244
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
246
245
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
247
246
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
248
- ai_edge_torch_nightly-0.5.0.dev20250423.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
249
- ai_edge_torch_nightly-0.5.0.dev20250423.dist-info/METADATA,sha256=PGzcX4WVfFW0wE0TSKLAuRB94iemrNff4L8CL_VUMnQ,2051
250
- ai_edge_torch_nightly-0.5.0.dev20250423.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
251
- ai_edge_torch_nightly-0.5.0.dev20250423.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
252
- ai_edge_torch_nightly-0.5.0.dev20250423.dist-info/RECORD,,
247
+ ai_edge_torch_nightly-0.5.0.dev20250424.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
248
+ ai_edge_torch_nightly-0.5.0.dev20250424.dist-info/METADATA,sha256=Gz8c2qvL6qiK7lrd001P55TXltKdycDvDaAq4d4Y-eQ,2051
249
+ ai_edge_torch_nightly-0.5.0.dev20250424.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
250
+ ai_edge_torch_nightly-0.5.0.dev20250424.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
251
+ ai_edge_torch_nightly-0.5.0.dev20250424.dist-info/RECORD,,
@@ -1,231 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Common building blocks for a GPU-specific Attention layer.
17
-
18
- This is a temporary implemenation for the GPU. It is subject to change/removal
19
- at any time.
20
- """
21
-
22
- from typing import Optional, Tuple, Union
23
-
24
- from ai_edge_torch.generative.layers import builder
25
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
26
- from ai_edge_torch.generative.layers import lora as lora_utils
27
- from ai_edge_torch.generative.layers import sdpa_with_kv_update
28
- import ai_edge_torch.generative.layers.model_config as cfg
29
- import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
30
- import torch
31
- from torch import nn
32
-
33
-
34
- class TransformerBlock(nn.Module):
35
-
36
- def __init__(
37
- self,
38
- config: cfg.TransformerBlockConfig,
39
- model_config: cfg.ModelConfig,
40
- ) -> None:
41
- """Initialize an instance of the TransformerBlock.
42
-
43
- Args:
44
- config (cfg.TransformerBlockConfig): the configuration object for this
45
- transformer block.
46
- model_config (cfg.ModelConfig): the configuration object for the model
47
- this transformer block belongs to.
48
- """
49
- super().__init__()
50
- self.pre_atten_norm = builder.build_norm(
51
- model_config.embedding_dim,
52
- config.pre_attention_norm_config,
53
- )
54
- self.atten_func = CausalSelfAttention(
55
- model_config.embedding_dim,
56
- config.attn_config,
57
- model_config.enable_hlfb,
58
- )
59
- self.post_atten_norm = builder.build_norm(
60
- model_config.embedding_dim,
61
- config.post_attention_norm_config,
62
- )
63
- self.ff = builder.build_ff(model_config.embedding_dim, config.ff_config)
64
- self.config = config
65
-
66
- def forward(
67
- self,
68
- x: torch.Tensor,
69
- rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
70
- mask: Optional[torch.Tensor] = None,
71
- input_pos: Optional[torch.Tensor] = None,
72
- kv_cache: kv_utils.KVCacheEntry = None,
73
- lora: Optional[lora_utils.LoRAEntry] = None,
74
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
75
- """Forward function of the TransformerBlock.
76
-
77
- Args:
78
- x (torch.Tensor): the input tensor.
79
- rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
80
- mask (torch.Tensor): the optional mask tensor.
81
- input_pos (torch.Tensor): the optional input position tensor.
82
- kv_cache (KVCacheEntry): the optional kv cache entry.
83
- lora (LoRAEntry): the optional lora entry.
84
-
85
- Returns:
86
- output activation from this transformer block, and updated kv cache (if
87
- passed in).
88
- """
89
- kv = None
90
- if self.config.parallel_residual:
91
- x_norm = self.pre_atten_norm(x)
92
- atten_func_out = self.atten_func(
93
- x_norm, rope, mask, input_pos, kv_cache, lora
94
- )
95
- if kv_cache is None:
96
- attn_out = atten_func_out
97
- else:
98
- attn_out, kv = atten_func_out
99
- ff_out = self.ff(x_norm)
100
- output = x + attn_out + ff_out
101
- else:
102
- x_norm = self.pre_atten_norm(x)
103
- atten_func_out = self.atten_func(
104
- x_norm, rope, mask, input_pos, kv_cache, lora
105
- )
106
- if kv_cache is None:
107
- attn_out = atten_func_out
108
- else:
109
- attn_out, kv = atten_func_out
110
- x = x + attn_out
111
- x_norm = self.post_atten_norm(x)
112
- output = x + self.ff(x_norm)
113
-
114
- return output if kv is None else (output, kv)
115
-
116
-
117
- class CausalSelfAttention(nn.Module):
118
-
119
- def __init__(
120
- self,
121
- dim: int,
122
- config: cfg.AttentionConfig,
123
- enable_hlfb: bool,
124
- ) -> None:
125
- """Initialize an instance of CausalSelfAttention.
126
-
127
- Args:
128
- dim (int): causal attention's input/output dimmension.
129
- config (cfg.AttentionConfig): attention specific configurations.
130
- enable_hlfb (bool): whether hlfb is enabled or not.
131
- """
132
- super().__init__()
133
- self.kv_cache = None
134
- qkv_shape = (
135
- config.num_heads + 2 * config.num_query_groups
136
- ) * config.head_dim
137
- output_shape = config.num_heads * config.head_dim
138
- # Key, query, value projections for all heads.
139
- self.qkv_projection = nn.Linear(dim, qkv_shape, bias=config.qkv_use_bias)
140
- self.output_projection = nn.Linear(
141
- output_shape, dim, bias=config.output_proj_use_bias
142
- )
143
- self.query_norm = builder.build_norm(
144
- config.head_dim, config.query_norm_config
145
- )
146
- self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
147
- self.config = config
148
- self.enable_hlfb = enable_hlfb
149
-
150
- def forward(
151
- self,
152
- x: torch.Tensor,
153
- rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
154
- mask: Optional[torch.Tensor] = None,
155
- input_pos: Optional[torch.Tensor] = None,
156
- kv_cache: Optional[kv_utils.KVCacheEntry] = None,
157
- lora: Optional[lora_utils.LoRAEntry] = None,
158
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
159
- """Forward function of the CausalSelfAttention layer, which can support
160
-
161
- MQA, GQA and MHA.
162
-
163
- Args:
164
- x (torch.Tensor): the input tensor.
165
- rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
166
- mask (torch.Tensor): the optional mask tensor.
167
- input_pos (torch.Tensor): the optional input position tensor.
168
- kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
169
- lora (LoRAEntry): the optional lora entry.
170
-
171
- Returns:
172
- output activation from this self attention layer, and the updated
173
- KV Cach Entry (if passed in).
174
- """
175
- # Batch size, sequence length, embedding dimensionality.
176
- B, T, E = x.size()
177
-
178
- qkv = self.qkv_projection(x)
179
-
180
- # Assemble into a number of query groups to support MHA, MQA and GQA.
181
- q_per_kv = self.config.num_heads // self.config.num_query_groups
182
- # Each group has >=1 queries, 1 key, and 1 value.
183
- if self.config.qkv_transpose_before_split:
184
- qkv = qkv.view(B, T, -1, self.config.head_dim)
185
- q, k, v = qkv.split(
186
- (
187
- q_per_kv * self.config.num_query_groups,
188
- self.config.num_query_groups,
189
- self.config.num_query_groups,
190
- ),
191
- dim=-2,
192
- )
193
- else:
194
- qkv = qkv.view(B, T, self.config.num_query_groups, -1)
195
- q, k, v = qkv.split(
196
- (
197
- q_per_kv * self.config.head_dim,
198
- self.config.head_dim,
199
- self.config.head_dim,
200
- ),
201
- dim=-1,
202
- )
203
-
204
- if lora is not None:
205
- q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
206
- k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
207
- v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
208
-
209
- q = self.query_norm(q)
210
- k = self.key_norm(k)
211
-
212
- q = q.reshape(B, T, -1, self.config.head_dim)
213
- k = k.reshape(B, T, -1, self.config.head_dim)
214
- v = v.reshape(B, T, -1, self.config.head_dim)
215
-
216
- if rope is not None:
217
- # Compute rotary positional embedding for query and key.
218
- n_elem = int(self.config.rotary_percentage * self.config.head_dim)
219
- cos, sin = rope
220
- q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
221
-
222
- sdpa_out, kv_cache = sdpa_with_kv_update.sdpa_with_kv_update(
223
- q, k, v, kv_cache, input_pos, mask, self.config
224
- )
225
-
226
- # Compute the output projection.
227
- y = self.output_projection(sdpa_out)
228
- if lora is not None:
229
- y += lora_utils.apply_lora(sdpa_out, lora.attention.output)
230
-
231
- return y if kv_cache is None else (y, kv_cache)