ai-edge-torch-nightly 0.5.0.dev20250424__py3-none-any.whl → 0.5.0.dev20250425__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. ai_edge_torch/_convert/conversion.py +1 -3
  2. ai_edge_torch/_convert/fx_passes/__init__.py +0 -1
  3. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +63 -2
  4. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +2 -1
  5. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +38 -4
  6. ai_edge_torch/generative/examples/deepseek/deepseek.py +1 -0
  7. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +37 -2
  8. ai_edge_torch/generative/examples/qwen/qwen.py +1 -0
  9. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +51 -0
  10. ai_edge_torch/generative/layers/sdpa_with_kv_update.py +4 -6
  11. ai_edge_torch/generative/test/test_model_conversion.py +3 -33
  12. ai_edge_torch/generative/test/test_model_conversion_large.py +3 -75
  13. ai_edge_torch/generative/utilities/converter.py +5 -0
  14. ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +2 -0
  15. ai_edge_torch/version.py +1 -1
  16. {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/METADATA +1 -1
  17. {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/RECORD +20 -22
  18. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +0 -129
  19. ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +0 -93
  20. {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/LICENSE +0 -0
  21. {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/WHEEL +0 -0
  22. {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/top_level.txt +0 -0
@@ -35,13 +35,11 @@ def _run_convert_passes(
35
35
  )
36
36
 
37
37
  passes = [
38
- fx_passes.CastInputsBf16ToF32Pass(),
39
- fx_passes.BuildInterpolateCompositePass(),
40
- fx_passes.CanonicalizePass(),
41
38
  fx_passes.OptimizeLayoutTransposesPass(),
42
39
  fx_passes.CanonicalizePass(),
43
40
  fx_passes.BuildAtenCompositePass(),
44
41
  fx_passes.RemoveNonUserOutputsPass(),
42
+ fx_passes.CastInputsBf16ToF32Pass(),
45
43
  ]
46
44
 
47
45
  # Debuginfo is not injected automatically by odml_torch. Only inject
@@ -16,7 +16,6 @@
16
16
  from typing import Sequence, Union
17
17
 
18
18
  from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
19
- from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
20
19
  from ai_edge_torch._convert.fx_passes.cast_inputs_bf16_to_f32_pass import CastInputsBf16ToF32Pass
21
20
  from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
22
21
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
@@ -20,7 +20,8 @@ import torch
20
20
  import torch.utils._pytree as pytree
21
21
 
22
22
  _composite_builders: dict[
23
- Callable, Callable[[torch.fx.GraphModule, torch.fx.Node], None]
23
+ Callable[[Any, ...], Any],
24
+ Callable[[torch.fx.GraphModule, torch.fx.Node], None],
24
25
  ] = {}
25
26
 
26
27
 
@@ -272,13 +273,73 @@ def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node):
272
273
  output = op(**full_kwargs)
273
274
  output = builder.mark_outputs(output)
274
275
 
275
- # Explicitly reshape back to the original shape. This places the ReshapeOp outside of the HLFB.
276
+ # Explicitly reshape back to the original shape. This places the ReshapeOp
277
+ # outside of the HLFB.
276
278
  output = torch.reshape(output, (*(original_idx_shape), embedding_dim))
277
279
  return output
278
280
 
279
281
  node.target = embedding
280
282
 
281
283
 
284
+ @_register_composite_builder(torch.ops.aten.upsample_bilinear2d.vec)
285
+ def _aten_upsample_bilinear2d_vec(_, node: torch.fx.Node):
286
+ """Build a composite for aten.upsample_bilinear2d.vec."""
287
+ op = node.target
288
+ args_mapper = TorchOpArgumentsMapper(op)
289
+ # Assumes later FX passes does not change the args/kwargs of the op.
290
+ # Which is a valid assumption for, given that composite/mark_tensor wrapper
291
+ # should semantically prevents any future mutations on the op.
292
+ output_h, output_w = node.meta["val"].shape[-2:]
293
+
294
+ def upsample_bilinear2d_vec(*args, **kwargs):
295
+ nonlocal op, args_mapper
296
+ full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
297
+
298
+ builder = lowertools.StableHLOCompositeBuilder(
299
+ name="odml.upsample_bilinear2d",
300
+ attr={
301
+ "size": (int(output_h), int(output_w)),
302
+ "align_corners": full_kwargs["align_corners"],
303
+ "is_nchw_op": True,
304
+ },
305
+ )
306
+ full_kwargs["input"] = builder.mark_inputs(full_kwargs["input"])
307
+ output = op(**full_kwargs)
308
+ output = builder.mark_outputs(output)
309
+ return output
310
+
311
+ node.target = upsample_bilinear2d_vec
312
+
313
+
314
+ @_register_composite_builder(torch.ops.aten.upsample_nearest2d.vec)
315
+ def _aten_upsample_nearest2d_vec(_, node: torch.fx.Node):
316
+ """Build a composite for aten.upsample_nearest2d.vec."""
317
+ op = node.target
318
+ args_mapper = TorchOpArgumentsMapper(op)
319
+ # Assumes later FX passes does not change the args/kwargs of the op.
320
+ # Which is a valid assumption for, given that composite/mark_tensor wrapper
321
+ # should semantically prevents any future mutations on the op.
322
+ output_h, output_w = node.meta["val"].shape[-2:]
323
+
324
+ def upsample_nearest2d_vec(*args, **kwargs):
325
+ nonlocal op, args_mapper
326
+ full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
327
+
328
+ builder = lowertools.StableHLOCompositeBuilder(
329
+ name="tfl.resize_nearest_neighbor",
330
+ attr={
331
+ "size": (int(output_h), int(output_w)),
332
+ "is_nchw_op": True,
333
+ },
334
+ )
335
+ full_kwargs["input"] = builder.mark_inputs(full_kwargs["input"])
336
+ output = op(**full_kwargs)
337
+ output = builder.mark_outputs(output)
338
+ return output
339
+
340
+ node.target = upsample_nearest2d_vec
341
+
342
+
282
343
  class BuildAtenCompositePass(fx_infra.PassBase):
283
344
 
284
345
  def call(self, graph_module: torch.fx.GraphModule):
@@ -17,6 +17,7 @@
17
17
  import operator
18
18
 
19
19
  import ai_edge_torch
20
+ from ai_edge_torch import lowertools
20
21
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark
21
22
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import op_func_registry
22
23
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
@@ -24,7 +25,7 @@ import torch
24
25
  import torch.utils._pytree as pytree
25
26
 
26
27
  aten = torch.ops.aten
27
- StableHLOCompositeBuilder = ai_edge_torch.hlfb.StableHLOCompositeBuilder
28
+ StableHLOCompositeBuilder = lowertools.StableHLOCompositeBuilder
28
29
 
29
30
  __all__ = ["rewrite_nhwc_node", "has_nhwc_rewriter"]
30
31
 
@@ -17,11 +17,43 @@
17
17
 
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.deepseek import deepseek
20
+ from ai_edge_torch.generative.layers import kv_cache
20
21
  from ai_edge_torch.generative.utilities import converter
21
- from ai_edge_torch.generative.utilities import export_config
22
+ from ai_edge_torch.generative.utilities.model_builder import export_cfg
23
+ import torch
24
+
25
+ flags = converter.define_conversion_flags('deepseek')
26
+ ExportConfig = export_cfg.ExportConfig
27
+
28
+
29
+ def _create_mask(mask_len, kv_cache_max_len):
30
+ mask = torch.full(
31
+ (mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
32
+ )
33
+ mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
34
+ return mask
35
+
36
+
37
+ def _create_export_config(
38
+ prefill_seq_lens: list[int], kv_cache_max_len: int
39
+ ) -> ExportConfig:
40
+ """Creates the export config for the model."""
41
+ export_config = ExportConfig()
42
+ if isinstance(prefill_seq_lens, list):
43
+ prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
44
+ else:
45
+ prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
46
+
47
+ export_config.prefill_mask = prefill_mask
48
+
49
+ decode_mask = torch.full(
50
+ (1, kv_cache_max_len), float('-inf'), dtype=torch.float32
51
+ )
52
+ decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
53
+ export_config.decode_mask = decode_mask
54
+ export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
55
+ return export_config
22
56
 
23
- flags = converter.define_conversion_flags("deepseek")
24
- ExportConfig = export_config.ExportConfig
25
57
 
26
58
  def main(_):
27
59
  pytorch_model = deepseek.build_model(
@@ -34,7 +66,9 @@ def main(_):
34
66
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
35
67
  quantize=flags.FLAGS.quantize,
36
68
  lora_ranks=flags.FLAGS.lora_ranks,
37
- export_config=ExportConfig(),
69
+ export_config=_create_export_config(
70
+ flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
71
+ ),
38
72
  )
39
73
 
40
74
 
@@ -53,6 +53,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
53
53
  norm_config = cfg.NormalizationConfig(
54
54
  type=cfg.NormalizationType.RMS_NORM,
55
55
  epsilon=1e-06,
56
+ enable_hlfb=True,
56
57
  )
57
58
  block_config = cfg.TransformerBlockConfig(
58
59
  attn_config=attn_config,
@@ -17,13 +17,14 @@
17
17
 
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.qwen import qwen
20
+ from ai_edge_torch.generative.layers import kv_cache
20
21
  from ai_edge_torch.generative.utilities import converter
21
22
  from ai_edge_torch.generative.utilities import export_config
23
+ import torch
22
24
 
23
25
  flags = converter.define_conversion_flags('qwen')
24
26
  ExportConfig = export_config.ExportConfig
25
27
 
26
-
27
28
  _MODEL_SIZE = flags.DEFINE_enum(
28
29
  'model_size',
29
30
  '3b',
@@ -37,6 +38,36 @@ _BUILDER = {
37
38
  '3b': qwen.build_3b_model,
38
39
  }
39
40
 
41
+
42
+ def _create_mask(mask_len, kv_cache_max_len):
43
+ mask = torch.full(
44
+ (mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
45
+ )
46
+ mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
47
+ return mask
48
+
49
+
50
+ def _create_export_config(
51
+ prefill_seq_lens: list[int], kv_cache_max_len: int
52
+ ) -> ExportConfig:
53
+ """Creates the export config for the model."""
54
+ export_config = ExportConfig()
55
+ if isinstance(prefill_seq_lens, list):
56
+ prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
57
+ else:
58
+ prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
59
+
60
+ export_config.prefill_mask = prefill_mask
61
+
62
+ decode_mask = torch.full(
63
+ (1, kv_cache_max_len), float('-inf'), dtype=torch.float32
64
+ )
65
+ decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
66
+ export_config.decode_mask = decode_mask
67
+ export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
68
+ return export_config
69
+
70
+
40
71
  def main(_):
41
72
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
42
73
  flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
@@ -48,7 +79,11 @@ def main(_):
48
79
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
49
80
  quantize=flags.FLAGS.quantize,
50
81
  lora_ranks=flags.FLAGS.lora_ranks,
51
- export_config=ExportConfig(),
82
+ export_config=_create_export_config(
83
+ flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
84
+ )
85
+ if flags.FLAGS.transpose_kv_cache
86
+ else ExportConfig(),
52
87
  )
53
88
 
54
89
 
@@ -53,6 +53,7 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
53
53
  norm_config = cfg.NormalizationConfig(
54
54
  type=cfg.NormalizationType.RMS_NORM,
55
55
  epsilon=1e-06,
56
+ enable_hlfb=True,
56
57
  )
57
58
  block_config = cfg.TransformerBlockConfig(
58
59
  attn_config=attn_config,
@@ -17,6 +17,8 @@
17
17
  import math
18
18
  from typing import Optional
19
19
 
20
+ from ai_edge_torch.generative.custom_ops import bmm_4d as bmm_lib
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
20
22
  from ai_edge_torch.hlfb import StableHLOCompositeBuilder
21
23
  import torch
22
24
  import torch.nn.functional as F
@@ -142,3 +144,52 @@ def scaled_dot_product_attention_with_hlfb(
142
144
  result = y.transpose(1, 2)
143
145
  result = builder.mark_outputs(result)
144
146
  return result
147
+
148
+
149
+ def scaled_dot_product_attention_transposed(
150
+ query: torch.Tensor,
151
+ key: torch.Tensor,
152
+ value: torch.Tensor,
153
+ head_size: int,
154
+ mask: Optional[torch.Tensor] = None,
155
+ scale: Optional[float] = None,
156
+ softcap: Optional[float] = None,
157
+ ):
158
+ """Scaled dot product attention with transposed key and value.
159
+
160
+ Args:
161
+ query: Query tensor, with shape [B, T, N, H].
162
+ key: Key tensor, with shape [B, T, KV_LEN, H].
163
+ value: Value tensor, with shape [B, T, KV_LEN, H].
164
+ head_size (int): head dimension.
165
+ mask (torch.Tensor): the optional mask tensor.
166
+ scale (float): the optional scale factor.
167
+ softcap (float): the optional softcap for the logits.
168
+
169
+ Returns:
170
+ The output tensor of scaled_dot_product_attention_transposed.
171
+ """
172
+
173
+ if scale is None:
174
+ scale = 1.0 / math.sqrt(head_size)
175
+
176
+ query = query * scale
177
+
178
+ assert mask is not None, "Mask should not be None!"
179
+ t = mask.shape[2]
180
+
181
+ logits = bmm_lib.bmm_4d(query, key)
182
+
183
+ _, bk, gt, s = logits.shape
184
+ g = gt // t
185
+ logits = logits.reshape((bk, g, t, s))
186
+ if softcap is not None:
187
+ logits = torch.tanh(logits / softcap)
188
+ logits = logits * softcap
189
+
190
+ padded_logits = logits + mask
191
+ padded_logits = padded_logits.reshape(1, bk, gt, s)
192
+ probs = F.softmax(padded_logits, dim=-1).type_as(key)
193
+ encoded = bmm_lib.bmm_4d(probs, value)
194
+
195
+ return encoded # 1, bk, gt, h
@@ -18,9 +18,8 @@
18
18
  from typing import Tuple
19
19
 
20
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
- from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa_default
21
+ from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
22
22
  from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
23
- from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
24
23
  import ai_edge_torch.generative.layers.model_config as cfg
25
24
  import torch
26
25
 
@@ -72,8 +71,7 @@ def _sdpa_with_kv_update_transposed(
72
71
  kv = kv_utils_experimental.update(kv, input_pos, key, value)
73
72
  key, value = kv.k_cache, kv.v_cache
74
73
 
75
- sdpa_out = sdpa.scaled_dot_product_attention(
76
- kv,
74
+ sdpa_out = sdpa.scaled_dot_product_attention_transposed(
77
75
  query,
78
76
  key,
79
77
  value,
@@ -105,9 +103,9 @@ def _sdpa_with_kv_update_default(
105
103
  key, value = kv.k_cache, kv.v_cache
106
104
 
107
105
  if enable_hlfb:
108
- sdpa_func = sdpa_default.scaled_dot_product_attention_with_hlfb
106
+ sdpa_func = sdpa.scaled_dot_product_attention_with_hlfb
109
107
  else:
110
- sdpa_func = sdpa_default.scaled_dot_product_attention
108
+ sdpa_func = sdpa.scaled_dot_product_attention
111
109
  sdpa_out = sdpa_func(
112
110
  query,
113
111
  key,
@@ -32,10 +32,8 @@ class TestModelConversion(googletest.TestCase):
32
32
 
33
33
  def setUp(self):
34
34
  super().setUp()
35
- # Builder function for an Interpreter that supports custom ops.
36
35
  self._interpreter_builder = (
37
- lambda tflite_model: lambda: interpreter.InterpreterWithCustomOps(
38
- custom_op_registerers=["GenAIOpsRegisterer"],
36
+ lambda tflite_model: lambda: interpreter.Interpreter(
39
37
  model_content=tflite_model,
40
38
  experimental_default_delegate_latest_features=True,
41
39
  )
@@ -85,44 +83,24 @@ class TestModelConversion(googletest.TestCase):
85
83
  )
86
84
  )
87
85
 
88
- @googletest.skipIf(
89
- ai_edge_torch.config.in_oss,
90
- reason="tests with custom ops are not supported in oss",
91
- )
92
86
  def test_toy_model_with_kv_cache(self):
93
87
  self._test_model_with_kv_cache(enable_hlfb=False)
94
88
 
95
- @googletest.skipIf(
96
- ai_edge_torch.config.in_oss,
97
- reason="tests with custom ops are not supported in oss",
98
- )
99
89
  def test_toy_model_with_kv_cache_with_hlfb(self):
100
90
  self._test_model_with_kv_cache(enable_hlfb=True)
101
91
 
102
- @googletest.skipIf(
103
- ai_edge_torch.config.in_oss,
104
- reason="tests with custom ops are not supported in oss",
105
- )
106
92
  def test_toy_model_with_kv_cache_transposed(self):
107
93
  self._test_model_with_kv_cache(kv_layout=kv_cache.KV_LAYOUT_TRANSPOSED)
108
94
 
109
- @googletest.skipIf(
110
- ai_edge_torch.config.in_oss,
111
- reason="tests with custom ops are not supported in oss",
112
- )
113
95
  def test_toy_model_has_dus_op(self):
114
96
  """Tests that the model has the dynamic update slice op."""
115
97
  _, edge_model, _ = self._get_params(
116
98
  enable_hlfb=True, kv_layout=kv_cache.KV_LAYOUT_DEFAULT
117
99
  )
118
- interpreter_ = interpreter.InterpreterWithCustomOps(
119
- custom_op_registerers=["GenAIOpsRegisterer"],
120
- model_content=edge_model.tflite_model(),
121
- experimental_default_delegate_latest_features=True,
122
- )
100
+ interpreter = self._interpreter_builder(edge_model.tflite_model())()
123
101
 
124
102
  # pylint: disable=protected-access
125
- op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
103
+ op_names = [op["op_name"] for op in interpreter._get_ops_details()]
126
104
  self.assertIn("DYNAMIC_UPDATE_SLICE", op_names)
127
105
 
128
106
  def _test_multisig_model(
@@ -197,19 +175,11 @@ class TestModelConversion(googletest.TestCase):
197
175
  )
198
176
  )
199
177
 
200
- @googletest.skipIf(
201
- ai_edge_torch.config.in_oss,
202
- reason="tests with custom ops are not supported in oss",
203
- )
204
178
  def test_tiny_llama_multisig(self):
205
179
  config = tiny_llama.get_fake_model_config()
206
180
  pytorch_model = tiny_llama.TinyLlama(config).eval()
207
181
  self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
208
182
 
209
- @googletest.skipIf(
210
- ai_edge_torch.config.in_oss,
211
- reason="tests with custom ops are not supported in oss",
212
- )
213
183
  def test_tiny_llama_multisig_kv_layout_transposed(self):
214
184
  config = tiny_llama.get_fake_model_config()
215
185
  pytorch_model = tiny_llama.TinyLlama(config).eval()
@@ -48,10 +48,8 @@ class TestModelConversion(googletest.TestCase):
48
48
 
49
49
  def setUp(self):
50
50
  super().setUp()
51
- # Builder function for an Interpreter that supports custom ops.
52
51
  self._interpreter_builder = (
53
- lambda tflite_model: lambda: interpreter.InterpreterWithCustomOps(
54
- custom_op_registerers=["GenAIOpsRegisterer"],
52
+ lambda tflite_model: lambda: interpreter.Interpreter(
55
53
  model_content=tflite_model,
56
54
  experimental_default_delegate_latest_features=True,
57
55
  )
@@ -94,110 +92,62 @@ class TestModelConversion(googletest.TestCase):
94
92
  )
95
93
  )
96
94
 
97
- @googletest.skipIf(
98
- ai_edge_torch.config.in_oss,
99
- reason="tests with custom ops are not supported in oss",
100
- )
101
95
  def test_gemma1(self):
102
96
  config = gemma1.get_fake_model_config()
103
97
  pytorch_model = gemma1.Gemma1(config).eval()
104
98
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
105
99
 
106
- @googletest.skipIf(
107
- ai_edge_torch.config.in_oss,
108
- reason="tests with custom ops are not supported in oss",
109
- )
110
100
  def test_gemma2(self):
111
101
  config = gemma2.get_fake_model_config()
112
102
  pytorch_model = gemma2.Gemma2(config).eval()
113
103
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
114
104
 
115
- @googletest.skipIf(
116
- ai_edge_torch.config.in_oss,
117
- reason="tests with custom ops are not supported in oss",
118
- )
119
105
  def test_llama(self):
120
106
  config = llama.get_fake_model_config()
121
107
  pytorch_model = llama.Llama(config).eval()
122
108
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
123
109
 
124
- @googletest.skipIf(
125
- ai_edge_torch.config.in_oss,
126
- reason="tests with custom ops are not supported in oss",
127
- )
128
110
  def test_phi2(self):
129
111
  config = phi2.get_fake_model_config()
130
112
  pytorch_model = phi2.Phi2(config).eval()
131
113
  # Phi-2 logits are very big, so we need a larger absolute tolerance.
132
114
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
133
115
 
134
- @googletest.skipIf(
135
- ai_edge_torch.config.in_oss,
136
- reason="tests with custom ops are not supported in oss",
137
- )
138
116
  def test_phi3(self):
139
117
  config = phi3.get_fake_model_config()
140
118
  pytorch_model = phi3.Phi3_5Mini(config).eval()
141
119
  self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
142
120
 
143
- @googletest.skipIf(
144
- ai_edge_torch.config.in_oss,
145
- reason="tests with custom ops are not supported in oss",
146
- )
147
121
  def test_phi4(self):
148
122
  config = phi4.get_fake_model_config()
149
123
  pytorch_model = phi4.Phi4Mini(config).eval()
150
124
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
151
125
 
152
- @googletest.skipIf(
153
- ai_edge_torch.config.in_oss,
154
- reason="tests with custom ops are not supported in oss",
155
- )
156
126
  def test_smollm(self):
157
127
  config = smollm.get_fake_model_config()
158
128
  pytorch_model = smollm.SmolLM(config).eval()
159
129
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
160
130
 
161
- @googletest.skipIf(
162
- ai_edge_torch.config.in_oss,
163
- reason="tests with custom ops are not supported in oss",
164
- )
165
131
  def test_smollm2(self):
166
132
  config = smollm.get_fake_model_config_v2()
167
133
  pytorch_model = smollm.SmolLM2(config).eval()
168
134
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
169
135
 
170
- @googletest.skipIf(
171
- ai_edge_torch.config.in_oss,
172
- reason="tests with custom ops are not supported in oss",
173
- )
174
136
  def test_openelm(self):
175
137
  config = openelm.get_fake_model_config()
176
138
  pytorch_model = openelm.OpenELM(config).eval()
177
139
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
178
140
 
179
- @googletest.skipIf(
180
- ai_edge_torch.config.in_oss,
181
- reason="tests with custom ops are not supported in oss",
182
- )
183
141
  def test_qwen(self):
184
142
  config = qwen.get_fake_model_config()
185
143
  pytorch_model = qwen.Qwen(config).eval()
186
144
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
187
145
 
188
- @googletest.skipIf(
189
- ai_edge_torch.config.in_oss,
190
- reason="tests with custom ops are not supported in oss",
191
- )
192
146
  def test_deepseek(self):
193
147
  config = deepseek.get_fake_model_config()
194
148
  pytorch_model = deepseek.DeepSeekDistillQwen(config).eval()
195
149
  self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
196
150
 
197
- @googletest.skipIf(
198
- ai_edge_torch.config.in_oss,
199
- reason="tests with custom ops are not supported in oss",
200
- )
201
151
  def test_amd_llama_135m(self):
202
152
  config = amd_llama_135m.get_fake_model_config()
203
153
  pytorch_model = amd_llama_135m.AmdLlama(config).eval()
@@ -246,19 +196,11 @@ class TestModelConversion(googletest.TestCase):
246
196
  )
247
197
  )
248
198
 
249
- @googletest.skipIf(
250
- ai_edge_torch.config.in_oss,
251
- reason="tests with custom ops are not supported in oss",
252
- )
253
199
  def test_paligemma1(self):
254
200
  self._test_paligemma_model(
255
201
  decoder.Decoder, decoder.get_fake_decoder_config, atol=1e-3, rtol=1e-5
256
202
  )
257
203
 
258
- @googletest.skipIf(
259
- ai_edge_torch.config.in_oss,
260
- reason="tests with custom ops are not supported in oss",
261
- )
262
204
  def test_paligemma2(self):
263
205
  self._test_paligemma_model(
264
206
  decoder2.Decoder2,
@@ -267,10 +209,6 @@ class TestModelConversion(googletest.TestCase):
267
209
  rtol=1e-5,
268
210
  )
269
211
 
270
- @googletest.skipIf(
271
- ai_edge_torch.config.in_oss,
272
- reason="tests with custom ops are not supported in oss",
273
- )
274
212
  def test_qwen_vl_model(self):
275
213
  config = qwen_vl.get_fake_model_config()
276
214
  pytorch_model = qwen_vl.QwenVL(config).eval()
@@ -316,10 +254,7 @@ class TestModelConversion(googletest.TestCase):
316
254
  )
317
255
  )
318
256
 
319
- @googletest.skipIf(
320
- ai_edge_torch.config.in_oss,
321
- reason="tests with custom ops are not supported in oss",
322
- )
257
+ @googletest.skipIf(ai_edge_torch.config.in_oss, reason="flaky")
323
258
  def test_stable_diffusion_clip(self):
324
259
  config = sd_clip.get_fake_model_config()
325
260
  prompt_tokens = torch.from_numpy(
@@ -348,10 +283,7 @@ class TestModelConversion(googletest.TestCase):
348
283
  )
349
284
  )
350
285
 
351
- @googletest.skipIf(
352
- ai_edge_torch.config.in_oss,
353
- reason="tests with custom ops are not supported in oss",
354
- )
286
+ @googletest.skipIf(ai_edge_torch.config.in_oss, reason="b/413106901")
355
287
  def test_stable_diffusion_diffusion(self):
356
288
  config = sd_diffusion.get_fake_model_config(2)
357
289
  # Reduce stddev(scale) of input values to avoid too big output logits which
@@ -390,10 +322,6 @@ class TestModelConversion(googletest.TestCase):
390
322
  )
391
323
  )
392
324
 
393
- @googletest.skipIf(
394
- ai_edge_torch.config.in_oss,
395
- reason="tests with custom ops are not supported in oss",
396
- )
397
325
  def test_stable_diffusion_decoder(self):
398
326
  config = sd_decoder.get_fake_model_config()
399
327
  # Reduce stddev(scale) of input values to avoid too big output logits which
@@ -81,6 +81,11 @@ def define_conversion_flags(model_name: str):
81
81
  'If set, the model will be converted with the provided list of LoRA'
82
82
  ' ranks.',
83
83
  )
84
+ flags.DEFINE_bool(
85
+ 'transpose_kv_cache',
86
+ False,
87
+ 'If set, the model will be converted with transposed KV cache.',
88
+ )
84
89
 
85
90
  return flags
86
91
 
@@ -34,6 +34,8 @@ fx_infra.decomp.update_pre_lower_decomp(
34
34
  torch.ops.aten.replication_pad1d,
35
35
  torch.ops.aten.replication_pad2d,
36
36
  torch.ops.aten.replication_pad3d,
37
+ torch.ops.aten.upsample_bilinear2d.vec,
38
+ torch.ops.aten.upsample_nearest2d.vec,
37
39
  torch.ops.aten.addmm,
38
40
  ])
39
41
  )
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.5.0.dev20250424"
16
+ __version__ = "0.5.0.dev20250425"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.5.0.dev20250424
3
+ Version: 0.5.0.dev20250425
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,16 +2,15 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=Nixp49eAXZPPMWEWkqpm_M4Mi_WGPx-I8q2noKuh0hw,706
5
+ ai_edge_torch/version.py,sha256=_aF64u6MXH8zPBTEg6odQq2WazbUIxQYlfJNXzfkMdM,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
- ai_edge_torch/_convert/conversion.py,sha256=dOr3TUfF0UCvkmlUrMqKvgaN4jh3lJ9XFuO-sHaAmIw,5521
7
+ ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
9
9
  ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
10
10
  ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
11
11
  ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
12
- ai_edge_torch/_convert/fx_passes/__init__.py,sha256=6LtGzzqT2IXprfI_vPYKhE7IuN5XmPG0xy-v0UtZ9yk,1361
13
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=a1KhqLetFb_efRHjX4T-zH0vF-U37Ha5I1CPIAsIluE,9211
14
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=3JyjiHpn17Zhfq3yGQXK5LMH71DQPXHb_4GOkP9uAjY,4251
12
+ ai_edge_torch/_convert/fx_passes/__init__.py,sha256=jbRCZmSduG_1qmngaEEtbofAyL1PKZ8P1uxzzsXQhsw,1253
13
+ ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=dgUO-lI9Id9hIOHP5XmegVlu5Fl79GR4_b-lDUehzoo,11428
15
14
  ai_edge_torch/_convert/fx_passes/cast_inputs_bf16_to_f32_pass.py,sha256=90YxLVAAkiA3qKr4Um__JmPeC1bTeA2PxBCj0GETq1Q,1748
16
15
  ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=Z6E3U7SYZvMl3Ivpqa3burVOLKFndEZuNmWKNxjq2mM,2386
17
16
  ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=HCOkj0k3NhaYbtfjE8HDXVmYhZ9fL5V_u6VunVh9mN4,2116
@@ -19,7 +18,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha
19
18
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/_decomp_registry.py,sha256=aWO_zHDF4j_hokoKJQNFIFmua4ysXztsgS6pcyBUht0,1082
20
19
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=7yEKSfXskXUk4tsd7c8vL155O-iU4eUjXCU5RSZHrbw,8204
21
20
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
22
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=OCFcPP618zH8IE12KTBQm2hRTtsaSeO3egvlOBUpNxA,13911
21
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=IhEh3tTP3-AmQlpt24stKKEl0AIRyuo2REZIbhkmgJo,13940
23
22
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=OhisegHY2j4cv_m9auCh9Mq9qmm1lUqpFLVO9X-oBlc,1032
24
23
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=mr0MiLbaQmU-3S3KT-vb58kRWbNT3VJiCKY-K7_3tFg,10556
25
24
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
@@ -54,8 +53,8 @@ ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=urNif8
54
53
  ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=z5MWiZLnsQzhNYMiQbcI9i0ki-dtkbimCptkiTFZxwo,1586
55
54
  ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=o13NkFlBgawBsjdJup05VMUjAPvDRAmig6VyEkX8q6U,2426
56
55
  ai_edge_torch/generative/examples/deepseek/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
57
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=r6Pb5_LRKvw2QrOMn3PzunrVxPB-LSdyU2H1XORZo9A,1553
58
- ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=AOAJ7ltXwY5IbmcCP2nVHW9FmRwexzfNxnoDlR-sW9c,2885
56
+ ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=1wz4h3bjyX2qMRZ310UKGNYTORegzxinVFmYz2Fupm4,2666
57
+ ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=yhS_i2kR0GJWpWciCt4p9Z9nHYh6A5uJ8Ycy2ebFN9w,2909
59
58
  ai_edge_torch/generative/examples/deepseek/verify.py,sha256=iYldze-pvZGvPkkqr6zA7EmitPnH9sXkzjNVx353IcE,2403
60
59
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
61
60
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=tSEtGeS-Ndcc_cTm7c4CT4FqRiwrHedEv1oJk4Y_zYU,1552
@@ -104,8 +103,8 @@ ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfS
104
103
  ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
105
104
  ai_edge_torch/generative/examples/phi/verify_phi4.py,sha256=BoCa5kUBRHtMQ-5ql6yD4pG4xHJMyUiQlpMOWVx-JgY,2356
106
105
  ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
107
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=om3lXL1RnA87PkfU_cRfP6RnPgXrCmaB-cK98H-nqbA,1802
108
- ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-LKzFCvWvFLKhJjnASo,4199
106
+ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=-Xe5koexhNUkWjS2XgS9Ggg7XOQAlMO8QcBJRTNjJa4,2972
107
+ ai_edge_torch/generative/examples/qwen/qwen.py,sha256=m8APYzo9N0SXsdvCxC8HtCcbN3W7gLKkRBL-Tg0BWXU,4223
109
108
  ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
110
109
  ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
111
110
  ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=yVebRatt2SLCsGvrYTBXOM-0S2REhkpikHTyy5MCjUw,2222
@@ -159,11 +158,10 @@ ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1t
159
158
  ai_edge_torch/generative/layers/model_config.py,sha256=nLXvTkDAIHJQ0PTaWODF8oxJQoJ-K8D10cKR9229SAw,8355
160
159
  ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
161
160
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
162
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
163
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=D4rATT2Ppa9Su7yuRHYnQPJ1dFvUDAyH1GrFnCed7p8,3810
161
+ ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=efqqGRZPJ55hKn1MQJ-cXfrJD85uS1v7W_juyGyts58,5648
162
+ ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=Hn8Zw-jiB9GH2uZ-yaRMcDdpmjECcW4uCy-YNH9zV8c,3693
164
163
  ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
165
164
  ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=zgpFVftOfllvjh9-UEBSvUbm152SnQETn29rUMMMvAM,2978
166
- ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=YFcIGOkaNb-vvQKjI-G9-bC2Z1W0O_qRyIZPlsLl72U,2797
167
165
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
168
166
  ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZteHZXK6HKyxYji49DQ46sA9aIy7U3Jnz0HZp6hfevY,28996
169
167
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
@@ -180,12 +178,12 @@ ai_edge_torch/generative/test/test_custom_dus.py,sha256=MjIhTvkTko872M35XMciobvI
180
178
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=1sXN2RPntq0PP3IEy0NkvIbzQ0Y8JhPIwRSFwO9JLlE,5728
181
179
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
182
180
  ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
183
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=jSNJ0Eex6VYCkGn3FXbCOOJ2S3-F_QuwJctu3VycjR4,7200
184
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=-v2Vj7Qdd3GyBn4k7BWVgyGzrbcL30Su3nxZYLtwkCs,14787
181
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=mhNJikLnGVGi9NKmXB8FhnqeDy9gtrvC3yEbrTABZ4Y,6163
182
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=6LkLnFOvlnt7JVVDYKMaZClPRBEvdjq6xnSjIFYNdI8,12554
185
183
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
186
184
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
187
185
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
188
- ai_edge_torch/generative/utilities/converter.py,sha256=LtBHjnslhL-uf4sDRoC8JIbbUD73g0QW3FiWsHUdV1g,10631
186
+ ai_edge_torch/generative/utilities/converter.py,sha256=z3CvNJxKzglu1BU_5ri91RUeGHh7urhoWFbk0oq7i2M,10768
189
187
  ai_edge_torch/generative/utilities/export_config.py,sha256=8-795nyd3M34LkGhgW7hwHlJyTc2Oz1iipHK8yBhdFs,1633
190
188
  ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
191
189
  ai_edge_torch/generative/utilities/model_builder.py,sha256=ZYX1TxpFdj573du2QCyHJlFjx4q1m12R74fp4Gwl92A,6343
@@ -227,7 +225,7 @@ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62l
227
225
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=fEWjIdEpDIqT1EYLZE13O9A41OuaNdbfBrv3vNxS9gI,11601
228
226
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
229
227
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
230
- ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=VhmeGFnB5hrUsALiVWV96JJOqPDrTIWouHjTvLuT5eU,2477
228
+ ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=ybOdoFE5HIJTkyiYcc73zpyUyUpioVnAca6k0wyJPs4,2572
231
229
  ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=tkaDo232HjuZvJHyua0n6tdHecifUuVzclJAGq7PPYs,11428
232
230
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
233
231
  ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
@@ -244,8 +242,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
244
242
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
245
243
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
246
244
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
247
- ai_edge_torch_nightly-0.5.0.dev20250424.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
248
- ai_edge_torch_nightly-0.5.0.dev20250424.dist-info/METADATA,sha256=Gz8c2qvL6qiK7lrd001P55TXltKdycDvDaAq4d4Y-eQ,2051
249
- ai_edge_torch_nightly-0.5.0.dev20250424.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
250
- ai_edge_torch_nightly-0.5.0.dev20250424.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
251
- ai_edge_torch_nightly-0.5.0.dev20250424.dist-info/RECORD,,
245
+ ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
246
+ ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/METADATA,sha256=owGeoLcv0XFf4tXFatFjXLSisoaRBBwrtyLx3LFq8PM,2051
247
+ ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
248
+ ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
249
+ ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/RECORD,,
@@ -1,129 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- """Build interpolate composite pass."""
16
-
17
- import functools
18
-
19
- from ai_edge_torch import fx_infra
20
- from ai_edge_torch.hlfb import mark_pattern
21
- from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
22
- import torch
23
-
24
- # For torch nightly released after mid June 2024,
25
- # torch.nn.functional.interpolate no longer gets exported into decomposed graph
26
- # but a single aten op:
27
- # torch.ops.aten.upsample_nearest2d.vec/torch.ops.aten.upsample_bilinear2d.vec.
28
- # This would interefere with our pattern matching based composite builder.
29
- # Here we register the now missing decompositions first.
30
- _INTERPOLATE_DECOMPOSITIONS = torch._decomp.get_decompositions([
31
- torch.ops.aten.upsample_bilinear2d.vec,
32
- torch.ops.aten.upsample_nearest2d.vec,
33
- ])
34
-
35
-
36
- @functools.cache
37
- def _get_upsample_bilinear2d_pattern():
38
- pattern = pattern_module.Pattern(
39
- "odml.upsample_bilinear2d",
40
- lambda x: torch.nn.functional.interpolate(
41
- x, scale_factor=2, mode="bilinear", align_corners=False
42
- ),
43
- export_args=(torch.rand(1, 3, 100, 100),),
44
- extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
45
- )
46
-
47
- @pattern.register_attr_builder
48
- def attr_builder(pattern, graph_module, internal_match):
49
- output = internal_match.returning_nodes[0]
50
- output_h, output_w = output.meta["val"].shape[-2:]
51
- return {
52
- "size": (int(output_h), int(output_w)),
53
- "align_corners": False,
54
- "is_nchw_op": True,
55
- }
56
-
57
- return pattern
58
-
59
-
60
- @functools.cache
61
- def _get_upsample_bilinear2d_align_corners_pattern():
62
- pattern = pattern_module.Pattern(
63
- "odml.upsample_bilinear2d",
64
- lambda x: torch.nn.functional.interpolate(
65
- x, scale_factor=2, mode="bilinear", align_corners=True
66
- ),
67
- export_args=(torch.rand(1, 3, 100, 100),),
68
- extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
69
- )
70
-
71
- @pattern.register_attr_builder
72
- def attr_builder(graph_module, pattern, internal_match):
73
- output = internal_match.returning_nodes[0]
74
- output_h, output_w = output.meta["val"].shape[-2:]
75
- return {
76
- "size": (int(output_h), int(output_w)),
77
- "align_corners": True,
78
- "is_nchw_op": True,
79
- }
80
-
81
- return pattern
82
-
83
-
84
- @functools.cache
85
- def _get_interpolate_nearest2d_pattern():
86
- pattern = pattern_module.Pattern(
87
- "tfl.resize_nearest_neighbor",
88
- lambda x: torch.nn.functional.interpolate(
89
- x, scale_factor=2, mode="nearest"
90
- ),
91
- export_args=(torch.rand(1, 3, 100, 100),),
92
- extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
93
- )
94
-
95
- @pattern.register_attr_builder
96
- def attr_builder(pattern, graph_module, internal_match):
97
- output = internal_match.returning_nodes[0]
98
- output_h, output_w = output.meta["val"].shape[-2:]
99
- return {
100
- "size": (int(output_h), int(output_w)),
101
- "is_nchw_op": True,
102
- }
103
-
104
- return pattern
105
-
106
-
107
- class BuildInterpolateCompositePass(fx_infra.ExportedProgramPassBase):
108
-
109
- def __init__(self):
110
- super().__init__()
111
- self._patterns = [
112
- _get_upsample_bilinear2d_pattern(),
113
- _get_upsample_bilinear2d_align_corners_pattern(),
114
- _get_interpolate_nearest2d_pattern(),
115
- ]
116
-
117
- def call(self, exported_program: torch.export.ExportedProgram):
118
- exported_program = fx_infra.safe_run_decompositions(
119
- exported_program,
120
- _INTERPOLATE_DECOMPOSITIONS,
121
- )
122
-
123
- graph_module = exported_program.graph_module
124
- for pattern in self._patterns:
125
- graph_module = mark_pattern.mark_pattern(graph_module, pattern)
126
-
127
- graph_module.graph.lint()
128
- graph_module.recompile()
129
- return fx_infra.ExportedProgramPassResult(exported_program, True)
@@ -1,93 +0,0 @@
1
- # Copyright 2025 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- # Implements scaled dot product attention. This is experimental and
16
- # GPU-specific code.
17
-
18
- import math
19
- from typing import Optional
20
-
21
- from ai_edge_torch.generative.custom_ops import bmm_4d as bmm_lib
22
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
- from ai_edge_torch.generative.utilities import types
24
- from ai_edge_torch.hlfb import StableHLOCompositeBuilder
25
- from multipledispatch import dispatch
26
- import torch
27
- import torch.nn.functional as F
28
-
29
-
30
- def scaled_dot_product_attention(
31
- kv: kv_utils.KVCacheEntry,
32
- query: torch.Tensor,
33
- key: torch.Tensor,
34
- value: torch.Tensor,
35
- head_size: int,
36
- mask: Optional[torch.Tensor] = None,
37
- scale: Optional[float] = None,
38
- softcap: Optional[float] = None,
39
- ):
40
- if hasattr(kv, "kv_layout"):
41
- return _sdpa(
42
- kv.kv_layout[0](), # key layout
43
- kv.kv_layout[1](), # value layout
44
- query=query,
45
- key=key,
46
- value=value,
47
- head_size=head_size,
48
- mask=mask,
49
- scale=scale,
50
- softcap=softcap,
51
- )
52
- raise ValueError("No kv_layout attribute found in kv.")
53
-
54
-
55
- @dispatch(types.BNTH, types.BNHT)
56
- def _sdpa(k_type, v_type, *args, **kwargs):
57
- query = kwargs["query"]
58
- key = kwargs["key"]
59
- value = kwargs["value"]
60
- head_size = kwargs["head_size"]
61
- mask = kwargs.get("mask", None)
62
- scale = kwargs.get("scale", None)
63
- softcap = kwargs.get("softcap", None)
64
-
65
- if scale is None:
66
- scale = 1.0 / math.sqrt(head_size)
67
-
68
- query = query * scale
69
-
70
- assert mask is not None, "Mask should not be None!"
71
- t = mask.shape[2]
72
-
73
- logits = bmm_lib.bmm_4d(query, key)
74
-
75
- _, bk, gt, s = logits.shape
76
- g = gt // t
77
- logits = logits.reshape((bk, g, t, s))
78
- if softcap is not None:
79
- logits = torch.tanh(logits / softcap)
80
- logits = logits * softcap
81
-
82
- padded_logits = logits + mask
83
- padded_logits = padded_logits.reshape(1, bk, gt, s)
84
- probs = F.softmax(padded_logits, dim=-1).type_as(key)
85
- encoded = bmm_lib.bmm_4d(probs, value)
86
-
87
- return encoded # 1, bk, gt, h
88
-
89
-
90
- @dispatch(object, object)
91
- def _sdpa(k_type, v_type, *args, **kwargs):
92
-
93
- raise ValueError(f"No implementations for k={k_type} and v={v_type}")