ai-edge-torch-nightly 0.5.0.dev20250423__py3-none-any.whl → 0.5.0.dev20250425__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. ai_edge_torch/_convert/conversion.py +1 -3
  2. ai_edge_torch/_convert/fx_passes/__init__.py +0 -1
  3. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +63 -2
  4. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +2 -1
  5. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +38 -4
  6. ai_edge_torch/generative/examples/deepseek/deepseek.py +1 -0
  7. ai_edge_torch/generative/examples/gemma3/decoder.py +1 -1
  8. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +37 -2
  9. ai_edge_torch/generative/examples/qwen/qwen.py +1 -0
  10. ai_edge_torch/generative/layers/attention.py +4 -18
  11. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +51 -0
  12. ai_edge_torch/generative/layers/sdpa_with_kv_update.py +38 -44
  13. ai_edge_torch/generative/test/test_model_conversion.py +38 -33
  14. ai_edge_torch/generative/test/test_model_conversion_large.py +3 -75
  15. ai_edge_torch/generative/utilities/converter.py +5 -0
  16. ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +2 -0
  17. ai_edge_torch/version.py +1 -1
  18. {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/METADATA +1 -1
  19. {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/RECORD +22 -25
  20. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +0 -129
  21. ai_edge_torch/generative/layers/experimental/attention.py +0 -231
  22. ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +0 -93
  23. {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/LICENSE +0 -0
  24. {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/WHEEL +0 -0
  25. {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/top_level.txt +0 -0
@@ -35,13 +35,11 @@ def _run_convert_passes(
35
35
  )
36
36
 
37
37
  passes = [
38
- fx_passes.CastInputsBf16ToF32Pass(),
39
- fx_passes.BuildInterpolateCompositePass(),
40
- fx_passes.CanonicalizePass(),
41
38
  fx_passes.OptimizeLayoutTransposesPass(),
42
39
  fx_passes.CanonicalizePass(),
43
40
  fx_passes.BuildAtenCompositePass(),
44
41
  fx_passes.RemoveNonUserOutputsPass(),
42
+ fx_passes.CastInputsBf16ToF32Pass(),
45
43
  ]
46
44
 
47
45
  # Debuginfo is not injected automatically by odml_torch. Only inject
@@ -16,7 +16,6 @@
16
16
  from typing import Sequence, Union
17
17
 
18
18
  from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
19
- from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
20
19
  from ai_edge_torch._convert.fx_passes.cast_inputs_bf16_to_f32_pass import CastInputsBf16ToF32Pass
21
20
  from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
22
21
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
@@ -20,7 +20,8 @@ import torch
20
20
  import torch.utils._pytree as pytree
21
21
 
22
22
  _composite_builders: dict[
23
- Callable, Callable[[torch.fx.GraphModule, torch.fx.Node], None]
23
+ Callable[[Any, ...], Any],
24
+ Callable[[torch.fx.GraphModule, torch.fx.Node], None],
24
25
  ] = {}
25
26
 
26
27
 
@@ -272,13 +273,73 @@ def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node):
272
273
  output = op(**full_kwargs)
273
274
  output = builder.mark_outputs(output)
274
275
 
275
- # Explicitly reshape back to the original shape. This places the ReshapeOp outside of the HLFB.
276
+ # Explicitly reshape back to the original shape. This places the ReshapeOp
277
+ # outside of the HLFB.
276
278
  output = torch.reshape(output, (*(original_idx_shape), embedding_dim))
277
279
  return output
278
280
 
279
281
  node.target = embedding
280
282
 
281
283
 
284
+ @_register_composite_builder(torch.ops.aten.upsample_bilinear2d.vec)
285
+ def _aten_upsample_bilinear2d_vec(_, node: torch.fx.Node):
286
+ """Build a composite for aten.upsample_bilinear2d.vec."""
287
+ op = node.target
288
+ args_mapper = TorchOpArgumentsMapper(op)
289
+ # Assumes later FX passes does not change the args/kwargs of the op.
290
+ # Which is a valid assumption for, given that composite/mark_tensor wrapper
291
+ # should semantically prevents any future mutations on the op.
292
+ output_h, output_w = node.meta["val"].shape[-2:]
293
+
294
+ def upsample_bilinear2d_vec(*args, **kwargs):
295
+ nonlocal op, args_mapper
296
+ full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
297
+
298
+ builder = lowertools.StableHLOCompositeBuilder(
299
+ name="odml.upsample_bilinear2d",
300
+ attr={
301
+ "size": (int(output_h), int(output_w)),
302
+ "align_corners": full_kwargs["align_corners"],
303
+ "is_nchw_op": True,
304
+ },
305
+ )
306
+ full_kwargs["input"] = builder.mark_inputs(full_kwargs["input"])
307
+ output = op(**full_kwargs)
308
+ output = builder.mark_outputs(output)
309
+ return output
310
+
311
+ node.target = upsample_bilinear2d_vec
312
+
313
+
314
+ @_register_composite_builder(torch.ops.aten.upsample_nearest2d.vec)
315
+ def _aten_upsample_nearest2d_vec(_, node: torch.fx.Node):
316
+ """Build a composite for aten.upsample_nearest2d.vec."""
317
+ op = node.target
318
+ args_mapper = TorchOpArgumentsMapper(op)
319
+ # Assumes later FX passes does not change the args/kwargs of the op.
320
+ # Which is a valid assumption for, given that composite/mark_tensor wrapper
321
+ # should semantically prevents any future mutations on the op.
322
+ output_h, output_w = node.meta["val"].shape[-2:]
323
+
324
+ def upsample_nearest2d_vec(*args, **kwargs):
325
+ nonlocal op, args_mapper
326
+ full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
327
+
328
+ builder = lowertools.StableHLOCompositeBuilder(
329
+ name="tfl.resize_nearest_neighbor",
330
+ attr={
331
+ "size": (int(output_h), int(output_w)),
332
+ "is_nchw_op": True,
333
+ },
334
+ )
335
+ full_kwargs["input"] = builder.mark_inputs(full_kwargs["input"])
336
+ output = op(**full_kwargs)
337
+ output = builder.mark_outputs(output)
338
+ return output
339
+
340
+ node.target = upsample_nearest2d_vec
341
+
342
+
282
343
  class BuildAtenCompositePass(fx_infra.PassBase):
283
344
 
284
345
  def call(self, graph_module: torch.fx.GraphModule):
@@ -17,6 +17,7 @@
17
17
  import operator
18
18
 
19
19
  import ai_edge_torch
20
+ from ai_edge_torch import lowertools
20
21
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark
21
22
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import op_func_registry
22
23
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
@@ -24,7 +25,7 @@ import torch
24
25
  import torch.utils._pytree as pytree
25
26
 
26
27
  aten = torch.ops.aten
27
- StableHLOCompositeBuilder = ai_edge_torch.hlfb.StableHLOCompositeBuilder
28
+ StableHLOCompositeBuilder = lowertools.StableHLOCompositeBuilder
28
29
 
29
30
  __all__ = ["rewrite_nhwc_node", "has_nhwc_rewriter"]
30
31
 
@@ -17,11 +17,43 @@
17
17
 
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.deepseek import deepseek
20
+ from ai_edge_torch.generative.layers import kv_cache
20
21
  from ai_edge_torch.generative.utilities import converter
21
- from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities.model_builder import export_cfg
23
+ import torch
24
+
25
+ flags = converter.define_conversion_flags('deepseek')
26
+ ExportConfig = export_cfg.ExportConfig
27
+
28
+
29
+ def _create_mask(mask_len, kv_cache_max_len):
30
+ mask = torch.full(
31
+ (mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
32
+ )
33
+ mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
34
+ return mask
35
+
36
+
37
+ def _create_export_config(
38
+ prefill_seq_lens: list[int], kv_cache_max_len: int
39
+ ) -> ExportConfig:
40
+ """Creates the export config for the model."""
41
+ export_config = ExportConfig()
42
+ if isinstance(prefill_seq_lens, list):
43
+ prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
44
+ else:
45
+ prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
46
+
47
+ export_config.prefill_mask = prefill_mask
48
+
49
+ decode_mask = torch.full(
50
+ (1, kv_cache_max_len), float('-inf'), dtype=torch.float32
51
+ )
52
+ decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
53
+ export_config.decode_mask = decode_mask
54
+ export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
55
+ return export_config
22
56
 
23
- flags = converter.define_conversion_flags("deepseek")
24
- ExportConfig = export_config.ExportConfig
25
57
 
26
58
  def main(_):
27
59
  pytorch_model = deepseek.build_model(
@@ -34,7 +66,9 @@ def main(_):
34
66
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
35
67
  quantize=flags.FLAGS.quantize,
36
68
  lora_ranks=flags.FLAGS.lora_ranks,
37
- export_config=ExportConfig(),
69
+ export_config=_create_export_config(
70
+ flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
71
+ ),
38
72
  )
39
73
 
40
74
 
@@ -53,6 +53,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
53
53
  norm_config = cfg.NormalizationConfig(
54
54
  type=cfg.NormalizationType.RMS_NORM,
55
55
  epsilon=1e-06,
56
+ enable_hlfb=True,
56
57
  )
57
58
  block_config = cfg.TransformerBlockConfig(
58
59
  attn_config=attn_config,
@@ -17,10 +17,10 @@
17
17
 
18
18
  from typing import List, Optional, Tuple
19
19
 
20
+ from ai_edge_torch.generative.layers import attention
20
21
  from ai_edge_torch.generative.layers import builder
21
22
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
- from ai_edge_torch.generative.layers.experimental import attention
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
25
  import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
26
26
  from ai_edge_torch.generative.utilities import export_config as export_cfg
@@ -17,13 +17,14 @@
17
17
 
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.qwen import qwen
20
+ from ai_edge_torch.generative.layers import kv_cache
20
21
  from ai_edge_torch.generative.utilities import converter
21
22
  from ai_edge_torch.generative.utilities import export_config
23
+ import torch
22
24
 
23
25
  flags = converter.define_conversion_flags('qwen')
24
26
  ExportConfig = export_config.ExportConfig
25
27
 
26
-
27
28
  _MODEL_SIZE = flags.DEFINE_enum(
28
29
  'model_size',
29
30
  '3b',
@@ -37,6 +38,36 @@ _BUILDER = {
37
38
  '3b': qwen.build_3b_model,
38
39
  }
39
40
 
41
+
42
+ def _create_mask(mask_len, kv_cache_max_len):
43
+ mask = torch.full(
44
+ (mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
45
+ )
46
+ mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
47
+ return mask
48
+
49
+
50
+ def _create_export_config(
51
+ prefill_seq_lens: list[int], kv_cache_max_len: int
52
+ ) -> ExportConfig:
53
+ """Creates the export config for the model."""
54
+ export_config = ExportConfig()
55
+ if isinstance(prefill_seq_lens, list):
56
+ prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
57
+ else:
58
+ prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
59
+
60
+ export_config.prefill_mask = prefill_mask
61
+
62
+ decode_mask = torch.full(
63
+ (1, kv_cache_max_len), float('-inf'), dtype=torch.float32
64
+ )
65
+ decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
66
+ export_config.decode_mask = decode_mask
67
+ export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
68
+ return export_config
69
+
70
+
40
71
  def main(_):
41
72
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
42
73
  flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
@@ -48,7 +79,11 @@ def main(_):
48
79
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
49
80
  quantize=flags.FLAGS.quantize,
50
81
  lora_ranks=flags.FLAGS.lora_ranks,
51
- export_config=ExportConfig(),
82
+ export_config=_create_export_config(
83
+ flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
84
+ )
85
+ if flags.FLAGS.transpose_kv_cache
86
+ else ExportConfig(),
52
87
  )
53
88
 
54
89
 
@@ -53,6 +53,7 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
53
53
  norm_config = cfg.NormalizationConfig(
54
54
  type=cfg.NormalizationType.RMS_NORM,
55
55
  epsilon=1e-06,
56
+ enable_hlfb=True,
56
57
  )
57
58
  block_config = cfg.TransformerBlockConfig(
58
59
  attn_config=attn_config,
@@ -21,6 +21,7 @@ from ai_edge_torch.generative.layers import builder
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
22
  from ai_edge_torch.generative.layers import lora as lora_utils
23
23
  from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
24
+ from ai_edge_torch.generative.layers import sdpa_with_kv_update
24
25
  import ai_edge_torch.generative.layers.model_config as cfg
25
26
  import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
26
27
  import torch
@@ -142,11 +143,6 @@ class CausalSelfAttention(nn.Module):
142
143
  self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
143
144
  self.config = config
144
145
  self.enable_hlfb = enable_hlfb
145
- self.sdpa_func = (
146
- sdpa.scaled_dot_product_attention_with_hlfb
147
- if enable_hlfb
148
- else sdpa.scaled_dot_product_attention
149
- )
150
146
 
151
147
  def forward(
152
148
  self,
@@ -174,7 +170,7 @@ class CausalSelfAttention(nn.Module):
174
170
  KV Cach Entry (if passed in).
175
171
  """
176
172
  # Batch size, sequence length, embedding dimensionality.
177
- B, T, E = x.size()
173
+ B, T, _ = x.size()
178
174
  qkv = self.qkv_projection(x)
179
175
 
180
176
  # Assemble into a number of query groups to support MHA, MQA and GQA.
@@ -218,19 +214,9 @@ class CausalSelfAttention(nn.Module):
218
214
  cos, sin = rope
219
215
  q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
220
216
 
221
- if kv_cache is not None:
222
- kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
223
- k, v = kv_cache.k_cache, kv_cache.v_cache
224
-
225
- sdpa_out = self.sdpa_func(
226
- q,
227
- k,
228
- v,
229
- self.config.head_dim,
230
- mask=mask,
231
- softcap=self.config.logit_softcap,
217
+ sdpa_out, kv_cache = sdpa_with_kv_update.sdpa_with_kv_update(
218
+ q, k, v, kv_cache, input_pos, mask, self.config, self.enable_hlfb
232
219
  )
233
- sdpa_out = sdpa_out.reshape(B, T, -1)
234
220
 
235
221
  # Compute the output projection.
236
222
  y = self.output_projection(sdpa_out)
@@ -17,6 +17,8 @@
17
17
  import math
18
18
  from typing import Optional
19
19
 
20
+ from ai_edge_torch.generative.custom_ops import bmm_4d as bmm_lib
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
20
22
  from ai_edge_torch.hlfb import StableHLOCompositeBuilder
21
23
  import torch
22
24
  import torch.nn.functional as F
@@ -142,3 +144,52 @@ def scaled_dot_product_attention_with_hlfb(
142
144
  result = y.transpose(1, 2)
143
145
  result = builder.mark_outputs(result)
144
146
  return result
147
+
148
+
149
+ def scaled_dot_product_attention_transposed(
150
+ query: torch.Tensor,
151
+ key: torch.Tensor,
152
+ value: torch.Tensor,
153
+ head_size: int,
154
+ mask: Optional[torch.Tensor] = None,
155
+ scale: Optional[float] = None,
156
+ softcap: Optional[float] = None,
157
+ ):
158
+ """Scaled dot product attention with transposed key and value.
159
+
160
+ Args:
161
+ query: Query tensor, with shape [B, T, N, H].
162
+ key: Key tensor, with shape [B, T, KV_LEN, H].
163
+ value: Value tensor, with shape [B, T, KV_LEN, H].
164
+ head_size (int): head dimension.
165
+ mask (torch.Tensor): the optional mask tensor.
166
+ scale (float): the optional scale factor.
167
+ softcap (float): the optional softcap for the logits.
168
+
169
+ Returns:
170
+ The output tensor of scaled_dot_product_attention_transposed.
171
+ """
172
+
173
+ if scale is None:
174
+ scale = 1.0 / math.sqrt(head_size)
175
+
176
+ query = query * scale
177
+
178
+ assert mask is not None, "Mask should not be None!"
179
+ t = mask.shape[2]
180
+
181
+ logits = bmm_lib.bmm_4d(query, key)
182
+
183
+ _, bk, gt, s = logits.shape
184
+ g = gt // t
185
+ logits = logits.reshape((bk, g, t, s))
186
+ if softcap is not None:
187
+ logits = torch.tanh(logits / softcap)
188
+ logits = logits * softcap
189
+
190
+ padded_logits = logits + mask
191
+ padded_logits = padded_logits.reshape(1, bk, gt, s)
192
+ probs = F.softmax(padded_logits, dim=-1).type_as(key)
193
+ encoded = bmm_lib.bmm_4d(probs, value)
194
+
195
+ return encoded # 1, bk, gt, h
@@ -12,16 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Common utility functions for data loading etc.
16
- from dataclasses import dataclass
15
+
16
+ """Common utility functions for data loading etc."""
17
+
17
18
  from typing import Tuple
19
+
18
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
19
- from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa_default
21
+ from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
20
22
  from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
21
- from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
22
23
  import ai_edge_torch.generative.layers.model_config as cfg
23
- from ai_edge_torch.generative.utilities import types
24
- from multipledispatch import dispatch
25
24
  import torch
26
25
 
27
26
 
@@ -33,32 +32,27 @@ def sdpa_with_kv_update(
33
32
  input_pos: torch.Tensor,
34
33
  mask: torch.Tensor,
35
34
  config: cfg.AttentionConfig,
35
+ enable_hlfb: bool,
36
36
  ) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
37
- return sdpa_with_kv_update_impl(
38
- kv.kv_layout[0](), # key layout
39
- kv.kv_layout[1](), # value layout
40
- query=query,
41
- key=key,
42
- value=value,
43
- kv=kv,
44
- input_pos=input_pos,
45
- mask=mask,
46
- config=config,
37
+ """Wrapper function for scaled dot product attention with KV cache update."""
38
+ if kv is not None and kv.kv_layout == kv_utils.KV_LAYOUT_TRANSPOSED:
39
+ return _sdpa_with_kv_update_transposed(
40
+ query, key, value, kv, input_pos, mask, config
41
+ )
42
+ return _sdpa_with_kv_update_default(
43
+ query, key, value, kv, input_pos, mask, config, enable_hlfb
47
44
  )
48
45
 
49
46
 
50
- @dispatch(types.BNTH, types.BNHT)
51
- def sdpa_with_kv_update_impl(
52
- k_type, v_type, *args, **kwargs
47
+ def _sdpa_with_kv_update_transposed(
48
+ query: torch.Tensor,
49
+ key: torch.Tensor,
50
+ value: torch.Tensor,
51
+ kv: kv_utils.KVCacheEntry,
52
+ input_pos: torch.Tensor,
53
+ mask: torch.Tensor,
54
+ config: cfg.AttentionConfig,
53
55
  ) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
54
- query = kwargs["query"]
55
- key = kwargs["key"]
56
- value = kwargs["value"]
57
- kv = kwargs["kv"]
58
- input_pos = kwargs["input_pos"]
59
- mask = kwargs["mask"]
60
- config = kwargs["config"]
61
-
62
56
  # Transpose k/v to specific layout for GPU implementation.
63
57
  b, seq_len, n, h = query.shape
64
58
  g = n // config.num_query_groups
@@ -74,12 +68,10 @@ def sdpa_with_kv_update_impl(
74
68
  1, -1, config.head_dim, seq_len
75
69
  ) # 1, bk, h, s
76
70
 
77
- if kv is not None:
78
- kv = kv_utils_experimental.update(kv, input_pos, key, value)
79
- key, value = kv.k_cache, kv.v_cache
71
+ kv = kv_utils_experimental.update(kv, input_pos, key, value)
72
+ key, value = kv.k_cache, kv.v_cache
80
73
 
81
- sdpa_out = sdpa.scaled_dot_product_attention(
82
- kv,
74
+ sdpa_out = sdpa.scaled_dot_product_attention_transposed(
83
75
  query,
84
76
  key,
85
77
  value,
@@ -95,24 +87,26 @@ def sdpa_with_kv_update_impl(
95
87
  return sdpa_out, kv
96
88
 
97
89
 
98
- @dispatch(object, object)
99
- def sdpa_with_kv_update_impl(
100
- k_type, v_type, *args, **kwargs
90
+ def _sdpa_with_kv_update_default(
91
+ query: torch.Tensor,
92
+ key: torch.Tensor,
93
+ value: torch.Tensor,
94
+ kv: kv_utils.KVCacheEntry,
95
+ input_pos: torch.Tensor,
96
+ mask: torch.Tensor,
97
+ config: cfg.AttentionConfig,
98
+ enable_hlfb: bool,
101
99
  ) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
102
- query = kwargs["query"]
103
- key = kwargs["key"]
104
- value = kwargs["value"]
105
- kv = kwargs["kv"]
106
- input_pos = kwargs["input_pos"]
107
- mask = kwargs["mask"]
108
- config = kwargs["config"]
109
-
110
100
  b, seq_len, _, _ = query.shape
111
101
  if kv is not None:
112
102
  kv = kv_utils.update(kv, input_pos, key, value)
113
103
  key, value = kv.k_cache, kv.v_cache
114
104
 
115
- sdpa_out = sdpa_default.scaled_dot_product_attention(
105
+ if enable_hlfb:
106
+ sdpa_func = sdpa.scaled_dot_product_attention_with_hlfb
107
+ else:
108
+ sdpa_func = sdpa.scaled_dot_product_attention
109
+ sdpa_out = sdpa_func(
116
110
  query,
117
111
  key,
118
112
  value,
@@ -32,16 +32,14 @@ class TestModelConversion(googletest.TestCase):
32
32
 
33
33
  def setUp(self):
34
34
  super().setUp()
35
- # Builder function for an Interpreter that supports custom ops.
36
35
  self._interpreter_builder = (
37
- lambda tflite_model: lambda: interpreter.InterpreterWithCustomOps(
38
- custom_op_registerers=["GenAIOpsRegisterer"],
36
+ lambda tflite_model: lambda: interpreter.Interpreter(
39
37
  model_content=tflite_model,
40
38
  experimental_default_delegate_latest_features=True,
41
39
  )
42
40
  )
43
41
 
44
- def _get_params(self, enable_hlfb: bool):
42
+ def _get_params(self, enable_hlfb: bool, kv_layout: kv_cache.KVLayout):
45
43
  """Returns a model, edge model and the kwargs to use for testing."""
46
44
  config = toy_model_with_kv_cache.get_model_config()
47
45
  config.enable_hlfb = enable_hlfb
@@ -49,7 +47,7 @@ class TestModelConversion(googletest.TestCase):
49
47
  tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
50
48
  [10], dtype=torch.int
51
49
  )
52
- kv = kv_cache.KVCache.from_model_config(config)
50
+ kv = kv_cache.KVCache.from_model_config(config, kv_layout=kv_layout)
53
51
  kwargs = {
54
52
  "tokens": tokens,
55
53
  "input_pos": input_pos,
@@ -65,8 +63,12 @@ class TestModelConversion(googletest.TestCase):
65
63
  )
66
64
  return pytorch_model, edge_model, kwargs
67
65
 
68
- def _test_model_with_kv_cache(self, enable_hlfb: bool):
69
- pytorch_model, edge_model, kwargs = self._get_params(enable_hlfb)
66
+ def _test_model_with_kv_cache(
67
+ self,
68
+ enable_hlfb: bool = False,
69
+ kv_layout: kv_cache.KVLayout = kv_cache.KV_LAYOUT_DEFAULT,
70
+ ):
71
+ pytorch_model, edge_model, kwargs = self._get_params(enable_hlfb, kv_layout)
70
72
 
71
73
  self.assertTrue(
72
74
  test_utils.compare_tflite_torch(
@@ -81,38 +83,34 @@ class TestModelConversion(googletest.TestCase):
81
83
  )
82
84
  )
83
85
 
84
- @googletest.skipIf(
85
- ai_edge_torch.config.in_oss,
86
- reason="tests with custom ops are not supported in oss",
87
- )
88
86
  def test_toy_model_with_kv_cache(self):
89
87
  self._test_model_with_kv_cache(enable_hlfb=False)
90
88
 
91
- @googletest.skipIf(
92
- ai_edge_torch.config.in_oss,
93
- reason="tests with custom ops are not supported in oss",
94
- )
95
89
  def test_toy_model_with_kv_cache_with_hlfb(self):
96
90
  self._test_model_with_kv_cache(enable_hlfb=True)
97
91
 
98
- @googletest.skipIf(
99
- ai_edge_torch.config.in_oss,
100
- reason="tests with custom ops are not supported in oss",
101
- )
92
+ def test_toy_model_with_kv_cache_transposed(self):
93
+ self._test_model_with_kv_cache(kv_layout=kv_cache.KV_LAYOUT_TRANSPOSED)
94
+
102
95
  def test_toy_model_has_dus_op(self):
103
96
  """Tests that the model has the dynamic update slice op."""
104
- _, edge_model, _ = self._get_params(enable_hlfb=True)
105
- interpreter_ = interpreter.InterpreterWithCustomOps(
106
- custom_op_registerers=["GenAIOpsRegisterer"],
107
- model_content=edge_model.tflite_model(),
108
- experimental_default_delegate_latest_features=True,
97
+ _, edge_model, _ = self._get_params(
98
+ enable_hlfb=True, kv_layout=kv_cache.KV_LAYOUT_DEFAULT
109
99
  )
100
+ interpreter = self._interpreter_builder(edge_model.tflite_model())()
110
101
 
111
102
  # pylint: disable=protected-access
112
- op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
103
+ op_names = [op["op_name"] for op in interpreter._get_ops_details()]
113
104
  self.assertIn("DYNAMIC_UPDATE_SLICE", op_names)
114
105
 
115
- def _test_multisig_model(self, config, pytorch_model, atol, rtol):
106
+ def _test_multisig_model(
107
+ self,
108
+ config,
109
+ pytorch_model,
110
+ atol,
111
+ rtol,
112
+ kv_layout=kv_cache.KV_LAYOUT_DEFAULT,
113
+ ):
116
114
  # prefill
117
115
  seq_len = 10
118
116
  prefill_tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
@@ -124,7 +122,7 @@ class TestModelConversion(googletest.TestCase):
124
122
  decode_token = torch.tensor([[1]], dtype=torch.int)
125
123
  decode_input_pos = torch.tensor([5], dtype=torch.int)
126
124
 
127
- kv = kv_cache.KVCache.from_model_config(config)
125
+ kv = kv_cache.KVCache.from_model_config(config, kv_layout=kv_layout)
128
126
 
129
127
  edge_model = (
130
128
  ai_edge_torch.signature(
@@ -160,7 +158,7 @@ class TestModelConversion(googletest.TestCase):
160
158
  kv,
161
159
  signature_name="prefill",
162
160
  atol=atol,
163
- rtol=atol,
161
+ rtol=rtol,
164
162
  )
165
163
  )
166
164
 
@@ -173,19 +171,26 @@ class TestModelConversion(googletest.TestCase):
173
171
  kv,
174
172
  signature_name="decode",
175
173
  atol=atol,
176
- rtol=atol,
174
+ rtol=rtol,
177
175
  )
178
176
  )
179
177
 
180
- @googletest.skipIf(
181
- ai_edge_torch.config.in_oss,
182
- reason="tests with custom ops are not supported in oss",
183
- )
184
178
  def test_tiny_llama_multisig(self):
185
179
  config = tiny_llama.get_fake_model_config()
186
180
  pytorch_model = tiny_llama.TinyLlama(config).eval()
187
181
  self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
188
182
 
183
+ def test_tiny_llama_multisig_kv_layout_transposed(self):
184
+ config = tiny_llama.get_fake_model_config()
185
+ pytorch_model = tiny_llama.TinyLlama(config).eval()
186
+ self._test_multisig_model(
187
+ config,
188
+ pytorch_model,
189
+ atol=1e-5,
190
+ rtol=1e-5,
191
+ kv_layout=kv_cache.KV_LAYOUT_TRANSPOSED,
192
+ )
193
+
189
194
 
190
195
  if __name__ == "__main__":
191
196
  googletest.main()